Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Optimize GRPO VRAM Usage #2669

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

andyl98
Copy link
Contributor

@andyl98 andyl98 commented Jan 27, 2025

What does this PR do?

TL;DR:

  • Re-usesprompt_cache to greatly reduce VRAM usage if prompt_len/completion_len ratio is high
  • Use a mini-batch approach to perform logit calculations which further reduces peak memory usage
  1. "Fixes" Crazy VRAM usage with longer prompts open-r1#47
  2. Applies similar idea to GRPO memory bottleneck from num_generations in compute_loss #2709

Intuition for 1

When running open-r1 I realize that the VRAM requirement scales poorly with the prompt length. After some further investigations with the existing implementation, I find it obvious that the OOM issue mainly comes from this line

logits = model(input_ids).logits  # (B, L, V)

https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L332

Where we basically feed in all B * G * (P + C) tokens in to the model to calculate logprobs, where

  • B denotes the batch size
  • G denotes the number of generations per prompt (a.k.a Group size)
  • P denotes the maximum prompt length
  • C denotes the maximum completion length

This is actually a bit counterintuitive, because

  1. We only need the logprobs of the completion tokens for reward/loss calculation.
  2. The prompt are fixed for all G generations and a single forward pass should already give us what we need.

Thus, a more ideal approach should tackle the following:

  1. Only "forward" the prompt once per group
  2. Re-use the cached hidden states for G times

This way, we only need to store information for B * (P + G * C) tokens per batch.

This PR

With this change, I can fit a 7B Qwen model with B = 8, G = 16, P = 5000 and C = 256 without issues, whereas previously, this will absolutely cause OOM.

02/03/25 Update: combine the 2 approaches above, I can train a 32b model with 4x A100-80gbs.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Sorry, something went wrong.

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 27, 2025

Gently adding @qgallouedec to review

@andyl98 andyl98 changed the title 🔧 GRPO VRAM Optimization 🔧 Optimize GRPO VRAM Usage by Compute Prompt Tokens Once Jan 27, 2025
@andyl98 andyl98 changed the title 🔧 Optimize GRPO VRAM Usage by Compute Prompt Tokens Once 🔧 Optimize GRPO VRAM Usage by Computing Prompt Tokens Just Once Jan 27, 2025
@andyl98 andyl98 marked this pull request as draft January 28, 2025 00:44
@andyl98 andyl98 marked this pull request as ready for review January 28, 2025 01:48
andyl98 and others added 2 commits January 27, 2025 17:55

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

For some reason I cannot upload images but I let the training job ran one night without any issues.

@qgallouedec
Copy link
Member

Can you profile so that we can compare?

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

Can you profile so that we can compare?

For sure. What are some profiling tools you guys recommend using?
Would torch.cuda.memory_allocated() and torch.cuda.memory_reserved() be enough?

@qgallouedec
Copy link
Member

Whatever profiler you're familiar with. Can be wandb, torch profiler...

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

Whatever profiler you're familiar with. Can be wandb, torch profiler...

Sure let me make 2 side-by-side comparisons and I'll share the results.

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

Ok done. The results match my expectations and I still have no clue why I cannot directly upload images, but I put down all comparisons in the following PDF and it should give ppl a pretty good overview.

Cheers!

GRPO VRAM Diff Investigation.pdf

@Superskyyy Superskyyy mentioned this pull request Jan 28, 2025
5 tasks
@qgallouedec
Copy link
Member

Seems very promising!
What about the training speed?

Copy link

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice!

del prompt_last_logps

# Interleave the past key values for the G times
prompt_out.past_key_values.batch_repeat_interleave(G)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the full context, but if this shares memory per repeat (as an expand would) then perfect!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so!

@ArthurZucker
Copy link

If you have a static cache, you can also specifically compile this part of the code as the shape will never change

@fkxie
Copy link

fkxie commented Jan 29, 2025

Hi @andyl98, Great thanks for your contributions!
When I test your code, it aborts:

utils.py, line 1698, in compute_logps_with_prompt_cache
    prompt_out.past_key_values.batch_repeat_interleave(G)
AttributeError: 'tuple' object has no attribute 'batch_repeat_interleave'

seems past_key_values is a tuple, did I missing something?

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jan 29, 2025

Hmm, I'm not familiar with the TRL code, but is there a danger here that gradient will not be propagated back through the prompt sequence if the cached hidden states for the prompt are computed without gradient? Remember that even if we are not computing a label or loss on the prompt tokens, we still want gradient to flow back to these tokens via their key and value vectors

@qgallouedec
Copy link
Member

That's a good point @Rocketknight1. I'm currently looking into this

@Rocketknight1
Copy link
Member

I think it should still be possible to share the prompt computation and re-use the hidden states! It's just important that those states are computed with gradient, and the graph isn't deleted until you compute the loss on the completion tokens

@andyl98 andyl98 changed the title 🔧 Optimize GRPO VRAM Usage by Computing Prompt Tokens Just Once 🔧 Optimize GRPO VRAM Usage Feb 4, 2025
@mchen30
Copy link

mchen30 commented Feb 6, 2025

I ran into an odd behavior with the improved version and was wondering if someone might be able to explain it: when running the forward passes (1: get KV cache, 2: get policy logp, 3: get ref logp) in a standalone script, the memory consumption only increases by about 6GB for a mini-batch in Step 2; however, if the training is started from the trl framework with minimal configuration, the same PEFT model consumes an additional 30GB VRAM for one mini-batch in forward pass Step 2. (For reference, the output logp in that mini-batch is about 320MB in fp16.) And the memory consumption keeps increasing with each mini-batch, so with a moderate-sized G, the GPU quickly runs out of memory.

I tried to keep everything constant between the two setups, i.e. same model/config. I also cached the initial generation results for better comparability.

Is this kind of VRAM consumption normal for trl with PEFT enabled? Any help or insight is greatly appreciated!

@qgallouedec
Copy link
Member

Thanks @andyl98 for pushing in this direction. Why your approach isn't compatible with gradient checkpointing?

@andyl98
Copy link
Contributor Author

andyl98 commented Feb 6, 2025

Thanks @andyl98 for pushing in this direction. Why your approach isn't compatible with gradient checkpointing?

Ah this is because it uses the prompt token’s past kv cache. If we turn gradient checkpointing on, that field will be empty/not-reusable.

Unless we can do sth like

model.gradient_checkpointing_disable()
model.config.use_cache = True

before the prompt forwarding part and turn it off after so it is compatible? Not sure if this is an option. cc @ArthurZucker @qgallouedec if you guys know better solutions in this case.

@andyl98
Copy link
Contributor Author

andyl98 commented Feb 6, 2025

I ran into an odd behavior with the improved version and was wondering if someone might be able to explain it: when running the forward passes (1: get KV cache, 2: get policy logp, 3: get ref logp) in a standalone script, the memory consumption only increases by about 6GB for a mini-batch in Step 2; however, if the training is started from the trl framework with minimal configuration, the same PEFT model consumes an additional 30GB VRAM for one mini-batch in forward pass Step 2. (For reference, the output logp in that mini-batch is about 320MB in fp16.) And the memory consumption keeps increasing with each mini-batch, so with a moderate-sized G, the GPU quickly runs out of memory.

I tried to keep everything constant between the two setups, i.e. same model/config. I also cached the initial generation results for better comparability.

Is this kind of VRAM consumption normal for trl with PEFT enabled? Any help or insight is greatly appreciated!

I haven’t tested PEFT but let me try to understand the situation a bit more. When you launch the training script, are you using accelerate/torch distributed? Maybe I can try to reproduce

@mchen30
Copy link

mchen30 commented Feb 6, 2025

I ran into an odd behavior with the improved version and was wondering if someone might be able to explain it: when running the forward passes (1: get KV cache, 2: get policy logp, 3: get ref logp) in a standalone script, the memory consumption only increases by about 6GB for a mini-batch in Step 2; however, if the training is started from the trl framework with minimal configuration, the same PEFT model consumes an additional 30GB VRAM for one mini-batch in forward pass Step 2. (For reference, the output logp in that mini-batch is about 320MB in fp16.) And the memory consumption keeps increasing with each mini-batch, so with a moderate-sized G, the GPU quickly runs out of memory.
I tried to keep everything constant between the two setups, i.e. same model/config. I also cached the initial generation results for better comparability.
Is this kind of VRAM consumption normal for trl with PEFT enabled? Any help or insight is greatly appreciated!

I haven’t tested PEFT but let me try to understand the situation a bit more. When you launch the training script, are you using accelerate/torch distributed? Maybe I can try to reproduce

Thanks so much! I ran the script directly with python but I think it uses accelerate backend nonetheless. Using accelerate launch results in the same OOM behavior. When I ran your code in a separate script, I used torch/transformers directly without accelerate backend, which end up being significantly more memory efficient despite the computation appearing to be identical. I debugged both setups and the call stacks seem identical. The memory allocations mainly happen during MLP forward pass in each decoder layer, which used code from AutoAWQ GEMM, as well as some from qkv matrices calculation before F.scaled_dot_product_attention.

@Superskyyy
Copy link
Contributor

@andyl98 Hello, just a kind reminder is it possible to merge this? Thanks

@andyl98
Copy link
Contributor Author

andyl98 commented Feb 18, 2025

@andyl98 Hello, just a kind reminder is it possible to merge this? Thanks

Hi @Superskyyy , I can resolve the conflicts, but idk if the trl team wants to merge this. Let me do this later tonight

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants