Skip to content

Commit

Permalink
Add check that parameters are not intended to be offloaded (#51)
Browse files Browse the repository at this point in the history
* Add check that parameters are not intended to be offloaded

* Only push model to device if quantization config is set.
  • Loading branch information
nathan-az authored Dec 4, 2023
1 parent 15279e7 commit 3f368a0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,16 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

model = model_args.model_name_or_path
Expand Down
5 changes: 3 additions & 2 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,16 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
logger.info("*** Model loaded! ***")

Expand Down

0 comments on commit 3f368a0

Please sign in to comment.