-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading #2873
Conversation
This could be a stupid question, but is it guaranteed that in a multigpu/multinode situation no other process/worker will be doing anything with the model weights during this window where the adaptors are merged? Come to think of it, is it a good idea to use merge/unmerge adaptor at all given all the possible configurations with multiple GPUS, multiple nodes, etc? |
Thanks for noticing that! Correct me if I'm wrong. From my understanding, the sync between processes happens at |
Thanks for spotting it! def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = unwrapped_model._orig_mod
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
- unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
+ if self.accelerator.is_main_process:
+ llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
+ llm_model.load_weights(state_dict.items())
+ if is_peft_model(unwrapped_model):
+ unwrapped_model.unmerge_adapter()
- if self.accelerator.is_main_process:
- llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
- llm_model.load_weights(state_dict.items()) |
You are right. I did not put it right after weights loading because I was not sure how vLLM loads weights (i.e., whether vLLM creates its own copy of the weights or not). But I just looked up vLLM's code base and it seems that the weight loader indeed copies the weights, as shown here I'll improve the PR as you suggested. |
Changed. |
@AndreiCComan thanks for your work in #2856! Can you try if the patch still gives better results? |
Thanks! I'm running some experiments to validate |
Thanks @qgallouedec for the help and @XZ-X for the speedy update! @qgallouedec here are the plots with ![]() ![]()
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
The
state_dict
obtained from the PEFT model seems to be a shallow copy of model weights.Therefore, if we
unmerge
the model before generation, the weights are "rolled-back" to the base weights without the adapter.I move the unmerge operation to the end of generation.
Fixes #2856 (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.