Skip to content

Commit

Permalink
🌯 Fix context manager runtime error when gather is disabled (#2639)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superskyyy authored Jan 23, 2025
1 parent 0e216f7 commit f34b70a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ def unwrap_model_for_generation(
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
add_hooks(model)
else:
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
add_hooks(model)
else:
yield unwrapped_model

0 comments on commit f34b70a

Please sign in to comment.