Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"None of the inputs have requires_grad=True" with online DPO and GRPO #2671

Open
5 tasks done
benjamin-marie opened this issue Jan 28, 2025 · 5 comments
Open
5 tasks done
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO 🏋 Online DPO Related to Online DPO

Comments

@benjamin-marie
Copy link

Reproduction

Are online DPO and GRPO supposed to work with gradient checkpointing enabled?
I always get this warning when using them:
/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None

And then the model doesn't seem to learn with a training loss that goes up and down, and the learning rate doesn't seem to have any impact.

Here is the notebook (simple code) to reproduce the error:
https://colab.research.google.com/drive/1Tb2m_EBdKuuELEEMkA7YYHmOIxozMBmu?usp=sharing

I tried many variations and first thought it was related to the use of an adapter but it isn't.

This notebook runs online DPO but I have the exact same problem with GRPO.
PS: use_vllm doesn't work with a peft config. In the same notebook, use_vllm=True and the peft_config trigger an error.

System Info

Google Colab L4/A100

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🏋 Online DPO Related to Online DPO 🏋 GRPO Related to GRPO 🐛 bug Something isn't working labels Jan 28, 2025
@benjamin-marie
Copy link
Author

Same issue in open R1:

!git clone https://github.com/huggingface/open-r1.git
%cd open-r1/
!python src/open_r1/grpo.py \
    --output_dir DeepSeek-R1-Distill-Qwen-7B-GRPO \
    --model_name_or_path Qwen/Qwen2.5-1.5B \
    --dataset_name AI-MO/NuminaMath-TIR \
    --max_prompt_length 256 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --bf16
/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None

@benjamin-marie
Copy link
Author

This part for gradient checkpointing is in the other TRL trainers, but not in the online DPO and GRPO trainers:

            elif getattr(args, "gradient_checkpointing", False):
                # For backward compatibility with older versions of transformers
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(module, input, output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

Probably an easy fix?

@qgallouedec
Copy link
Member

Probably. But I not sure to understand the fix at this point

@benjamin-marie
Copy link
Author

I saw that Philipp uses gradient checkpointing in the following tutorial:
https://www.philschmid.de/mini-deepseek-r1

I tried but it doesn't work either. Gradient checkpointing in this tutorial doesn't trigger the warning because use_reentrant is set to False instead of True. I might be wrong but I think the non_reentrant variant is not implemented in Qwen (and most LLMs). The consequence is that it consumes as much memory as if gradient checkpointing was set to False.

@qgallouedec
Copy link
Member

Thanks for the follow-up. Can you submit a PR so that we can make some tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO 🏋 Online DPO Related to Online DPO
Projects
None yet
Development

No branches or pull requests

2 participants