diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 11f9a2f1..06189674 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -105,14 +105,16 @@ def main(): torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map(), - quantization_config=get_quantization_config(model_args), + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, ) model = model_args.model_name_or_path diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 1ed8e335..748dd71b 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -109,6 +109,7 @@ def main(): torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, @@ -116,8 +117,8 @@ def main(): use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map(), - quantization_config=get_quantization_config(model_args), + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, ) logger.info("*** Model loaded! ***")