From 3f368a0748f926855486a83e0286d0b1f90b250b Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:10:41 +1100 Subject: [PATCH] Add check that parameters are not intended to be offloaded (#51) * Add check that parameters are not intended to be offloaded * Only push model to device if quantization config is set. --- scripts/run_dpo.py | 6 ++++-- scripts/run_sft.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 11f9a2f1..06189674 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -105,14 +105,16 @@ def main(): torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map(), - quantization_config=get_quantization_config(model_args), + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, ) model = model_args.model_name_or_path diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 1ed8e335..748dd71b 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -109,6 +109,7 @@ def main(): torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, @@ -116,8 +117,8 @@ def main(): use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map(), - quantization_config=get_quantization_config(model_args), + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, ) logger.info("*** Model loaded! ***")