-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🔧 Optimize GRPO VRAM Usage #2669
base: main
Are you sure you want to change the base?
Conversation
Gently adding @qgallouedec to review |
… into grpo-vram-optimization
For some reason I cannot upload images but I let the training job ran one night without any issues. |
Can you profile so that we can compare? |
For sure. What are some profiling tools you guys recommend using? |
Whatever profiler you're familiar with. Can be wandb, torch profiler... |
Sure let me make 2 side-by-side comparisons and I'll share the results. |
Ok done. The results match my expectations and I still have no clue why I cannot directly upload images, but I put down all comparisons in the following PDF and it should give ppl a pretty good overview. Cheers! |
Seems very promising! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice!
trl/trainer/utils.py
Outdated
del prompt_last_logps | ||
|
||
# Interleave the past key values for the G times | ||
prompt_out.past_key_values.batch_repeat_interleave(G) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have the full context, but if this shares memory per repeat (as an expand would) then perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so!
If you have a static cache, you can also specifically compile this part of the code as the shape will never change |
Hi @andyl98, Great thanks for your contributions!
seems |
Hmm, I'm not familiar with the TRL code, but is there a danger here that gradient will not be propagated back through the prompt sequence if the cached hidden states for the prompt are computed without gradient? Remember that even if we are not computing a label or loss on the prompt tokens, we still want gradient to flow back to these tokens via their key and value vectors |
That's a good point @Rocketknight1. I'm currently looking into this |
I think it should still be possible to share the prompt computation and re-use the hidden states! It's just important that those states are computed with gradient, and the graph isn't deleted until you compute the loss on the completion tokens |
… make sure it is respected
I ran into an odd behavior with the improved version and was wondering if someone might be able to explain it: when running the forward passes (1: get KV cache, 2: get policy logp, 3: get ref logp) in a standalone script, the memory consumption only increases by about 6GB for a mini-batch in Step 2; however, if the training is started from the trl framework with minimal configuration, the same PEFT model consumes an additional 30GB VRAM for one mini-batch in forward pass Step 2. (For reference, the output logp in that mini-batch is about 320MB in fp16.) And the memory consumption keeps increasing with each mini-batch, so with a moderate-sized G, the GPU quickly runs out of memory. I tried to keep everything constant between the two setups, i.e. same model/config. I also cached the initial generation results for better comparability. Is this kind of VRAM consumption normal for trl with PEFT enabled? Any help or insight is greatly appreciated! |
Thanks @andyl98 for pushing in this direction. Why your approach isn't compatible with gradient checkpointing? |
Ah this is because it uses the prompt token’s past kv cache. If we turn gradient checkpointing on, that field will be empty/not-reusable. Unless we can do sth like model.gradient_checkpointing_disable()
model.config.use_cache = True before the prompt forwarding part and turn it off after so it is compatible? Not sure if this is an option. cc @ArthurZucker @qgallouedec if you guys know better solutions in this case. |
I haven’t tested PEFT but let me try to understand the situation a bit more. When you launch the training script, are you using accelerate/torch distributed? Maybe I can try to reproduce |
Thanks so much! I ran the script directly with python but I think it uses accelerate backend nonetheless. Using accelerate launch results in the same OOM behavior. When I ran your code in a separate script, I used torch/transformers directly without accelerate backend, which end up being significantly more memory efficient despite the computation appearing to be identical. I debugged both setups and the call stacks seem identical. The memory allocations mainly happen during MLP forward pass in each decoder layer, which used code from AutoAWQ GEMM, as well as some from qkv matrices calculation before F.scaled_dot_product_attention. |
@andyl98 Hello, just a kind reminder is it possible to merge this? Thanks |
Hi @Superskyyy , I can resolve the conflicts, but idk if the trl team wants to merge this. Let me do this later tonight |
What does this PR do?
TL;DR:
prompt_cache
to greatly reduce VRAM usage ifprompt_len
/completion_len
ratio is highIntuition for 1
When running
open-r1
I realize that the VRAM requirement scales poorly with the prompt length. After some further investigations with the existing implementation, I find it obvious that the OOM issue mainly comes from this linehttps://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L332
Where we basically feed in all
B * G * (P + C)
tokens in to the model to calculate logprobs, whereB
denotes the batch sizeG
denotes the number of generations per prompt (a.k.a Group size)P
denotes the maximum prompt lengthC
denotes the maximum completion lengthThis is actually a bit counterintuitive, because
G
generations and a single forward pass should already give us what we need.Thus, a more ideal approach should tackle the following:
G
timesThis way, we only need to store information for
B * (P + G * C)
tokens per batch.This PR
With this change, I can fit a 7B Qwen model with
B
= 8,G
= 16,P
= 5000 andC
= 256 without issues, whereas previously, this will absolutely cause OOM.02/03/25 Update: combine the 2 approaches above, I can train a
32b
model with 4x A100-80gbs.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.