Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐦‍🔥 6x faster GRPO with multi-step optimization #2899

Merged
merged 16 commits into from
Feb 20, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Feb 19, 2025

What does this PR do?

This PR implements the multi-step trick. It allows the generation to be reused several times, speeding up training.

It requires to implement the importance sampling and the clipping logic.

- per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

Results with various mu (note that mu=1 matches the current implementation in main):

W B Chart 20_02_2025, 13_24_19

W B Chart 20_02_2025, 13_17_55-3

W B Chart 20_02_2025, 13_17_55-2

W B Chart 20_02_2025, 13_18_56-2

For reproduction:

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer


dataset = load_dataset("trl-lib/tldr", split="train")

# Dummy reward function: the closer the completion is to 20 characters, the higher the reward
def reward_len(completions, **kwargs):
    return [-abs(100 - len(completion)) for completion in completions]

def main():
    num_iterations=4
    training_args = GRPOConfig(
        output_dir=f"Qwen2.5-0.5B-GRPO-2899-μ={num_iterations}",
        logging_steps=5,
        gradient_accumulation_steps=4,
        per_device_train_batch_size=4,
        num_generations=8,
        max_prompt_length=64,
        max_completion_length=64,
        log_completions=True,
        max_steps=200,
        num_iterations=num_iterations,
    )
    trainer = GRPOTrainer(
        model="Qwen/Qwen2.5-0.5B",
        reward_funcs=reward_len,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()


if __name__ == "__main__":
    main()

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec qgallouedec changed the title Multi-step GRPOTrainer 🐦‍🔥 6x faster GRPO with multi-step optimization Feb 20, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -224,6 +228,14 @@ class GRPOConfig(TrainingArguments):
"training speed."
},
)
num_iterations: int = field(
default=1,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set another default? eg 4?

@qgallouedec qgallouedec marked this pull request as ready for review February 20, 2025 12:53
@@ -638,6 +731,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

with torch.inference_mode():
old_per_token_logps = self._get_per_token_logps(
Copy link
Collaborator

@edbeeching edbeeching Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if num_iterations == 1, I think we do a redundant forward pass here as the same thing will be recalculated at the optimization step. Would it be better to only calculate this if num_iterations > 1, and then at the optimization step ?
For example:

if num_iterations==1:
    old_per_token_logps = per_token_logps.detach()
else:
    old_per_token_logps = inputs["old_per_token_logps"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, that's why this version is slightly slower with num_iterations=1:
W B Chart 20_02_2025, 13_24_19

Let me add this optimization

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif
Copy link
Collaborator

kashif commented Feb 20, 2025

clearer when one realizes that the integers are the indices of the prompts!

@qgallouedec
Copy link
Member Author

clearer when one realizes that the integers are the indices of the prompts!

I'll add it, thanks for the feedback

idanshen added a commit to idanshen/multi_ref that referenced this pull request Feb 20, 2025
@qgallouedec qgallouedec merged commit e5ae703 into main Feb 20, 2025
14 checks passed
@qgallouedec qgallouedec deleted the multi-step-grpi branch February 20, 2025 18:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants