Add Optional ZeRO-3 Weight Gathering for GRPO in Sequence Generation #2667
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR extends the functionality for DeepSpeed ZeRO-3 weight gathering in sequence generation, making it optional for the
GRPO
algorithm. While similar support was added for PPO/RLOO/OnlineDPO in [#2557](#2557),GRPO
—another online RL algorithm—was missing this feature. This update ensures consistent behavior across online RL algorithms in the TRL library.Motivation and Context
This addition is particularly important for reducing memory usage in online RL pipelines. As highlighted in the [Open-R1 project issue #65](huggingface/open-r1#65), enabling optional weight gathering is crucial for efficiently managing memory during large-scale training. By including this feature, the
GRPO
algorithm can now handle longer contexts without exceeding memory limitations.How to Test
You can reproduce the results and verify the implementation using the following steps:
Clone and set up the Open-R1 repository:
uninstall trl and install this branch (SeungyounShin:feat/grpo-unwrap-for-generation)
Run the training script with the specified parameters:
accelerate launch src/open_r1/grpo.py \ --output_dir DeepSeek-R1-Distill-Qwen-7B-GRPO \ --model_name_or_path Seungyoun/ \ --dataset_name AI-MO/NuminaMath-TIR \ --max_prompt_length 1024 \ --per_device_train_batch_size 1 \ --max_completion_length 512 \ --gradient_accumulation_steps 16 \ --logging_steps 10 \ --bf16 \ --ds3_gather_for_generation false
This will demonstrate that the changes work seamlessly with DeepSpeed ZeRO-3 and improve memory efficiency.
Checklist
GRPO
support for optional ZeRO-3 weight gathering).Who can review?
Anyone in the community, especially contributors familiar with online RL algorithms and DeepSpeed integration, is welcome to review this PR. Please feel free to tag interested members for feedback.