Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optional ZeRO-3 Weight Gathering for GRPO in Sequence Generation #2667

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SeungyounShin
Copy link
Contributor

@SeungyounShin SeungyounShin commented Jan 27, 2025

What does this PR do?

This PR extends the functionality for DeepSpeed ZeRO-3 weight gathering in sequence generation, making it optional for the GRPO algorithm. While similar support was added for PPO/RLOO/OnlineDPO in [#2557](#2557), GRPO—another online RL algorithm—was missing this feature. This update ensures consistent behavior across online RL algorithms in the TRL library.

Motivation and Context

This addition is particularly important for reducing memory usage in online RL pipelines. As highlighted in the [Open-R1 project issue #65](huggingface/open-r1#65), enabling optional weight gathering is crucial for efficiently managing memory during large-scale training. By including this feature, the GRPO algorithm can now handle longer contexts without exceeding memory limitations.

How to Test

You can reproduce the results and verify the implementation using the following steps:

  1. Clone and set up the Open-R1 repository:

    git clone https://github.com/SeungyounShin/open-r1.git && pip install vllm==0.6.6.post1 && cd open-r1 && pip install -e ".[dev]" && apt update && apt-get install -y git-lfs

    uninstall trl and install this branch (SeungyounShin:feat/grpo-unwrap-for-generation)

     pip uninstall -y trl
     pip install git+https://github.com/SeungyounShin/trl.git@feat/grpo-unwrap-for-generation
  2. Run the training script with the specified parameters:

    accelerate launch src/open_r1/grpo.py \
        --output_dir DeepSeek-R1-Distill-Qwen-7B-GRPO \
        --model_name_or_path Seungyoun/ \
        --dataset_name AI-MO/NuminaMath-TIR \
        --max_prompt_length 1024 \
        --per_device_train_batch_size 1 \
        --max_completion_length 512 \
        --gradient_accumulation_steps 16 \
        --logging_steps 10 \
        --bf16 \
        --ds3_gather_for_generation false

This will demonstrate that the changes work seamlessly with DeepSpeed ZeRO-3 and improve memory efficiency.

Checklist

Who can review?

Anyone in the community, especially contributors familiar with online RL algorithms and DeepSpeed integration, is welcome to review this PR. Please feel free to tag interested members for feedback.

@qgallouedec
Copy link
Member

Thanks! Looks good overall. I'll just wait #2600 to be merged before merging this one. Consequently can you just add that this option is not compatible with use_vllm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants