Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading #2873

Merged
merged 5 commits into from
Feb 17, 2025

Conversation

XZ-X
Copy link
Contributor

@XZ-X XZ-X commented Feb 15, 2025

What does this PR do?

The state_dict obtained from the PEFT model seems to be a shallow copy of model weights.
Therefore, if we unmerge the model before generation, the weights are "rolled-back" to the base weights without the adapter.

I move the unmerge operation to the end of generation.

Fixes #2856 (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@XZ-X XZ-X changed the title Move unmerge of PEFT model to the end of generation [GRPO][vllm + Lora] Move unmerge of PEFT model to the end of generation Feb 15, 2025
@matt23654
Copy link

This could be a stupid question, but is it guaranteed that in a multigpu/multinode situation no other process/worker will be doing anything with the model weights during this window where the adaptors are merged? Come to think of it, is it a good idea to use merge/unmerge adaptor at all given all the possible configurations with multiple GPUS, multiple nodes, etc?

@XZ-X
Copy link
Contributor Author

XZ-X commented Feb 16, 2025

This could be a stupid question, but is it guaranteed that in a multigpu/multinode situation no other process/worker will be doing anything with the model weights during this window where the adaptors are merged? Come to think of it, is it a good idea to use merge/unmerge adaptor at all given all the possible configurations with multiple GPUS, multiple nodes, etc?

Thanks for noticing that!

Correct me if I'm wrong. From my understanding, the sync between processes happens at self.optimizer.step() in the trainer. That is, before the model weights are updated, all processes will sync with each other. Therefore, the model weights will change in a synchronized way.

@qgallouedec
Copy link
Member

qgallouedec commented Feb 17, 2025

Thanks for spotting it!
Why not unmerging just after the weights are loaded? it seems sufficient. Like this:

  def _move_model_to_vllm(self):
      with unwrap_model_for_generation(
          self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
      ) as unwrapped_model:
          if is_compiled_module(unwrapped_model):
              unwrapped_model = unwrapped_model._orig_mod
          if is_peft_model(unwrapped_model):
              unwrapped_model.merge_adapter()
              state_dict = unwrapped_model.state_dict()
-             unwrapped_model.unmerge_adapter()
              # Remove base_model and base_layer prefixes
              state_dict = {
                  k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
              }
              # Remove values with adapter prefix (example: "_lora")
              state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
              # When module to save, remove its prefix and discard the original module
              state_dict = {
                  k.replace("modules_to_save.default.", ""): v
                  for k, v in state_dict.items()
                  if "original_module" not in k
              }
          else:
              state_dict = unwrapped_model.state_dict()
+         if self.accelerator.is_main_process:
+             llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
+             llm_model.load_weights(state_dict.items())
+         if is_peft_model(unwrapped_model):
+             unwrapped_model.unmerge_adapter()
-      if self.accelerator.is_main_process:
-          llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
-          llm_model.load_weights(state_dict.items())

@XZ-X
Copy link
Contributor Author

XZ-X commented Feb 17, 2025

You are right. I did not put it right after weights loading because I was not sure how vLLM loads weights (i.e., whether vLLM creates its own copy of the weights or not).

But I just looked up vLLM's code base and it seems that the weight loader indeed copies the weights, as shown here

https://github.com/vllm-project/vllm/blob/ce77eb9410c694000c5da5abfa638500c6c72aeb/vllm/model_executor/model_loader/weight_utils.py#L521

I'll improve the PR as you suggested.

@XZ-X
Copy link
Contributor Author

XZ-X commented Feb 17, 2025

Changed.

@qgallouedec
Copy link
Member

@AndreiCComan thanks for your work in #2856! Can you try if the patch still gives better results?

@qgallouedec
Copy link
Member

Thanks! I'm running some experiments to validate

@AndreiCComan
Copy link

AndreiCComan commented Feb 17, 2025

@AndreiCComan thanks for your work in #2856! Can you try if the patch still gives better results?

Thanks @qgallouedec for the help and @XZ-X for the speedy update!

@qgallouedec here are the plots with patched (my inefficient implementation), @XZ-X's fix-vllm-peft-v1, and @XZ-X's fix-vllm-peft-v2:

patched and fix-vllm-peft-v2 now seem to be fairly close to each other.
Note: I changed my instance in the meantime. I guess a slight difference between the two is justified.

@qgallouedec
Copy link
Member

Seems to work 🥳

Thanks for the work guys

W B Chart 17_02_2025, 18_53_25

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title [GRPO][vllm + Lora] Move unmerge of PEFT model to the end of generation 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading Feb 17, 2025
@qgallouedec qgallouedec merged commit 8226538 into huggingface:main Feb 17, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[GRPO] use_peft: true together with use_vllm: true not working as intended
5 participants