diff --git a/.github/workflows/tests_latest.yml b/.github/workflows/tests_latest.yml index d8ea6d524e..9d957a4242 100644 --- a/.github/workflows/tests_latest.yml +++ b/.github/workflows/tests_latest.yml @@ -17,7 +17,7 @@ jobs: steps: - name: Git checkout uses: actions/checkout@v4 - with: { ref: v0.13-release } + with: { ref: v0.15-release } - name: Set up Python 3.12 uses: actions/setup-python@v5 with: diff --git a/CITATION.cff b/CITATION.cff index 05b9ae2bbc..68076ef706 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -31,4 +31,4 @@ keywords: - pytorch - transformers license: Apache-2.0 -version: 0.13 +version: 0.15 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9a8979a658..b3b1c0d411 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,7 +23,7 @@ There are several ways you can contribute to TRL: * Contribute to the examples or the documentation. If you don't know where to start, there is a special [Good First -Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of +Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over. For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀 diff --git a/README.md b/README.md index b7f1e5b7aa..01f50ad7ae 100644 --- a/README.md +++ b/README.md @@ -137,39 +137,26 @@ trainer = RewardTrainer( trainer.train() ``` -### `RLOOTrainer` +### `GRPOTrainer` -`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`: +`GRPOTrainer` implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1). ```python -from trl import RLOOConfig, RLOOTrainer, apply_chat_template from datasets import load_dataset -from transformers import ( - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, -) +from trl import GRPOConfig, GRPOTrainer -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -reward_model = AutoModelForSequenceClassification.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1 -) -ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/tldr", split="train") -dataset = load_dataset("trl-lib/ultrafeedback-prompt") -dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt") +# Dummy reward function: rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [-abs(20 - len(completion)) for completion in completions] -training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL") -trainer = RLOOTrainer( - config=training_args, - processing_class=tokenizer, - policy=policy, - ref_policy=ref_policy, - reward_model=reward_model, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, ) trainer.train() ``` diff --git a/commands/run_sft.sh b/commands/run_sft.sh index bdea77fcb6..b7beaaf7fd 100644 --- a/commands/run_sft.sh +++ b/commands/run_sft.sh @@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \ --output_dir $OUTPUT_DIR \ --max_steps $MAX_STEPS \ --per_device_train_batch_size $BATCH_SIZE \ - --max_seq_length $SEQ_LEN \ + --max_length $SEQ_LEN \ $EXTRA_TRAINING_ARGS """ diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1f793e6ea9..4ccc57ae8f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -68,6 +68,8 @@ title: Online DPO - local: gkd_trainer title: GKD + - local: grpo_trainer + title: GRPO - local: kto_trainer title: KTO - local: nash_md_trainer diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.md similarity index 96% rename from docs/source/alignprop_trainer.mdx rename to docs/source/alignprop_trainer.md index a4c6b007ef..4c3b21042c 100644 --- a/docs/source/alignprop_trainer.mdx +++ b/docs/source/alignprop_trainer.md @@ -16,7 +16,7 @@ The `alignprop.py` script is a working example of using the `AlignProp` trainer **Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1. -Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running ```batch python alignprop.py --hf_user_access_token @@ -26,7 +26,7 @@ To obtain the documentation of `stable_diffusion_tuning.py`, please run `python The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) -- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) +- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) - The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False ## Setting up the image logging hook function diff --git a/docs/source/bco_trainer.mdx b/docs/source/bco_trainer.md similarity index 95% rename from docs/source/bco_trainer.mdx rename to docs/source/bco_trainer.md index c23365cc00..e449f86b63 100644 --- a/docs/source/bco_trainer.mdx +++ b/docs/source/bco_trainer.md @@ -62,7 +62,7 @@ embedding_model = Accelerator().prepare_model(self.embedding_model) embedding_func = partial(embed_prompt, model=embedding_model) ``` -Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: +Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: ```py training_args = BCOConfig( @@ -97,4 +97,4 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype ## BCOConfig -[[autodoc]] BCOConfig \ No newline at end of file +[[autodoc]] BCOConfig diff --git a/docs/source/best_of_n.mdx b/docs/source/best_of_n.md similarity index 96% rename from docs/source/best_of_n.mdx rename to docs/source/best_of_n.md index 9dd56aba2c..8b2978c2a3 100644 --- a/docs/source/best_of_n.mdx +++ b/docs/source/best_of_n.md @@ -67,6 +67,6 @@ best_of_n.generate(query_tensors, device=device) ``` -Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query +Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query diff --git a/docs/source/callbacks.mdx b/docs/source/callbacks.md similarity index 100% rename from docs/source/callbacks.mdx rename to docs/source/callbacks.md diff --git a/docs/source/clis.mdx b/docs/source/clis.md similarity index 98% rename from docs/source/clis.mdx rename to docs/source/clis.md index 9c7a2dfca8..d165a49668 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.md @@ -7,12 +7,13 @@ Currently supported CLIs are: #### Training commands - `trl dpo`: fine-tune a LLM with DPO +- `trl grpo`: fine-tune a LLM with GRPO - `trl kto`: fine-tune a LLM with KTO - `trl sft`: fine-tune a LLM with SFT #### Other commands -- `trl chat`: quickly spin up a LLM fine-tuned for chatting +- `trl chat`: quickly spin up an LLM fine-tuned for chatting - `trl env`: get the system information ## Fine-tuning with the CLI diff --git a/docs/source/community_tutorials.md b/docs/source/community_tutorials.md index 4b2b9a6e54..0cd5e8e25f 100644 --- a/docs/source/community_tutorials.md +++ b/docs/source/community_tutorials.md @@ -1,28 +1,30 @@ # Community Tutorials -Community tutorials are made by active members of the Hugging Face community that want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities. +Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities. # Language Models -| Task | Class | Description | Author | Tutorial | Colab | -| ----------------------- | --------------- | ---------------------------------------------------------------------------------------- | -------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) | -| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) | -| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) | -| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | +| Task | Class | Description | Author | Tutorial | Colab | +| --- | --- | --- | --- | --- | --- | +| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) | +| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) | +| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) | +| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) | +| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | | Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) | # Vision Language Models -| Task | Class | Description | Author | Tutorial | Colab | -| --------------- | -------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------ | -------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) | -| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) | -| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) | -| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | -| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) | +| Task | Class | Description | Author | Tutorial | Colab | +| --- | --- | --- | --- | --- | --- | +| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) | +| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) | +| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) | +| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | +| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) | ## Contributing diff --git a/docs/source/cpo_trainer.mdx b/docs/source/cpo_trainer.md similarity index 98% rename from docs/source/cpo_trainer.mdx rename to docs/source/cpo_trainer.md index 3f9fb88cfc..24e0f3fdae 100644 --- a/docs/source/cpo_trainer.mdx +++ b/docs/source/cpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat. +Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat. CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. @@ -105,4 +105,4 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype ## CPOConfig -[[autodoc]] CPOConfig \ No newline at end of file +[[autodoc]] CPOConfig diff --git a/docs/source/customization.mdx b/docs/source/customization.md similarity index 100% rename from docs/source/customization.mdx rename to docs/source/customization.md diff --git a/docs/source/data_utils.mdx b/docs/source/data_utils.md similarity index 79% rename from docs/source/data_utils.mdx rename to docs/source/data_utils.md index 9b8391278d..bdadd4206b 100644 --- a/docs/source/data_utils.mdx +++ b/docs/source/data_utils.md @@ -12,6 +12,10 @@ [[autodoc]] maybe_apply_chat_template +## maybe_convert_to_chatml + +[[autodoc]] maybe_convert_to_chatml + ## extract_prompt [[autodoc]] extract_prompt @@ -27,3 +31,7 @@ ## maybe_unpair_preference_dataset [[autodoc]] maybe_unpair_preference_dataset + +## pack_examples + +[[autodoc]] pack_examples diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.md similarity index 99% rename from docs/source/dataset_formats.mdx rename to docs/source/dataset_formats.md index 1306fdad7f..a8e10f1830 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.md @@ -270,6 +270,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`GRPOTrainer`] | [Prompt-only](#prompt-only) | | [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | | [`NashMDTrainer`] | [Prompt-only](#prompt-only) | @@ -340,7 +341,7 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation. +We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation. For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks. diff --git a/docs/source/ddpo_trainer.mdx b/docs/source/ddpo_trainer.md similarity index 95% rename from docs/source/ddpo_trainer.mdx rename to docs/source/ddpo_trainer.md index 0682144edb..eca557c9e4 100644 --- a/docs/source/ddpo_trainer.mdx +++ b/docs/source/ddpo_trainer.md @@ -14,8 +14,8 @@ ## Getting started with Stable Diffusion finetuning with reinforcement learning The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` -library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. -Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. +library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. +Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made. There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. @@ -26,7 +26,7 @@ For a more detailed look into the interface and the associated default implement Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. -Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. +Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. ## Getting started with `examples/scripts/ddpo.py` diff --git a/docs/source/detoxifying_a_lm.mdx b/docs/source/detoxifying_a_lm.md similarity index 96% rename from docs/source/detoxifying_a_lm.mdx rename to docs/source/detoxifying_a_lm.md index fe97422889..eb0ab5fd80 100644 --- a/docs/source/detoxifying_a_lm.mdx +++ b/docs/source/detoxifying_a_lm.md @@ -30,7 +30,7 @@ We selected the following models for our experiments to show that TRL can be eas * [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters) * [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters) -For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). +For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). | Model | Mean toxicity score | |---|---| @@ -45,7 +45,7 @@ When doing PPO, it is very important to design the problem efficiently so that t ### Pre-processing the dataset -The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score. +The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score. A `prompt` example: ``` @@ -88,7 +88,7 @@ As a compromise between the two we took for a context window of 10 to 15 tokens ### How to deal with OOM issues -Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: +Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: - Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2: @@ -109,7 +109,7 @@ ref_model = create_reference_model(model, num_shared_layers=6) trainer = PPOTrainer(..., ref_model=ref_model) ``` -In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). +In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). - One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower). @@ -176,7 +176,7 @@ The evaluation script can be found [here](https://github.com/huggingface/trl/blo The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers). -To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful. +To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful. ### Limitations diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.md similarity index 99% rename from docs/source/dpo_trainer.mdx rename to docs/source/dpo_trainer.md index b0d6b1f8d6..dac5c227d5 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.md @@ -81,7 +81,7 @@ The best programming language based on these factors is subjective and depends o ## Expected dataset type -DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. @@ -280,4 +280,4 @@ dpo_trainer = DPOTrainer( ## DataCollatorForPreference -[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference \ No newline at end of file +[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md new file mode 100644 index 0000000000..c1b9a5c28e --- /dev/null +++ b/docs/source/grpo_trainer.md @@ -0,0 +1,253 @@ +# GRPO Trainer + +[![](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl) + +## Overview + +TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday). + +The abstract from the paper is the following: + +> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. + +This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Quick start + +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here: + + + +Below is the script to train the model. + +```python +# train_grpo.py +from datasets import load_dataset +from trl import GRPOConfig, GRPOTrainer + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Define the reward function, which rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [-abs(20 - len(completion)) for completion in completions] + +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_grpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 day. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png) + +## Looking deeper into the GRPO method + +GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png) + +### Generating completions + +At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)). + +### Computing the advantage + +For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows: + +$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$ + +This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**. + +### Estimating the KL divergence + +KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows: + +$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,>> prompts = ["The sky is", "The sun is"] +>>> completions = [" blue.", " in the sky."] +>>> print(reward_func(prompts=prompts, completions=completions)) +[6.0, 12.0] +``` + +#### Example 2: Reward completions with specific format + +Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +It is designed for conversational format, where prompts and completions consist of structured messages. + +```python +import re + +def format_reward_func(completions, **kwargs): + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?.*?$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] +``` + +You can test this function as follows: + +```python +>>> prompts = [ +... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}], +... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}], +... ] +>>> completions = [ +... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], +... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], +... ] +>>> format_reward_func(prompts=prompts, completions=completions) +[1.0, 0.0] +``` + +#### Example 3: Reward completions based on a reference + +Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`. + +```python +import re + +def reward_func(completions, ground_truth, **kwargs): + # Regular expression to capture content inside \boxed{} + matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions] + contents = [match.group(1) if match else "" for match in matches] + # Reward 1 if the content is the same as the ground truth, 0 otherwise + return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)] +``` + +You can test this function as follows: + +```python +>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."] +>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."] +>>> ground_truth = ["2", "5"] +>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth) +[1.0, 0.0] +``` + +#### Passing the reward function to the trainer + +To use your custom reward function, pass it to the [`GRPOTrainer`] as follows: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=reward_func, + ..., +) +``` + +If you have multiple reward functions, you can pass them as a list: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=[reward_func1, reward_func2], + ..., +) +``` +and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config. + +Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. + +## GRPOTrainer + +[[autodoc]] GRPOTrainer + +## GRPOConfig + +[[autodoc]] GRPOConfig diff --git a/docs/source/index.mdx b/docs/source/index.md similarity index 100% rename from docs/source/index.mdx rename to docs/source/index.md diff --git a/docs/source/installation.md b/docs/source/installation.md new file mode 100644 index 0000000000..8ab4165931 --- /dev/null +++ b/docs/source/installation.md @@ -0,0 +1,39 @@ +# Installation +You can install TRL either from PyPI or from source: + +## PyPI +Install the library with pip or [uv](https://docs.astral.sh/uv/): + + + + +uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), . + +```bash +uv pip install trl +``` + + + + +```bash +pip install trl +``` + + + + +## Source +You can also install the latest version from source. First clone the repo and then run the installation with `pip`: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e . +``` + +If you want the development install you can replace the pip install with the following: + +```bash +pip install -e ".[dev]" +``` diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx deleted file mode 100644 index bf74b64175..0000000000 --- a/docs/source/installation.mdx +++ /dev/null @@ -1,24 +0,0 @@ -# Installation -You can install TRL either from pypi or from source: - -## pypi -Install the library with pip: - -```bash -pip install trl -``` - -### Source -You can also install the latest version from source. First clone the repo and then run the installation with `pip`: - -```bash -git clone https://github.com/huggingface/trl.git -cd trl/ -pip install -e . -``` - -If you want the development install you can replace the pip install with the following: - -```bash -pip install -e ".[dev]" -``` \ No newline at end of file diff --git a/docs/source/iterative_sft_trainer.mdx b/docs/source/iterative_sft_trainer.md similarity index 100% rename from docs/source/iterative_sft_trainer.mdx rename to docs/source/iterative_sft_trainer.md diff --git a/docs/source/judges.mdx b/docs/source/judges.md similarity index 100% rename from docs/source/judges.mdx rename to docs/source/judges.md diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.md similarity index 98% rename from docs/source/kto_trainer.mdx rename to docs/source/kto_trainer.md index 05de7a026d..6e81f39019 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.md @@ -115,7 +115,7 @@ Each choice of `beta` has a maximum learning rate it can tolerate before learnin ### Imbalanced data The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples. -By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3. +By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3. ## Logged metrics diff --git a/docs/source/learning_tools.mdx b/docs/source/learning_tools.md similarity index 98% rename from docs/source/learning_tools.mdx rename to docs/source/learning_tools.md index add4844e2b..368d666fe8 100644 --- a/docs/source/learning_tools.mdx +++ b/docs/source/learning_tools.md @@ -22,7 +22,7 @@ Note that the scripts above rely heavily on the `TextEnvironment` API which is s The rough idea is as follows: -1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number: +1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number: ```python from transformers import AutoTokenizer, load_tool tool = load_tool("ybelkada/simple-calculator") @@ -154,7 +154,7 @@ We then basically deployed this snippet as a Hugging Face space [here](https://h We use the following settings: * use the `bigcode/starcoderbase` model as the base model -* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. +* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. * test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0. * notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer. * used the following prompt that demonstrates the usage of the wiki tool. @@ -220,9 +220,9 @@ def solution(): result = money_left return result print(solution()) -72 +8 -Result = 72 +Result = 8 Q: """ ``` diff --git a/docs/source/logging.mdx b/docs/source/logging.md similarity index 96% rename from docs/source/logging.mdx rename to docs/source/logging.md index 4c60868dac..f6d9657675 100644 --- a/docs/source/logging.mdx +++ b/docs/source/logging.md @@ -16,8 +16,8 @@ If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir Here's a brief explanation for the logged metrics provided in the data: Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy: -1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model. -1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model. +1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model. +1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model. 1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment. 1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function. 1. `objective/kl_dist`: The histogram distribution of the `objective/kl`. @@ -71,4 +71,4 @@ Here are some parameters that are useful to monitor for stability (when these di 1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on. 1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well. 1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy. -1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities. \ No newline at end of file +1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities. diff --git a/docs/source/models.mdx b/docs/source/models.md similarity index 100% rename from docs/source/models.mdx rename to docs/source/models.md diff --git a/docs/source/multi_adapter_rl.mdx b/docs/source/multi_adapter_rl.md similarity index 100% rename from docs/source/multi_adapter_rl.mdx rename to docs/source/multi_adapter_rl.md diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md index 1e0faf663f..13dae3eb27 100644 --- a/docs/source/ppo_trainer.md +++ b/docs/source/ppo_trainer.md @@ -26,6 +26,8 @@ python examples/scripts/ppo/ppo.py \ --gradient_accumulation_steps 1 \ --total_episodes 10000 \ --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path EleutherAI/pythia-1b-deduped \ + --reward_model_path EleutherAI/pythia-1b-deduped \ --missing_eos_penalty 1.0 ``` @@ -56,7 +58,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an ## Cookbook * Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. -* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it. +* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. * Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. @@ -234,4 +236,4 @@ python -m openrlbenchmark.rlops_multi_metrics \ ## PPOConfig -[[autodoc]] PPOConfig \ No newline at end of file +[[autodoc]] PPOConfig diff --git a/docs/source/prm_trainer.mdx b/docs/source/prm_trainer.md similarity index 100% rename from docs/source/prm_trainer.mdx rename to docs/source/prm_trainer.md diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.md similarity index 100% rename from docs/source/quickstart.mdx rename to docs/source/quickstart.md diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index e015c43906..cc335156e6 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...) -SFT truncation is applied to the input sequence via the `max_seq_length` parameter. +SFT truncation is applied to the input sequence via the `max_length` parameter.
Truncation input ids @@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet: ```python from trl import SFTConfig -training_args = SFTConfig(..., max_seq_length=...) +training_args = SFTConfig(..., max_length=...) ``` @@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f ```python from trl import SFTConfig -training_args = SFTConfig(..., packing=True, max_seq_length=512) +training_args = SFTConfig(..., packing=True, max_length=512) ``` @@ -93,3 +93,41 @@ training_args = SFTConfig(..., packing=True, max_seq_length=512) Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230). + +## Disabling model gathering for generation in online methods + +When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204). + +If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter: + + + + +```python +from trl import OnlineDPOConfig + +training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import PPOConfig + +training_args = PPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig(..., ds3_gather_for_generation=False) +``` + + + + +This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds. diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.md similarity index 100% rename from docs/source/reward_trainer.mdx rename to docs/source/reward_trainer.md diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 3ef57a3dc6..5ad5eca53c 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -2,7 +2,7 @@ [![](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl) -TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL. +TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, whereas PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL. References: - [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740) @@ -58,7 +58,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an ## Cookbook * Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. -* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it. +* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. * Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. diff --git a/docs/source/sentiment_tuning.mdx b/docs/source/sentiment_tuning.md similarity index 100% rename from docs/source/sentiment_tuning.mdx rename to docs/source/sentiment_tuning.md diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.md similarity index 98% rename from docs/source/sft_trainer.mdx rename to docs/source/sft_trainer.md index be009f448e..5c30b744fa 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.md @@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer dataset = load_dataset("stanfordnlp/imdb", split="train") training_args = SFTConfig( - max_seq_length=512, + max_length=512, output_dir="/tmp", ) trainer = SFTTrainer( @@ -29,7 +29,7 @@ trainer = SFTTrainer( ) trainer.train() ``` -Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. +Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. You can also construct a model outside of the trainer and pass it as follows: @@ -550,12 +550,12 @@ import torch from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel -max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number +max_length = 2048 # Supports automatic RoPE Scaling, so choose any number # Load model model, tokenizer = FastLanguageModel.from_pretrained( model_name="unsloth/mistral-7b", - max_seq_length=max_seq_length, + max_seq_length=max_length, dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf @@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model( random_state=3407, ) -training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length) +training_args = SFTConfig(output_dir="./output", max_length=max_length) trainer = SFTTrainer( model=model, @@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith Pay attention to the following best practices when training a model with that trainer: -- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. +- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. - For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index f47f1b2907..83d14cb5a2 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -8,14 +8,21 @@ Section under construction. Feel free to contribute! ## vLLM for fast generation in online methods -Online methods such as Online DPO or Nash-MD require the model to generate completions, which is often a slow process and can significantly impact training time. -To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. +Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. +To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. + +To use [vLLM](https://github.com/vllm-project/vllm), first install it using: -To use vLLM, first install it using: ```bash pip install vllm ``` +or + +```bash +pip install "trl[vllm]" +``` + @@ -24,7 +31,44 @@ Then, enable it by passing `use_vllm=True` in the training arguments. ```python from trl import OnlineDPOConfig -training_args = DPOConfig(..., use_vllm=True) +training_args = OnlineDPOConfig(..., use_vllm=True) +``` + + + + +Then, enable it by passing `use_vllm=True` in the training arguments. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_vllm=True) +``` + +The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training. + + + +When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes `. + +For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation. +```bash +accelerate launch --multi_gpu --num_processes 3 train_grpo.py +``` + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png) + + + +You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`]. + +```python +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_device="cuda:4", + vllm_gpu_memory_utilization=0.7, +) ``` diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index c7b0bd0cfd..15807f576d 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -174,7 +174,7 @@ With these attributes you can reconstruct every interaction of the model with th ### Visualization -When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods). +When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods). You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`: diff --git a/docs/source/using_llama_models.mdx b/docs/source/using_llama_models.md similarity index 100% rename from docs/source/using_llama_models.mdx rename to docs/source/using_llama_models.md diff --git a/docs/source/xpo_trainer.mdx b/docs/source/xpo_trainer.md similarity index 99% rename from docs/source/xpo_trainer.mdx rename to docs/source/xpo_trainer.md index 07a76f36dc..4501aaf68b 100644 --- a/docs/source/xpo_trainer.mdx +++ b/docs/source/xpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data. +Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data. The abstract from the paper is the following: diff --git a/examples/datasets/hh-rlhf-helpful-base.py b/examples/datasets/hh-rlhf-helpful-base.py index 44966f917e..98a225c8ec 100644 --- a/examples/datasets/hh-rlhf-helpful-base.py +++ b/examples/datasets/hh-rlhf-helpful-base.py @@ -110,7 +110,7 @@ def extract_dialogue(example: str) -> list[dict[str, str]]: - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The user query. +- `"prompt"`: The user query. - `"chosen"`: A response deemed helpful by human evaluators. - `"rejected"`: A response considered less helpful or unhelpful. diff --git a/examples/datasets/lm-human-preferences-descriptiveness.py b/examples/datasets/lm-human-preferences-descriptiveness.py index b836fcc6d5..7515b77373 100644 --- a/examples/datasets/lm-human-preferences-descriptiveness.py +++ b/examples/datasets/lm-human-preferences-descriptiveness.py @@ -82,7 +82,7 @@ def to_prompt_completion(example, tokenizer): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The text sample. +- `"prompt"`: The text sample. - `"chosen"`: A version of the text with enhanced descriptiveness. - `"rejected"`: A version of the text with less descriptiveness. diff --git a/examples/datasets/lm-human-preferences-sentiment.py b/examples/datasets/lm-human-preferences-sentiment.py index 198469c9e0..da411742ba 100644 --- a/examples/datasets/lm-human-preferences-sentiment.py +++ b/examples/datasets/lm-human-preferences-sentiment.py @@ -77,7 +77,7 @@ def to_prompt_completion(example, tokenizer): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The text sample. +- `"prompt"`: The text sample. - `"chosen"`: A version of the text that conveys the desired sentiment. - `"rejected"`: A version of the text that does not convey the desired sentiment. diff --git a/examples/datasets/math_shepherd.py b/examples/datasets/math_shepherd.py index 5dbd5ab7ea..47a28f0a30 100644 --- a/examples/datasets/math_shepherd.py +++ b/examples/datasets/math_shepherd.py @@ -141,7 +141,7 @@ def process_example(example): - **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision) Columns: -- `"pompt"`: The problem statement. +- `"prompt"`: The problem statement. - `"completions"`: A list of reasoning steps generated to solve the problem. - `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step. diff --git a/examples/datasets/prm800k.py b/examples/datasets/prm800k.py index c859272909..631fc89d24 100644 --- a/examples/datasets/prm800k.py +++ b/examples/datasets/prm800k.py @@ -115,7 +115,7 @@ def process_batch(examples): - **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision) Columns: -- `"pompt"`: The problem statement. +- `"prompt"`: The problem statement. - `"completions"`: A list of reasoning steps generated to solve the problem. - `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step. diff --git a/examples/datasets/rlaif-v.py b/examples/datasets/rlaif-v.py index 9548daa6b8..b867d6ed68 100644 --- a/examples/datasets/rlaif-v.py +++ b/examples/datasets/rlaif-v.py @@ -77,7 +77,7 @@ def to_conversational(example): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The task related to the image. +- `"prompt"`: The task related to the image. - `"images"`: The image. - `"chosen"`: The preferred answer. - `"rejected"`: An alternative answer that was not preferred. diff --git a/examples/datasets/tldr.py b/examples/datasets/tldr.py index 0fc27bd8c8..1f14943594 100644 --- a/examples/datasets/tldr.py +++ b/examples/datasets/tldr.py @@ -72,7 +72,7 @@ def to_prompt_completion(example): - **Type**: [Prompt-completion](https://huggingface.co/docs/trl/main/dataset_formats#prompt-completion) Columns: -- `"pompt"`: The unabridged Reddit post. +- `"prompt"`: The unabridged Reddit post. - `"completion"`: The concise "TL;DR" summary appended by the author. This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities. diff --git a/examples/datasets/tldr_preference.py b/examples/datasets/tldr_preference.py index f6c05f8e27..3de9a557a9 100644 --- a/examples/datasets/tldr_preference.py +++ b/examples/datasets/tldr_preference.py @@ -83,7 +83,7 @@ def to_preference(example): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The unabridged Reddit post. +- `"prompt"`: The unabridged Reddit post. - `"chosen"`: The concise "TL;DR" summary appended by the author. - `"rejected"`: An alternative summary or response that was not selected. diff --git a/examples/datasets/ultrafeedback-prompt.py b/examples/datasets/ultrafeedback-prompt.py index 7c218ee786..3b49ccc1a0 100644 --- a/examples/datasets/ultrafeedback-prompt.py +++ b/examples/datasets/ultrafeedback-prompt.py @@ -77,11 +77,11 @@ def drop_long_prompt(example): - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only) Column: -- `"pompt"`: The input question or instruction provided to the model. +- `"prompt"`: The input question or instruction provided to the model. ## Generation script -The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultafeedback-prompt.py). +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback-prompt.py). """) if __name__ == "__main__": diff --git a/examples/datasets/ultrafeedback.py b/examples/datasets/ultrafeedback.py index 49e4e2cc0c..a7d9a28c2f 100644 --- a/examples/datasets/ultrafeedback.py +++ b/examples/datasets/ultrafeedback.py @@ -112,13 +112,13 @@ def to_unpaired_preference(example, model_name, aspect): - **Type**: [Unpaired preference](https://huggingface.co/docs/trl/main/dataset_formats#unpaired-preference) Column: -- `"pompt"`: The input question or instruction provided to the model. +- `"prompt"`: The input question or instruction provided to the model. - `"completion"`: The model's response to the prompt. - `"label"`: A binary value indicating whether the response is sufficiently helpful. ## Generation script -The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultafeedback.py). +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback.py). """) if __name__ == "__main__": diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index 3ae1e82c2a..1f4611a3e8 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -185,7 +185,7 @@ def create_datasets(tokenizer, args, seed=None): train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, - max_seq_length=None, + max_length=None, formatting_func=prepare_sample_text, processing_class=tokenizer, args=training_args, diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 412fa85988..4e667aea0b 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -43,6 +43,7 @@ Idefics2ForConditionalGeneration, LlamaConfig, LlamaForCausalLM, + LlamaForSequenceClassification, LlavaConfig, LlavaForConditionalGeneration, LlavaNextConfig, @@ -57,6 +58,7 @@ Phi3ForCausalLM, Qwen2Config, Qwen2ForCausalLM, + Qwen2ForSequenceClassification, SiglipVisionConfig, T5Config, T5ForConditionalGeneration, @@ -131,6 +133,7 @@ def push_to_hub(model, tokenizer, prefix=None, suffix=None): model = model_class(config) push_to_hub(model, tokenizer, "tiny", suffix) + # A slightly bigger model, required for vLLM testing tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct") config = Qwen2Config( @@ -144,6 +147,26 @@ def push_to_hub(model, tokenizer, prefix=None, suffix=None): model = Qwen2ForCausalLM(config) push_to_hub(model, tokenizer, "small", "2.5") + +# Reward models +for model_id, config_class, model_class, suffix in [ + ("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"), + ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"), +]: + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = config_class( + vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + hidden_size=8, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=32, + num_labels=1, + ) + model = model_class(config) + push_to_hub(model, tokenizer, "tiny", suffix) + + # Encoder-decoder models for model_id, config_class, model_class, suffix in [ ("google/flan-t5-small", T5Config, T5ForConditionalGeneration, None), diff --git a/setup.py b/setup.py index b74aeaf326..6382cf4c21 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py -To create the package for pypi. +To create the package for PyPI. 0. Prerequisites: - Dependencies: @@ -50,7 +50,7 @@ For the sources, run: "python setup.py sdist" You should now have a /dist directory with both .whl and .tar.gz source versions. -5. Check that everything looks correct by uploading the package to the pypi test server: +5. Check that everything looks correct by uploading the package to the PyPI test server: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ @@ -59,7 +59,7 @@ pip install -U tqdm pip install -i https://testpypi.python.org/pypi evaluate -6. Upload the final version to actual pypi: +6. Upload the final version to actual PyPI: twine upload dist/* -r pypi 7. Fill release notes in the tag in github once everything is looking hunky-dory. @@ -71,7 +71,7 @@ from setuptools import find_packages, setup -__version__ = "0.14.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.16.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "accelerate>=0.34.0", @@ -85,13 +85,13 @@ "diffusers": ["diffusers>=0.18.0"], "judges": ["openai>=1.23.2", "llm-blender>=0.0.2"], # liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility - "liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"], + "liger": ["liger-kernel>=0.5.3; sys_platform != 'win32'"], "mergekit": ["mergekit>=0.0.5.1"], "peft": ["peft>=0.8.0"], "quantization": ["bitsandbytes"], "scikit": ["scikit-learn"], "test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"], - "vllm": ["vllm; sys_platform != 'win32'"], # vllm is not available on Windows + "vllm": ["vllm>=0.7.2; sys_platform != 'win32'"], # vllm is not available on Windows "vlm": ["Pillow"], } EXTRAS["dev"] = [] @@ -124,7 +124,7 @@ package_data={ "trl": ["templates/*.md"], }, - packages=find_packages(exclude={"tests", "tests.slow"}), + packages=find_packages(exclude={"tests", "tests.slow", "trl.templates"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, python_requires=">=3.9", diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 74811d092c..8a772a48aa 100644 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase): def setUp(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") - self.max_seq_length = 128 + self.max_length = 128 self.peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, @@ -74,7 +74,7 @@ def test_sft_trainer_str(self, model_name, packing): per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) trainer = SFTTrainer( @@ -100,7 +100,7 @@ def test_sft_trainer_transformers(self, model_name, packing): per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -135,7 +135,7 @@ def test_sft_trainer_peft(self, model_name, packing): max_steps=10, fp16=True, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -172,7 +172,7 @@ def test_sft_trainer_transformers_mp(self, model_name, packing): max_steps=10, fp16=True, # this is sufficient to enable amp packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -205,7 +205,7 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -242,7 +242,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -286,7 +286,7 @@ def test_sft_trainer_transformers_mp_gc_device_map( per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -324,7 +324,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -364,7 +364,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): training_args = SFTConfig( packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, output_dir=tmp_dir, logging_strategy="no", report_to="none", @@ -411,7 +411,7 @@ def test_sft_trainer_with_liger(self, model_name, packing): per_device_train_batch_size=2, max_steps=2, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, use_liger=True, ) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index eac36ebe02..cf056f3e03 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -341,7 +341,7 @@ def test_callback(self): model=self.model, args=training_args, train_dataset=self.dataset, - tokenizer=self.tokenizer, + processing_class=self.tokenizer, callbacks=[merge_callback], ) trainer.train() @@ -364,7 +364,7 @@ def test_every_checkpoint(self): model=self.model, args=training_args, train_dataset=self.dataset, - tokenizer=self.tokenizer, + processing_class=self.tokenizer, callbacks=[merge_callback], ) trainer.train() diff --git a/tests/test_cli.py b/tests/test_cli.py index 5d07e289a7..6b00ed1ed0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,16 +13,22 @@ # limitations under the License. +import sys import tempfile import unittest from io import StringIO from unittest.mock import patch -from trl.cli import main - +@unittest.skipIf( + sys.version_info < (3, 10), + "Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " + "to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche +) class TestCLI(unittest.TestCase): def test_dpo(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none" with patch("sys.argv", command.split(" ")): @@ -30,18 +36,32 @@ def test_dpo(self): @patch("sys.stdout", new_callable=StringIO) def test_env(self, mock_stdout): + from trl.cli import main + command = "trl env" with patch("sys.argv", command.split(" ")): main() self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) + def test_grpo(self): + from trl.cli import main + + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 4 --max_completion_length 32 --report_to none" + with patch("sys.argv", command.split(" ")): + main() + def test_kto(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_unpaired_preference --report_to none" with patch("sys.argv", command.split(" ")): main() def test_sft(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_language_modeling --report_to none" with patch("sys.argv", command.split(" ")): diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index f95e37fc84..20c8614f68 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -24,8 +24,10 @@ extract_prompt, is_conversational, maybe_apply_chat_template, + maybe_convert_to_chatml, maybe_extract_prompt, maybe_unpair_preference_dataset, + pack_examples, unpair_preference_dataset, ) @@ -41,7 +43,7 @@ class IsConversationalTester(unittest.TestCase): { # Prompt only "prompt": [{"role": "user", "content": "What color is the sky?"}], }, - { # Pompt-completion + { # Prompt-completion "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], }, @@ -110,7 +112,7 @@ class ApplyChatTemplateTester(unittest.TestCase): { # Prompt only "prompt": [{"role": "user", "content": "What color is the sky?"}], }, - { # Pompt-completion + { # Prompt-completion "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], }, @@ -153,7 +155,7 @@ def test_apply_chat_template(self, tokenizer_id, example): # Checking if the result is a dictionary self.assertIsInstance(result, dict) - # The chat template should be applied to the the following keys + # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: self.assertIn(key, result) @@ -179,7 +181,7 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example): # Checking if the result is a dictionary self.assertIsInstance(result, dict) - # The chat template should be applied to the the following keys + # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: self.assertIn(key, result) @@ -392,6 +394,93 @@ def test_maybe_extract_prompt_standard_already_explicit(self): ) +class TestPackExamples(unittest.TestCase): + def test_pack_examples_larger_chunks(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + seq_length = 5 + expected_output = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7, 8]], + "attention_mask": [[0, 1, 1, 0, 0], [1, 1, 1]], + } + result = pack_examples(examples, seq_length) + self.assertEqual(result, expected_output) + + def test_pack_examples_smaller_chunks(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + seq_length = 2 + expected_output = { + "input_ids": [[1, 2], [3, 4], [5, 6], [7, 8]], + "attention_mask": [[0, 1], [1, 0], [0, 1], [1, 1]], + } + result = pack_examples(examples, seq_length) + self.assertEqual(result, expected_output) + + def test_pack_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": seq_length}) + self.assertEqual(dataset.to_dict(), expected_output) + + +class TestMaybeConvertToChatML(unittest.TestCase): + def test_with_conversations_key(self): + # Particular case where the key is "conversations": we rename it to "messages" + example = { + "conversations": [ + {"from": "user", "value": "What color is the sky?"}, + {"from": "assistant", "value": "It is blue."}, + ] + } + expected_output = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + self.assertEqual(maybe_convert_to_chatml(example), expected_output) + + def test_without_conversations_key(self): + # Same as before, but we don't rename the keys + example = { + "prompt": [{"from": "user", "value": "What color is the sky?"}], + "completion": [{"from": "assistant", "value": "It is blue."}], + } + expected_output = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + } + self.assertEqual(maybe_convert_to_chatml(example), expected_output) + + def test_not_conversional(self): + # When not needed, the example should remain unchanged + example = {"text": "The sky is blue."} + self.assertEqual(maybe_convert_to_chatml(example), example) + + def test_already_chatml(self): + # When the example is already in ChatML format, it should remain unchanged + example = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + self.assertEqual(maybe_convert_to_chatml(example), example) + + # Run the tests if __name__ == "__main__": unittest.main() diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6200ef8f91..e5c6b1b08a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -265,7 +265,7 @@ def test_dpo_trainer_with_weighting(self): model=self.model, ref_model=self.ref_model, args=training_args, - tokenizer=self.tokenizer, + processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], ) @@ -1070,7 +1070,7 @@ def test_dpo_loss_js_div_f(self): ) self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) - def test_dpo_trainer_use_num_logits_to_keep(self): + def test_dpo_trainer_use_logits_to_keep(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token @@ -1087,7 +1087,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): learning_rate=9e-1, eval_strategy="steps", beta=0.1, - use_num_logits_to_keep=True, + use_logits_to_keep=True, rpo_alpha=0.5, report_to="none", ) @@ -1104,7 +1104,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): eval_dataset=dummy_dataset["test"], ) - training_args.use_num_logits_to_keep = False + training_args.use_logits_to_keep = False trainer2 = DPOTrainer( model=model, ref_model=None, @@ -1152,11 +1152,48 @@ def test_dpo_trainer_use_num_logits_to_keep(self): trainer.train() - def test_padding_free(self): + def test_dpo_trainer_with_tools(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + tools=[get_current_temperature], + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference") + + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + # We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When + # pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's + # what we're checking here + self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) + + def test_padding_free(self): + model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token # Normally, we need `attn_implementation="flash_attention_2"` to that the model returns correct logits. # Without it, the logits may be incorrect, but that's fine here. This test focuses only on the inner logic # of padding_free. diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py new file mode 100644 index 0000000000..5f58d69ca7 --- /dev/null +++ b/tests/test_grpo_trainer.py @@ -0,0 +1,616 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import tempfile +import unittest + +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft, require_torch_accelerator +from transformers.utils import is_peft_available + +from trl import GRPOConfig, GRPOTrainer +from trl.import_utils import is_vllm_available + + +if is_peft_available(): + from peft import LoraConfig + + +class GRPOTrainerTester(unittest.TestCase): + def test_init_minimal(self): + # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + ) + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training(self, config_name): + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_with_eval(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + eval_strategy="steps", + eval_steps=2, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + trainer.train() + + @require_peft + def test_training_peft(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + + def test_training_different_reward_model(self): + # Use a reward model different from the model: different chat template, tokenization, etc. + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" + reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id) + reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id) + # By default, the trainer uses the eos token as the padding token. However, for Llama models, the eos token + # appears in the chat template. Using it as a pad token disrupts the reward calculation, as the calculation + # considers the score of the last token before the first pad token. To ensure correct reward calculations, + # we use a separate pad token instead. + reward_tokenizer.pad_token = "<|finetune_right_pad_id|>" + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_model, + args=training_args, + train_dataset=dataset, + reward_processing_classes=reward_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_standard(self): + # Test if trainer can handle reward function with standard format + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_conversational(self): + # Test if trainer can handle reward function with conversational format + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that gives higher scores to longer completion content.""" + completion_contents = [completion[0]["content"] for completion in completions] + return [float(len(content)) for content in completion_contents] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_reward_funcs(self): + # Test that GRPOTrainer can be instantiated with multiple reward functions + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_reward_funcs_with_weights(self): + """Test that GRPOTrainer can handle multiple reward functions with weights.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + reward_weights=[0.7, 0.3], # weight of reward_func1 and reward_func2 respectively + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training logs contain both reward metrics + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIn("rewards/reward_func1", trainer.state.log_history[-1]) + self.assertIn("rewards/reward_func2", trainer.state.log_history[-1]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_mixed_reward_funcs(self): + # Test if the trainer can handle a mix of reward functions and reward models + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_additional_column(self): + # Test if trainer can handle reward function that rely on additional columns in the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Add a column to the dataset (dummy example, the column could be anything) + some_values = list(range(len(dataset))) + dataset = dataset.add_column("some_values", some_values) + + def reward_func(completions, some_values, **kwargs): + """Reward function that rewards completions with lengths closer to the values in some_values.""" + return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") + @require_torch_accelerator + def test_training_vllm(self): + """Test that training works with vLLM for generation.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU + vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") # compiling seems to be broken on Windows + def test_training_torch_compile(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + torch_compile=True, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_with_sync_ref_model(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + sync_ref_model=True, + ref_model_sync_steps=2, # reduce sync steps to ensure a sync happens + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_beta_zero_no_ref_model_and_no_kl(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + beta=0.0, # set beta to 0 to test the case where the reference model is not used + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") + @require_torch_accelerator + @require_peft + def test_training_vllm_and_peft(self): + """Test that training works with vLLM for generation.""" + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU + vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm + ) + lora_config = LoraConfig( + target_modules="all-linear", + # test with non-default modules as it add extra keys in state_dict tht we need to handle + modules_to_save=["embed_tokens", "lm_head"], + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + elif "base_layer" not in n and "original_module" not in n: + # We expect the peft params to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") + @require_torch_accelerator + def test_training_vllm_guided_decoding(self): + """Test that training works with vLLM for generation with guided decoding.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU + vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n", + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 7de90a3b24..dcafe7829e 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -172,3 +172,42 @@ def test_rloo_training(self): # Check if objective/rlhf_reward is available self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1]) + + def test_rloo_training_with_custom_reward(self): + # dummy reward function + def reward_function(texts): + # based on length of text + rewards = [len(text) for text in texts] + return rewards + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RLOOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + total_episodes=1, + num_train_epochs=1, + max_steps=2, + report_to="none", + ) + + # Create a simple dataset + dummy_text = [{"content": "Hello World!", "role": "user"}] + dummy_data = self.tokenizer.apply_chat_template(dummy_text) + dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]}) + + trainer = RLOOTrainer( + config=training_args, + policy=self.policy_model, + reward_model=reward_function, + ref_policy=self.policy_ref_model, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Test that training completes without errors + trainer.train() + + # Check if objective/rlhf_reward is available + self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1]) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 1724ff4c13..1a26378f3f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -53,7 +53,7 @@ def formatting_prompts_func_batched(example): if is_peft_available(): - from peft import LoraConfig, PeftModel + from peft import LoraConfig, PeftModel, get_peft_model if is_vision_available(): from PIL import Image as PILImage @@ -288,7 +288,7 @@ def test_sft_trainer(self): self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2")) - def test_sft_trainer_with_pretokenzied_data_packing(self): + def test_sft_trainer_with_pretokenized_data_packing(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = SFTConfig( output_dir=tmp_dir, @@ -326,8 +326,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=32, # make sure there is at least 1 packed sequence - num_of_sequences=32, + max_length=32, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -354,7 +353,7 @@ def test_sft_trainer_uncorrect_data(self): train_dataset=self.conversational_lm_dataset["train"], ) - # Same, but with packing with `max_seq_length` + # Same, but with packing with `max_length` training_args = SFTConfig( output_dir=tmp_dir, dataloader_drop_last=True, @@ -362,7 +361,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -397,7 +396,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=32, # make sure there is at least 1 packed sequence + max_length=32, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -408,45 +407,6 @@ def test_sft_trainer_uncorrect_data(self): formatting_func=formatting_prompts_func, ) - # This should not work because not enough data for one sample - training_args = SFTConfig( - output_dir=tmp_dir, - dataloader_drop_last=True, - max_steps=2, - eval_steps=1, - save_steps=1, - per_device_train_batch_size=2, - max_seq_length=1024, # make sure there is NOT at least 1 packed sequence - packing=True, - report_to="none", - ) - with self.assertRaises(ValueError): - _ = SFTTrainer( - model=self.model, - args=training_args, - train_dataset=self.dummy_dataset, - formatting_func=formatting_prompts_func, - ) - - # This should not work as well - with self.assertRaises(ValueError): - training_args = SFTConfig( - output_dir=tmp_dir, - dataloader_drop_last=True, - max_steps=2, - eval_steps=1, - save_steps=1, - per_device_train_batch_size=2, - packing=False, - report_to="none", - ) - _ = SFTTrainer( - model=self.model, - args=training_args, - train_dataset=self.dummy_dataset, - formatting_func=formatting_prompts_func, - ) - # but this should work training_args = SFTConfig( output_dir=tmp_dir, @@ -501,8 +461,7 @@ def test_sft_trainer_with_model_num_train_epochs(self): save_steps=1, num_train_epochs=2, per_device_train_batch_size=2, - max_seq_length=16, - num_of_sequences=16, + max_length=16, packing=True, report_to="none", ) @@ -526,7 +485,7 @@ def test_sft_trainer_with_model_num_train_epochs(self): save_steps=1, num_train_epochs=2, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -575,8 +534,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, - num_of_sequences=16, + max_length=16, packing=True, report_to="none", ) @@ -600,8 +558,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, - num_of_sequences=16, + max_length=16, packing=True, report_to="none", ) @@ -626,7 +583,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -649,7 +606,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -798,7 +755,7 @@ def test_sft_trainer_infinite_with_model(self): save_steps=1, per_device_train_batch_size=2, packing=True, - max_seq_length=500, + max_length=500, report_to="none", ) trainer = SFTTrainer( @@ -808,8 +765,6 @@ def test_sft_trainer_infinite_with_model(self): eval_dataset=self.eval_dataset, ) - self.assertTrue(trainer.train_dataset.infinite) - trainer.train() self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) @@ -827,7 +782,7 @@ def test_sft_trainer_infinite_with_model_epochs(self): per_device_train_batch_size=2, save_strategy="epoch", packing=True, - max_seq_length=500, + max_length=500, report_to="none", ) trainer = SFTTrainer( @@ -837,8 +792,6 @@ def test_sft_trainer_infinite_with_model_epochs(self): eval_dataset=self.eval_dataset, ) - self.assertFalse(trainer.train_dataset.infinite) - trainer.train() self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) @@ -1135,7 +1088,7 @@ def test_sft_trainer_only_train_packing(self): per_device_train_batch_size=2, gradient_checkpointing=True, packing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence eval_packing=False, report_to="none", ) @@ -1161,7 +1114,7 @@ def test_sft_trainer_eval_packing(self): save_steps=2, per_device_train_batch_size=2, gradient_checkpointing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -1186,7 +1139,7 @@ def test_sft_trainer_no_packing(self): save_steps=2, per_device_train_batch_size=2, gradient_checkpointing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=False, report_to="none", ) @@ -1345,6 +1298,137 @@ def test_sft_trainer_torch_dtype(self): ) self.assertIn( - "Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.", + "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + "a `torch.dtype` (e.g., 'float32'), but got -1.", str(context.exception), ) + + +# This new tester aims to replace the first one at some point +class SFTTrainerTester2(unittest.TestCase): + def test_train(self): + # Get the model and dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif ( + "base_layer" not in n + ): # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_non_chatml_conversational_data(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Rename role/content to from/value to ensure SFT works with non-chatML conversational data + def rename_fields(example: list[dict]): + return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} + + dataset = dataset.map(rename_fields, remove_columns="messages") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_sft_trainer_with_pretokenized_data(self): + # Get the model and dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + def tokenize_example(example): + return tokenizer(example["text"]) + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 3aff81fd3f..406eba4f86 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -153,7 +153,6 @@ def test_dpo(self): max_length=256, max_prompt_length=64, max_completion_length=64, - is_encoder_decoder=True, disable_dropout=False, # generate_during_eval=True, # ignore this one, it requires wandb precompute_ref_log_probs=True, @@ -188,7 +187,6 @@ def test_dpo(self): self.assertEqual(trainer.args.max_length, 256) self.assertEqual(trainer.args.max_prompt_length, 64) self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.is_encoder_decoder, True) self.assertEqual(trainer.args.disable_dropout, False) # self.assertEqual(trainer.args.generate_during_eval, True) self.assertEqual(trainer.args.precompute_ref_log_probs, True) @@ -370,20 +368,18 @@ def test_sft(self): tmp_dir, dataset_text_field="dummy_text_field", packing=True, - max_seq_length=256, + max_length=256, dataset_num_proc=4, dataset_batch_size=512, neftune_noise_alpha=0.1, model_init_kwargs={"trust_remote_code": True}, dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True}, eval_packing=True, - num_of_sequences=32, - chars_per_token=4.2, ) trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") self.assertEqual(trainer.args.packing, True) - self.assertEqual(trainer.args.max_seq_length, 256) + self.assertEqual(trainer.args.max_length, 256) self.assertEqual(trainer.args.dataset_num_proc, 4) self.assertEqual(trainer.args.dataset_batch_size, 512) self.assertEqual(trainer.args.neftune_noise_alpha, 0.1) @@ -391,8 +387,6 @@ def test_sft(self): self.assertIn("append_concat_token", trainer.args.dataset_kwargs) self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True) self.assertEqual(trainer.args.eval_packing, True) - self.assertEqual(trainer.args.num_of_sequences, 32) - self.assertEqual(trainer.args.chars_per_token, 4.2) @parameterized.expand([(False,), (True,)]) def test_xpo(self, alpha_list): diff --git a/tests/test_utils.py b/tests/test_utils.py index 4a1227475c..871ec6c737 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ import numpy as np import torch from datasets import load_dataset +from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available @@ -27,9 +28,11 @@ DataCollatorForChatML, batch_generation, decode_and_strip_padding, + flush_left, generate_model_card, get_peft_config, pad, + selective_log_softmax, ) @@ -404,3 +407,70 @@ def test_rewards_comparison_task(self): "These instances are ignored in the accuracy computation." ) self.assertEqual(str(cm.warning), expected_warning) + + +class TestFlushLeft(unittest.TestCase): + def test_basic_case(self): + mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) + tensor1 = torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 0, 0]]) + tensor2 = torch.tensor([[0, 0, 7, 8, 9], [0, 10, 11, 0, 0]]) + new_mask, new_tensor1, new_tensor2 = flush_left(mask, tensor1, tensor2) + + expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) + expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]]) + expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + + def test_single_row(self): + mask = torch.tensor([[0, 0, 1, 1]]) + tensor1 = torch.tensor([[0, 0, 2, 3]]) + new_mask, new_tensor1 = flush_left(mask, tensor1) + + expected_mask = torch.tensor([[1, 1]]) + expected_tensor1 = torch.tensor([[2, 3]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + + def test_no_shift_needed(self): + mask = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0]]) + tensor1 = torch.tensor([[5, 6, 0, 0], [7, 8, 0, 0]]) + new_mask, new_tensor1 = flush_left(mask, tensor1) + + expected_mask = torch.tensor([[1, 1], [1, 1]]) + expected_tensor1 = torch.tensor([[5, 6], [7, 8]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + + def test_no_tensors(self): + mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) + new_mask = flush_left(mask) + + expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestSelectiveLogSoftmax(unittest.TestCase): + @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) + def test_selective_log_softmax(self, dtype): + """Test selective_log_softmax with logits of different dtypes""" + vocab_size = 1024 + batch_size = 4 + seq_len = 32 + + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) + + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + actual_output = selective_log_softmax(logits, input_ids) + + if dtype in [torch.float16, torch.bfloat16]: + # half-precision dtypes fall back to an exact method + self.assertTrue(torch.equal(actual_output, expected_output)) + else: + torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) diff --git a/trl/__init__.py b/trl/__init__.py index 692aba0a56..9a17e8e873 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.14.0.dev0" +__version__ = "0.16.0.dev0" from typing import TYPE_CHECKING @@ -26,8 +26,10 @@ "extract_prompt", "is_conversational", "maybe_apply_chat_template", + "maybe_convert_to_chatml", "maybe_extract_prompt", "maybe_unpair_preference_dataset", + "pack_examples", "unpair_preference_dataset", ], "environment": ["TextEnvironment", "TextHistory"], @@ -68,6 +70,8 @@ "FDivergenceType", "GKDConfig", "GKDTrainer", + "GRPOConfig", + "GRPOTrainer", "HfPairwiseJudge", "IterativeSFTTrainer", "KTOConfig", @@ -123,8 +127,10 @@ extract_prompt, is_conversational, maybe_apply_chat_template, + maybe_convert_to_chatml, maybe_extract_prompt, maybe_unpair_preference_dataset, + pack_examples, unpair_preference_dataset, ) from .environment import TextEnvironment, TextHistory @@ -166,6 +172,8 @@ FDivergenceType, GKDConfig, GKDTrainer, + GRPOConfig, + GRPOTrainer, HfPairwiseJudge, IterativeSFTTrainer, KTOConfig, diff --git a/trl/cli.py b/trl/cli.py index 9377defa75..a86ec709d7 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -21,6 +21,7 @@ from .scripts.chat import make_parser as make_chat_parser from .scripts.dpo import make_parser as make_dpo_parser from .scripts.env import print_env +from .scripts.grpo import make_parser as make_grpo_parser from .scripts.kto import make_parser as make_kto_parser from .scripts.sft import make_parser as make_sft_parser from .scripts.utils import TrlParser @@ -36,6 +37,7 @@ def main(): make_chat_parser(subparsers) make_dpo_parser(subparsers) subparsers.add_parser("env", help="Print the environment information") + make_grpo_parser(subparsers) make_kto_parser(subparsers) make_sft_parser(subparsers) @@ -58,6 +60,15 @@ def main(): elif args.command == "env": print_env() + elif args.command == "grpo": + # Get the default args for the launch command + grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py") + args = launch_command_parser().parse_args([grpo_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "grpo" + launch_command(args) # launch training + elif args.command == "kto": # Get the default args for the launch command kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py") diff --git a/trl/data_utils.py b/trl/data_utils.py index 35332fd45f..bb0891d001 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool: dataset type. Returns: - `bool`: `True` if the data is in a conversational format, `False` otherwise. + `bool`: + `True` if the data is in a conversational format, `False` otherwise. Examples: @@ -90,8 +91,21 @@ def apply_chat_template( # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: + last_role = example["prompt"][-1]["role"] + if last_role == "user": + add_generation_prompt = True + continue_final_message = False + elif last_role == "assistant": + add_generation_prompt = False + continue_final_message = True + else: + raise ValueError(f"Invalid role in the last message: {last_role}") prompt = tokenizer.apply_chat_template( - example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True + example["prompt"], + tools=tools, + continue_final_message=continue_final_message, + tokenize=False, + add_generation_prompt=add_generation_prompt, ) # Apply the chat template to the entire prompt + completion @@ -172,17 +186,21 @@ def maybe_apply_chat_template( For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of messages, where each message is a dictionary with keys `"role"` and `"content"`. tokenizer (`PreTrainedTokenizer`): - The tokenizer to apply the chat template with. + Tokenizer to apply the chat template with. tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect Returns: - `dict[str, str]`: The formatted example with the chat template applied. + `dict[str, str]`: + Formatted example with the chat template applied. - Note: - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by - `"text"`. + Notes: + - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced + by `"text"`. + + - In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. + Else, if the last role is `"assistant"`, the final message is continued. Example: @@ -412,3 +430,86 @@ def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]: if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv): return example return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]}) + + +def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, list[list]]: + """ + Pack examples into chunks of size `seq_length`. + + Args: + examples (`dict[str, list[list]]`): + Dictionary of examples with keys as strings and values as lists of lists. + seq_length (`int`): + Maximum sequence length. + + Returns: + `dict[str, list[list]]`: Dictionary of examples with keys as strings and values as lists of lists. + + Example: + + ```python + >>> from trl import pack_examples + >>> examples = { + ... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + ... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + ... } + >>> pack_examples(examples, seq_length=5) + {'input_ids': [[1, 2, 3, 4, 5], [6, 7, 8]], 'attention_mask': [[0, 1, 1, 0, 0], [1, 1, 1]]} + >>> pack_examples(examples, seq_length=2) + {'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]} + ``` + """ + # Join all the values into a single list + examples = {k: sum(v, []) for k, v in examples.items()} + # Split the values into chunks of size seq_length + examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()} + return examples + + +def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]: + """ + Convert a conversational dataset with fields `from` and `value` to ChatML format. + + This function modifies conversational data to align with OpenAI's ChatML format: + - Replaces the key `"from"` with `"role"` in message dictionaries. + - Replaces the key `"value"` with `"content"` in message dictionaries. + - Renames `"conversations"` to `"messages"` for consistency with ChatML. + + Args: + example (`dict[str, list]`): + A single data entry containing a list of messages. + + Returns: + `dict[str, list]`: + Example reformatted to ChatML style. + + Example: + ```python + >>> from trl import maybe_convert_to_chatml + >>> example = { + ... "conversations": [ + ... {"from": "user", "value": "What color is the sky?"}, + ... {"from": "assistant", "value": "It is blue."} + ... ] + ... } + >>> maybe_convert_to_chatml(example) + {'messages': [{'role': 'user', 'content': 'What color is the sky?'}, + {'role': 'assistant', 'content': 'It is blue.'}]} + ``` + """ + # List of possible keys containing message lists + for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]: + if key in example and isinstance(example[key], list): + messages = example[key] + for message in messages: + if isinstance(message, dict): + if "from" in message: + message["role"] = message.pop("from") + if "value" in message: + message["content"] = message.pop("value") + + # Rename "conversations" to "messages" + if "conversations" in example: + example["messages"] = example.pop("conversations") + + return example diff --git a/trl/models/__init__.py b/trl/models/__init__.py index db998369c3..2365e7c1de 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -20,7 +20,7 @@ _import_structure = { "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"], "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], - "utils": ["SUPPORTED_ARCHITECTURES", "setup_chat_format", "unwrap_model_for_generation"], + "utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"], } try: @@ -39,7 +39,7 @@ if TYPE_CHECKING: from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead - from .utils import SUPPORTED_ARCHITECTURES, setup_chat_format, unwrap_model_for_generation + from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation try: if not is_diffusers_available(): diff --git a/trl/models/utils.py b/trl/models/utils.py index 1f4932d21b..0632e025f4 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -14,6 +14,7 @@ import itertools from contextlib import contextmanager +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Optional, Union @@ -36,8 +37,6 @@ from deepspeed.runtime.engine import DeepSpeedEngine from torch.nn.parallel.distributed import DistributedDataParallel - from .modeling_base import PreTrainedModelWrapper - # TODO: Add Abstract Base Class if more formats are added @dataclass @@ -136,6 +135,8 @@ def setup_chat_format( def remove_hooks(model: "DeepSpeedEngine") -> None: """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: @@ -163,6 +164,8 @@ def iter_params(module, recurse=False): def add_hooks(model: "DeepSpeedEngine") -> None: """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: @@ -172,18 +175,73 @@ def add_hooks(model: "DeepSpeedEngine") -> None: @contextmanager def unwrap_model_for_generation( - model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False -) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]: - """Context manager to unwrap a model for generation. - For ZeRO-3 models, we gather the weights once to speed up generation. + model: Union["DistributedDataParallel", "DeepSpeedEngine"], + accelerator: "Accelerator", + gather_deepspeed3_params: bool = True, +): + """ + Context manager to unwrap distributed or accelerated models for generation tasks. + + Args: + model (`Union[DistributedDataParallel, DeepSpeedEngine]`): + Model to be unwrapped. + accelerator (`~accelerate.Accelerator`): + Accelerator instance managing the model. + gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): + Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which + can be more memory-efficient but may lead to slower generation times. + + Yields: + Unwrapped model. + + Example: + ```python + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + generated_outputs = unwrapped_model.generate(input_ids) + ``` """ unwrapped_model = accelerator.unwrap_model(model) - if is_peft_model: - unwrapped_model.pretrained_model.disable_adapter() if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: - with deepspeed.zero.GatheredParameters(model.parameters()): - remove_hooks(model) + if not gather_deepspeed3_params: yield accelerator.unwrap_model(model) - add_hooks(model) + else: + with deepspeed.zero.GatheredParameters(model.parameters()): + remove_hooks(model) + yield accelerator.unwrap_model(model) + add_hooks(model) else: yield unwrapped_model + + +def prepare_deepspeed(model, accelerator): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + stage = config_kwargs["zero_optimization"]["stage"] + + if model is not None: + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and stage == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache + # @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO + # disabled (stage 0) + if stage != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py new file mode 100644 index 0000000000..4b336b28e9 --- /dev/null +++ b/trl/scripts/grpo.py @@ -0,0 +1,92 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config + + +@dataclass +class GRPOScriptArguments(ScriptArguments): + """ + Script arguments for the GRPO training script. + + Args: + reward_model_name_or_path (`str` or `None`): + Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a + directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. + """ + + reward_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " + "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." + }, + ) + + +def main(script_args, training_args, model_args): + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + reward_model = AutoModelForSequenceClassification.from_pretrained( + script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + # Initialize the GRPO trainer + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 2095df3074..764ca3c929 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -84,7 +84,8 @@ def main(script_args, training_args, model_args): tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) - tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token ################ # Dataset diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index f59051320d..85968218cc 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -36,6 +36,8 @@ "dpo_trainer": ["DPOTrainer"], "gkd_config": ["GKDConfig"], "gkd_trainer": ["GKDTrainer"], + "grpo_config": ["GRPOConfig"], + "grpo_trainer": ["GRPOTrainer"], "iterative_sft_trainer": ["IterativeSFTTrainer"], "judges": [ "AllTrueJudge", @@ -105,6 +107,8 @@ from .dpo_trainer import DPOTrainer from .gkd_config import GKDConfig from .gkd_trainer import GKDTrainer + from .grpo_config import GRPOConfig + from .grpo_trainer import GRPOTrainer from .iterative_sft_trainer import IterativeSFTTrainer from .judges import ( AllTrueJudge, diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index ccbb701d9f..db4fa156e4 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -65,6 +65,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -682,6 +683,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) @@ -892,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1057,7 +1065,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 0d567530d5..b4db2ee5f3 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -94,6 +94,10 @@ def _generate_completions( class SyncRefModelCallback(TrainerCallback): + """ + Callback to synchronize the model with a reference model. + """ + def __init__( self, ref_model: Union[PreTrainedModel, torch.nn.Module], diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 050dddad99..174cb4f255 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -60,6 +60,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -357,6 +358,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) @@ -706,7 +712,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 6f98b86b88..09b6e35dea 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any, Callable, Optional, Union from transformers import TrainingArguments @@ -58,7 +58,7 @@ class DPOConfig(TrainingArguments): this flag to `True`. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model and reference model. - use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): + use_logits_to_keep (`bool`, *optional*, defaults to `False`): If `True`, only a specified number of logits are computed in the forward pass. This can be useful for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios when working with very long prompts where labels are ignored (-100). @@ -71,14 +71,15 @@ class DPOConfig(TrainingArguments): Padding value to use. If `None`, the padding value of the tokenizer is used. label_pad_token_id (`int`, *optional*, defaults to `-100`): Padding value to use for labels. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. max_completion_length (`int` or `None`, *optional*, defaults to `None`): Maximum length of the completion. max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the full sequence (prompt + completion). + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. padding_free (`bool`, *optional*, defaults to `False`): Whether forward passes are performed without padding by flattening all sequences in the batch into a single continuous sequence. This approach requires associating a `position_ids` vector to track @@ -93,6 +94,9 @@ class DPOConfig(TrainingArguments): Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation. + tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): + List of tools (callable functions) that will be accessible to the model. + If the template does not support function calling, this argument will have no effect. > Parameters that control the training @@ -194,7 +198,7 @@ class DPOConfig(TrainingArguments): default=True, metadata={"help": "Whether to disable dropout in the model and reference model."}, ) - use_num_logits_to_keep: bool = field( + use_logits_to_keep: bool = field( default=False, metadata={ "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " @@ -216,13 +220,6 @@ class DPOConfig(TrainingArguments): default=-100, metadata={"help": "Padding value to use for labels."}, ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "Truncation mode to use when the prompt is too long.", - "choices": ["keep_end", "keep_start"], - }, - ) max_prompt_length: Optional[int] = field( default=512, metadata={"help": "Maximum length of the prompt."}, @@ -235,6 +232,14 @@ class DPOConfig(TrainingArguments): default=1024, metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " + "and `'keep_start'`.", + "choices": ["keep_end", "keep_start"], + }, + ) padding_free: bool = field( default=False, metadata={ @@ -261,6 +266,13 @@ class DPOConfig(TrainingArguments): "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." }, ) + tools: Optional[list[Union[dict, Callable]]] = field( + default=None, + metadata={ + "help": "List of tools (callable functions) that will be accessible to the model. If the template does " + "not support function calling, this argument will have no effect." + }, + ) # Parameters that control the training learning_rate: float = field( @@ -375,16 +387,18 @@ class DPOConfig(TrainingArguments): ) # Deprecated parameters - is_encoder_decoder: Optional[bool] = field( - default=None, - metadata={"help": "Deprecated. This argument is not used anymore."}, + use_num_logits_to_keep: bool = field( + default=False, + metadata={"help": "Deprecated. Use `use_logits_to_keep` instead."}, ) def __post_init__(self): - if self.is_encoder_decoder is not None: + super().__post_init__() + + if self.use_num_logits_to_keep: warnings.warn( - "The `is_encoder_decoder` parameter is deprecated will be removed in version 0.15. The trainer now " - "automatically determines if the model is an encoder-decoder, so you can safely remove it." + "`use_num_logits_to_keep` is deprecated and will be remove in version 0.17.0. Use " + "`use_logits_to_keep` instead.", + DeprecationWarning, ) - - return super().__post_init__() + self.use_logits_to_keep = self.use_num_logits_to_keep diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index d4b197362c..0346d991fc 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -51,7 +51,6 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_xpu_available -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper, create_reference_model @@ -62,12 +61,14 @@ cap_exp, disable_dropout_in_model, empty_cache, + flush_left, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment, pad, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -200,9 +201,6 @@ class DPOTrainer(Trainer): _tag_names = ["trl", "dpo"] - @deprecate_kwarg( - "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, @@ -387,14 +385,14 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) - self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id self.max_prompt_length = args.max_prompt_length - self.truncation_mode = args.truncation_mode self.max_completion_length = args.max_completion_length + self.max_length = args.max_length + self.truncation_mode = args.truncation_mode self.precompute_ref_log_probs = args.precompute_ref_log_probs - self.use_num_logits_to_keep = args.use_num_logits_to_keep + self.use_logits_to_keep = args.use_logits_to_keep if args.padding_free: if model.config._attn_implementation != "flash_attention_2": @@ -479,6 +477,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) @@ -542,7 +545,9 @@ def _prepare_dataset( # Apply the chat template if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs) + dataset = dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs + ) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` @@ -592,7 +597,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l >>> from transformers import GPT2Tokenizer >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} - >>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False) + >>> DPOTrainer.tokenize_row( + ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False + ... ) {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} ``` """ @@ -812,9 +819,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1139,33 +1148,32 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Flush left to reduce the memory usage # [[0, 0, x, x, x, x], -> [[x, x, x, x], # [0, x, x, x, 0, 0]] [x, x, x, 0]] - for i in range(attention_mask.size(0)): - first_one_idx = torch.nonzero(attention_mask[i])[0].item() - input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) - attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) - loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) - - # Get the first column idx that is all zeros and remove every column after that - empty_cols = torch.sum(attention_mask, dim=0) == 0 - first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) - input_ids = input_ids[:, :first_empty_col] - attention_mask = attention_mask[:, :first_empty_col] - loss_mask = loss_mask[:, :first_empty_col] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - loss_mask = loss_mask[:, : self.args.max_length] + if self.max_length is not None: + if self.truncation_mode == "keep_end": + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + elif self.truncation_mode == "keep_start": + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) - if self.use_num_logits_to_keep: - # Compute num_logits_to_keep based on loss_mask pattern: + if self.use_logits_to_keep: + # Compute logits_to_keep based on loss_mask pattern: # [[0, 0, 0, x, x, x, x], # [0, 0, 0, x, x, x, 0]] # ^ start computing logits from here ([:, -(7-3+1):]) first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label - model_kwargs["num_logits_to_keep"] = num_logits_to_keep + logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label + model_kwargs["logits_to_keep"] = logits_to_keep if self.padding_free: # Flatten the input_ids, position_ids, and loss_mask @@ -1185,15 +1193,15 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to labels = torch.roll(input_ids, shifts=-1, dims=1) loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() - if self.use_num_logits_to_keep: + if self.use_logits_to_keep: # Align labels with logits # logits: -, -, [x2, x3, x4, x5, x6] # ^ --------- ^ after logits[:, :-1, :] # labels: [y0, y1, y2, y3, y4, y5, y6] - # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # ^ --------- ^ with logits_to_keep=4, [:, -4:] # loss_mask: [0, 0, 0, 1, 1, 1, 1] - labels = labels[:, -num_logits_to_keep:] - loss_mask = loss_mask[:, -num_logits_to_keep:] + labels = labels[:, -logits_to_keep:] + loss_mask = loss_mask[:, -logits_to_keep:] if logits.shape[:2] != labels.shape[:2]: # for llava, the returned logits include the image tokens (placed before the text tokens) @@ -1202,7 +1210,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Compute the log probabilities of the labels labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) per_token_logps[~loss_mask] = 0 per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 59d71d1e44..7cfad453f7 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -78,7 +78,6 @@ def __init__( processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), @@ -86,9 +85,9 @@ def __init__( peft_config: Optional["PeftConfig"] = None, formatting_func: Optional[Callable] = None, ): - # add remove_unused_columns=False to the the dataclass args + # add remove_unused_columns=False to the dataclass args args.remove_unused_columns = False - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) super().__init__( model, @@ -97,7 +96,6 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, @@ -158,6 +156,14 @@ def __init__( ): self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + def _prepare_dataset(self, dataset, *args): + # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we + # need to keep the messages column as it is. We use the following workaround to keep the messages column. + dataset = dataset.add_column("_messages", dataset["messages"]) + dataset = super()._prepare_dataset(dataset, *args) + dataset = dataset.rename_column("_messages", "messages") + return dataset + @staticmethod def generalized_jsd_loss( student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py new file mode 100644 index 0000000000..923686276c --- /dev/null +++ b/trl/trainer/grpo_config.py @@ -0,0 +1,261 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class GRPOConfig(TrainingArguments): + r""" + Configuration class for the [`GRPOTrainer`]. + + Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the + [`~transformers.TrainingArguments`] documentation. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`GRPOTrainer`] is provided as a string. + + > Parameters that control the data preprocessing + + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that + requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. + num_generations (`int` or `None`, *optional*, defaults to `8`): + Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) + must be divisible by this value. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum length of the generated completion. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible + with vLLM generation. + + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for + training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). + vllm_device (`str`, *optional*, defaults to `"auto"`): + Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will + automatically select the next available GPU after the last one used for training. This assumes that + training has not already occupied all available GPUs. If only one device is available, the device will be + shared between both training and vLLM. + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the + device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus + improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors + during initialization. + vllm_dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined + based on the model configuration. Find the supported values in the vLLM documentation. + vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`): + If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced + `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model + context size, which might be much larger than the KV cache, leading to inefficiencies. + vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): + Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. + + > Parameters that control the training + + learning_rate (`float`, *optional*, defaults to `1e-6`): + Initial learning rate for [`AdamW`] optimizer. The default value replaces that of + [`~transformers.TrainingArguments`]. + beta (`float`, *optional*, defaults to `0.04`): + KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training + speed. + reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): + Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are + weighted equally with weight `1.0`. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originites from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`): + α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `64`): + τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + + > Parameters that control the logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log the completions during training. + """ + + # Parameters that control the model and reference model + model_init_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `GRPOTrainer` is provided as a string." + }, + ) + + # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on + # additional columns to compute the reward + remove_unused_columns: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " + "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." + }, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." + }, + ) + num_generations: Optional[int] = field( + default=8, + metadata={ + "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " + "must be divisible by this value." + }, + ) + temperature: Optional[float] = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + max_completion_length: Optional[int] = field( + default=256, + metadata={"help": "Maximum length of the generated completion."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + + # Parameters that control generation acceleration powered by vLLM + use_vllm: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept " + "unused for training, as vLLM will require one for generation. vLLM must be installed " + "(`pip install vllm`)." + }, + ) + vllm_device: Optional[str] = field( + default="auto", + metadata={ + "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system " + "will automatically select the next available GPU after the last one used for training. This assumes " + "that training has not already occupied all available GPUs." + }, + ) + vllm_gpu_memory_utilization: float = field( + default=0.9, + metadata={ + "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " + "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " + "size and thus improve the model's throughput. However, if the value is too high, it may cause " + "out-of-memory (OOM) errors during initialization." + }, + ) + vllm_dtype: Optional[str] = field( + default="auto", + metadata={ + "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " + "determined based on the model configuration. Find the supported values in the vLLM documentation." + }, + ) + vllm_max_model_len: Optional[int] = field( + default=None, + metadata={ + "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced " + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " + "context size, which might be much larger than the KV cache, leading to inefficiencies." + }, + ) + vllm_guided_decoding_regex: Optional[str] = field( + default=None, + metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + ) + + # Parameters that control the training + learning_rate: float = field( + default=1e-6, + metadata={ + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." + }, + ) + beta: float = field( + default=0.04, + metadata={ + "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " + "training speed." + }, + ) + reward_weights: Optional[list[float]] = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " + "rewards are weighted equally with weight `1.0`." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.9, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=64, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={"help": "Whether to log the completions during training."}, + ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py new file mode 100644 index 0000000000..573350c277 --- /dev/null +++ b/trl/trainer/grpo_trainer.py @@ -0,0 +1,834 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +import warnings +from collections import defaultdict +from typing import Any, Callable, Optional, Sized, Union +from unittest.mock import patch + +import torch +import torch.utils.data +import transformers +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from accelerate.utils.other import is_compiled_module +from datasets import Dataset, IterableDataset +from packaging import version +from torch import nn +from torch.utils.data import Sampler +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + is_wandb_available, +) +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.utils import is_peft_available + +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ..import_utils import is_vllm_available +from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation +from .callbacks import SyncRefModelCallback +from .grpo_config import GRPOConfig +from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax + + +if is_peft_available(): + from peft import PeftConfig, get_peft_model + +if is_vllm_available(): + from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams + +if is_wandb_available(): + import wandb + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] + + +class RepeatRandomSampler(Sampler): + """ + Sampler that repeats the indices of a dataset N times. + + Args: + data_source (`Sized`): + Dataset to sample from. + repeat_count (`int`): + Number of times to repeat each index. + seed (`Optional[int]`): + Random seed for reproducibility (only affects this sampler). + + Example: + ```python + >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2) + >>> list(sampler) + [2, 2, 0, 0, 3, 3, 1, 1] + ``` + """ + + def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None): + self.data_source = data_source + self.repeat_count = repeat_count + self.num_samples = len(data_source) + self.seed = seed + self.generator = torch.Generator() # Create a local random generator + if seed is not None: + self.generator.manual_seed(seed) + + def __iter__(self): + indexes = [ + idx + for idx in torch.randperm(self.num_samples, generator=self.generator).tolist() + for _ in range(self.repeat_count) + ] + return iter(indexes) + + def __len__(self): + return self.num_samples * self.repeat_count + + +class GRPOTrainer(Trainer): + """ + Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the + paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). + + Example: + + ```python + from datasets import load_dataset + from trl import GRPOTrainer + + dataset = load_dataset("trl-lib/tldr", split="train") + + def reward_func(completions, **kwargs): + # Dummy reward function that rewards completions with more unique letters. + return [float(len(set(completion))) for completion in completions] + + trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_func, + train_dataset=dataset, + ) + + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. For more details, see + [Using a custom reward function](#using-a-custom-reward-function). + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`GRPOConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. + reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. + For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), + the corresponding entries in `reward_processing_classes` are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "grpo"] + + def __init__( + self, + model: Union[str, PreTrainedModel], + reward_funcs: Union[RewardFunc, list[RewardFunc]], + args: GRPOConfig = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, + processing_class: Optional[PreTrainedTokenizerBase] = None, + reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + peft_config: Optional["PeftConfig"] = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = GRPOConfig(f"{model_name}-GRPO") + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + model_init_kwargs["use_cache"] = ( + False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + raise ValueError( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "This argument can only be used when the `model` argument is a string." + ) + + self.beta = args.beta + + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # Reference model + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_deepspeed_zero3_enabled(): + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + elif not is_peft_model(model): + # If PEFT configuration is not provided, create a reference model based on the initial model. + self.ref_model = create_reference_model(model) + else: + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + else: + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError("The number of reward processing classes must match the number of reward functions.") + + for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + # Data collator + def data_collator(features): # No data collation is needed in GRPO + return features + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.use_vllm = args.use_vllm + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: + # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To + # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. + # This acts as a flag to indicate that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Initialize the metrics + self._metrics = defaultdict(list) + self.log_completions = args.log_completions + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + ) + + # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations + num_processes = self.accelerator.num_processes + global_batch_size = args.per_device_train_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " + f"batch size, the valid values for the number of generations are: {possible_values}." + ) + if self.args.eval_strategy != "no": + global_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " + f"eval batch size, the valid values for the number of generations are: {possible_values}." + ) + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + + if self.accelerator.is_main_process: + vllm_device = self.args.vllm_device + if vllm_device == "auto": + if torch.cuda.device_count() == 1: + vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it + else: + vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + # Check that the requested device is available + if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): + raise ValueError( + f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM " + "without restricting the number of GPUs for training. Set the `--num_processes` argument to a " + "value lower than the number of GPUs available on your machine—typically, reducing it by one " + f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." + ) + # Check that the requested device is not also used for training + if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: + warnings.warn( + f"The requested device {vllm_device} is also being used for training. For higher throughput " + "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " + "If this is intentional, you may ignore this warning but should adjust " + "`vllm_gpu_memory_utilization` accordingly." + ) + # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM + # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our + # setting (profiling_patch). + world_size_patch = patch("torch.distributed.get_world_size", return_value=1) + profiling_patch = patch( + "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None + ) + with world_size_patch, profiling_patch: + self.llm = LLM( + model=model.name_or_path, + device=vllm_device, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + dtype=self.args.vllm_dtype, + # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can + # directly reuse the KV cache if it shares the same prefix with one of the existing queries. + # This is particularly useful here because we generate completions from the same prompts. + enable_prefix_caching=True, + max_model_len=self.args.vllm_max_model_len, + ) + + # Guided decoding, if enabled + if args.vllm_guided_decoding_regex is not None: + guided_decoding = GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) + else: + guided_decoding = None + + # Sampling parameters + self.sampling_params = SamplingParams( + temperature=args.temperature, + max_tokens=self.max_completion_length, + guided_decoding=guided_decoding, + n=args.num_generations, + ) + + self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + self.generation_config = GenerationConfig( + max_new_tokens=self.max_completion_length, + do_sample=True, + temperature=args.temperature, + pad_token_id=processing_class.pad_token_id, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt"] + + def _get_train_sampler(self) -> Sampler: + # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that + # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly + # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, + # preventing discrepancies in group formation. + return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that + # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly + # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, + # preventing discrepancies in group formation. + return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed) + + # Get the per-token log probabilities for the completions for the model and the reference model + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + def _move_model_to_vllm(self): + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + if is_compiled_module(unwrapped_model): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): + unwrapped_model.merge_adapter() + state_dict = unwrapped_model.state_dict() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace("modules_to_save.default.", ""): v + for k, v in state_dict.items() + if "original_module" not in k + } + else: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights(state_dict.items()) + # Unmerge the adapter to restore the model to its original state. + # This must be done after loading weights to ensure they correspond to the merged state. + if is_peft_model(unwrapped_model): + unwrapped_model.unmerge_adapter() + + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompt_inputs = self.processing_class( + prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + + if self.max_prompt_length is not None: + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.args.use_vllm: + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text)) + all_outputs = self.llm.generate( + ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False + ) + completion_ids = [] + for outputs in all_outputs: + for output in outputs.outputs: + completion_ids.append(output.token_ids) + else: + completion_ids = [None] * len(all_prompts_text) + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] + + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + prompt_completion_ids = unwrapped_model.generate( + prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config + ) + + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # Concatenate prompt_mask with completion_mask for logit computation + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + with torch.inference_mode(): + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep + ) + + # Decode the generated completions + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes) + ): + if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Log the metrics + reward_per_func = rewards_per_func.mean(0) + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) + + self._metrics["reward"].append(rewards.mean().item()) + self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) + + if ( + self.log_completions + and self.state.global_step % self.args.logging_steps == 0 + and "wandb" in self.args.report_to + ): + import pandas as pd + + # For logging + table = { + "step": [str(self.state.global_step)] * len(rewards), + "prompt": gather_object(prompts_text), + "completion": gather_object(completions_text), + "reward": rewards.tolist(), + } + df = pd.DataFrame(table) + + if wandb.run is not None and self.accelerator.is_main_process: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "ref_per_token_logps": ref_per_token_logps, + "advantages": advantages, + } + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + # Compute the per-token log probabilities for the model + + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # x - x.detach() allows for preserving gradients from x + advantages = inputs["advantages"] + per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() + + # Log the metrics + completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics["completion_length"].append(completion_length) + + if self.beta != 0.0: + mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + + return loss + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return loss, None, None + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics.clear() + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + citation = textwrap.dedent( + """\ + @article{zhihong2024deepseekmath, + title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, + author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, + year = 2024, + eprint = {arXiv:2402.03300}, + } + """ + ) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="GRPO", + trainer_citation=citation, + paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", + paper_id="2402.03300", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 897ce25520..0c92ad70b8 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -63,6 +63,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -746,6 +747,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) @@ -807,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1027,7 +1035,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index cbe218066e..5d2a8e830d 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -46,6 +46,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, ) @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions under the model diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index d01294c2e5..12daa74a07 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -63,7 +63,11 @@ class OnlineDPOConfig(TrainingArguments): disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model and reference model. use_vllm (`bool`, *optional*, defaults to `False`): - Whether to use the vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). + Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. """ learning_rate: float = field( @@ -114,8 +118,8 @@ class OnlineDPOConfig(TrainingArguments): metadata={ "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " "the reference model. For the IPO loss (`loss_type='ipo'`), β is the regularization parameter denoted by " - "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is " - "selected for each new epoch and the last β is used for the rest of the epochs." + "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β " + "is selected for each new epoch and the last β is used for the rest of the epochs." }, ) loss_type: str = field( @@ -136,10 +140,18 @@ class OnlineDPOConfig(TrainingArguments): use_vllm: bool = field( default=False, metadata={ - "help": "Whether to use the vLLM for generating completions. Requires vLLM to be installed " + "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " "(`pip install vllm`)." }, ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) def __post_init__(self): super().__post_init__() diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 7c7a6b3169..9abefc5140 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -262,7 +262,7 @@ def __init__( top_p=1.0, detokenize=False, # to avoid vllm to decode (we don't need it) ) - # vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instanciation. + # vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instantiation. # A larger cache size improves speed, so we would expect gpu_memory_utilization=1. # However, at this stage, the optimizer's weights are not yet loaded onto the GPU; they will be loaded # after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough @@ -272,6 +272,7 @@ def __init__( gpu_memory_utilization=0.55, dtype=torch.float32, # When release by vLLM, we would be able to distribute the model on multiple GPUs + # See https://github.com/vllm-project/vllm/pull/12071 # tensor_parallel_size=torch.cuda.device_count(), # distributed_executor_backend="external_launcher", ) @@ -476,7 +477,9 @@ def _generate(self, model, prompts): inputs = self._prepare_inputs(inputs) prompt_ids = inputs["prompt_input_ids"].repeat(2, 1) prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1) - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: output = unwrapped_model.generate( input_ids=prompt_ids, attention_mask=prompt_mask, diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 06457683e8..72436d321d 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -50,7 +50,6 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper @@ -65,6 +64,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -119,9 +119,6 @@ class ORPOTrainer(Trainer): _tag_names = ["trl", "orpo"] - @deprecate_kwarg( - "tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, @@ -722,7 +719,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels = torch.where(labels == label_pad_token_id, 0, labels) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index ecaa7192d5..0b0ec0c318 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -53,6 +53,10 @@ class PPOConfig(OnPolicyConfig): Discount factor. lam (`float`, *optional*, defaults to `0.95`): Lambda value for GAE. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. """ exp_name: str = field( @@ -103,3 +107,11 @@ class PPOConfig(OnPolicyConfig): default=0.95, metadata={"help": "Lambda value for GAE."}, ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index fe3ea3a147..cf7f6768a6 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -25,7 +25,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -46,7 +45,6 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback from transformers.utils import is_peft_available -from transformers.utils.deprecation import deprecate_kwarg from ..core import masked_mean, masked_whiten from ..models import create_reference_model @@ -66,6 +64,7 @@ peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, ) @@ -98,14 +97,6 @@ def forward(self, **kwargs): class PPOTrainer(Trainer): _tag_names = ["trl", "ppo"] - @deprecate_kwarg("config", "0.15.0", "args", warn_if_greater_or_equal_version=True, raise_if_both_names=True) - @deprecate_kwarg( - "tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) - @deprecate_kwarg("policy", "0.15.0", "model", warn_if_greater_or_equal_version=True, raise_if_both_names=True) - @deprecate_kwarg( - "ref_policy", "0.15.0", "ref_model", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, args: PPOConfig, @@ -138,10 +129,18 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.policy_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: @@ -220,8 +219,6 @@ def __init__( for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler( @@ -313,9 +310,11 @@ def get_eval_dataloader(self) -> DataLoader: @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model.policy - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.policy.set_adapter(self.ref_adapter_name) yield @@ -414,7 +413,9 @@ def repeat_generator(): scores = [] sequence_lengths = [] values = [] - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: query_responses, logitss = batch_generation( unwrapped_model.policy, queries, @@ -428,9 +429,8 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = selective_log_softmax(logits, response) + del logits torch.cuda.empty_cache() if ref_policy is None: @@ -440,16 +440,15 @@ def repeat_generator(): ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) # Response Processing 2. run reward model on the truncated responses @@ -548,8 +547,7 @@ def repeat_generator(): output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -600,7 +598,7 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, @@ -688,7 +686,9 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: for batch in self.eval_dataloader: query = batch["input_ids"] with torch.no_grad(): @@ -702,9 +702,9 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) table["query"].extend( gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index ba40e3f80a..f324637e52 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -29,8 +29,8 @@ class RewardConfig(TrainingArguments): Parameters: max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. + Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the + limit. This argument is required if you want to use the default data collator. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model. dataset_num_proc (`int`, *optional*, defaults to `None`): @@ -46,8 +46,8 @@ class RewardConfig(TrainingArguments): max_length: Optional[int] = field( default=1024, metadata={ - "help": "Maximum length of the sequences (prompt + completion) in the batch. This argument is required if " - "you want to use the default data collator." + "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " + "exceed the limit. This argument is required if you want to use the default data collator." }, ) disable_dropout: bool = field( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index ea25b425ab..063fe5e8e8 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -39,7 +39,6 @@ from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template from .reward_config import RewardConfig @@ -84,9 +83,6 @@ def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") class RewardTrainer(Trainer): _tag_names = ["trl", "reward-trainer"] - @deprecate_kwarg( - "tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module]] = None, diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index a52407c171..bd0b6ed8ed 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -50,6 +50,10 @@ class RLOOConfig(OnPolicyConfig): Whether to normalize advantages. token_level_kl (`bool`, *optional*, defaults to `True`): Whether to use token-level KL penalty or sequence-level KL penalty. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. """ exp_name: str = field( @@ -96,3 +100,11 @@ class RLOOConfig(OnPolicyConfig): default=False, metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"}, ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 321ba164be..344253c2b8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -18,13 +18,12 @@ import textwrap import time from collections import defaultdict -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -56,6 +55,7 @@ get_reward, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, ) from .rloo_config import RLOOConfig @@ -79,7 +79,7 @@ def __init__( ], policy: nn.Module, ref_policy: nn.Module, - reward_model: nn.Module, + reward_model: Union[nn.Module, Callable[[list[str]], list[float]]], train_dataset: Dataset, data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, @@ -152,7 +152,8 @@ def __init__( # setup model, optimizer, and others ######### for module in [policy, ref_policy, reward_model]: - disable_dropout_in_model(module) + if isinstance(module, nn.Module): + disable_dropout_in_model(module) if args.stop_token and args.stop_token == "eos": args.stop_token_id = self.processing_class.eos_token_id self.model = policy @@ -219,16 +220,18 @@ def __init__( self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) + if isinstance(self.reward_model, nn.Module): + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) self.ref_policy = prepare_deepspeed( self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.deepspeed = self.model else: self.ref_policy = self.ref_policy.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) + if isinstance(self.reward_model, nn.Module): + self.reward_model = self.reward_model.to(self.accelerator.device) def get_train_dataloader(self) -> DataLoader: return self.dataloader @@ -310,7 +313,9 @@ def repeat_generator(): sequence_lengths = [] # Generate responses and compute logprobs - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: query_responses, logitss = batch_generation( unwrapped_model, queries, @@ -325,17 +330,15 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = selective_log_softmax(logits, response) + del logits torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -348,9 +351,18 @@ def repeat_generator(): # Response Processing 2. run reward model on the truncated responses postprocessed_query_response = torch.cat((query, postprocessed_response), 1) sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) + + if isinstance(reward_model, nn.Module): + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + else: + score = torch.tensor( + reward_model( + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) + ), + dtype=torch.float, + ).to(device) # Store batch results responses.append(response) @@ -453,8 +465,7 @@ def repeat_generator(): logits /= args.temperature + 1e-7 # Compute new logprobs - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -498,9 +509,8 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, logits, new_all_logprobs, new_logprobs, - logprobs_diff, ratio, pg_losses, pg_losses2, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, + output, logits, new_logprobs, logprobs_diff, ratio, pg_losses, + pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_advantage, mb_responses, mb_query_responses, mb_logprobs, ) # fmt: on @@ -565,7 +575,9 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: for batch in self.eval_dataloader: query = batch["input_ids"] with torch.no_grad(): @@ -591,9 +603,21 @@ def generate_completions(self, sampling: bool = False): ) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) + + if isinstance(self.reward_model, nn.Module): + _, score, _ = get_reward( + self.reward_model, + postprocessed_query_response, + processing_class.pad_token_id, + context_length, + ) + else: + score = torch.tensor( + self.reward_model( + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) + ), + dtype=torch.float, + ).to(postprocessed_query_response.device) table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 250eb74a0a..23b617dfe2 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Any, Optional @@ -23,103 +24,152 @@ class SFTConfig(TrainingArguments): r""" Configuration class for the [`SFTTrainer`]. + Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the + [`~transformers.TrainingArguments`] documentation. + Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: - dataset_text_field (`str`, *optional*, defaults to `"text"`): - Name of the text field of the dataset. If provided, the trainer will automatically create a - [`ConstantLengthDataset`] based on `dataset_text_field`. - packing (`bool`, *optional*, defaults to `False`): - Controls whether the [`ConstantLengthDataset`] packs the sequences of the dataset. - learning_rate (`float`, *optional*, defaults to `2e-5`): - Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - max_seq_length (`int` or `None`, *optional*, defaults to `None`): - Maximum sequence length for the [`ConstantLengthDataset`] and for automatically creating the dataset. If - `None`, it uses the smaller value between `tokenizer.model_max_length` and `1024`. - dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): - Number of processes to use for processing the dataset. Only used when `packing=False`. - dataset_batch_size (`Union[int, None]`, *optional*, defaults to `1000`): - Number of examples to tokenize per batch. If `dataset_batch_size <= 0` or `dataset_batch_size is None`, - tokenizes the full dataset as a single batch. + > Parameters that control the model + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SFTTrainer`] is provided as a string. + use_liger (`bool`, *optional*, defaults to `False`): + Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. + + > Parameters that control the data preprocessing + + dataset_text_field (`str`, *optional*, defaults to `"text"`): + Name of the column that contains text data in the dataset. dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + packing (`bool`, *optional*, defaults to `False`): + Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length. eval_packing (`bool` or `None`, *optional*, defaults to `None`): Whether to pack the eval dataset. If `None`, uses the same value as `packing`. - num_of_sequences (`int`, *optional*, defaults to `1024`): - Number of sequences to use for the [`ConstantLengthDataset`]. - chars_per_token (`float`, *optional*, defaults to `3.6`): - Number of characters per token to use for the [`ConstantLengthDataset`]. See - [chars_token_ratio](https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53) for more details. - use_liger (`bool`, *optional*, defaults to `False`): - Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. + + > Parameters that control the training + + learning_rate (`float`, *optional*, defaults to `2e-5`): + Initial learning rate for [`AdamW`] optimizer. The default value replaces that of + [`~transformers.TrainingArguments`]. """ - dataset_text_field: str = field( - default="text", + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Name of the text field of the dataset. If provided, the trainer will automatically create a " - "`ConstantLengthDataset` based on `dataset_text_field`." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `SFTTrainer` is provided as a string." }, ) - packing: bool = field( + use_liger: bool = field( default=False, - metadata={"help": "Controls whether the `ConstantLengthDataset` packs the sequences of the dataset."}, + metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."}, ) - learning_rate: float = field( - default=2.0e-5, - metadata={ - "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " - "`TrainingArguments`." - }, + + # Parameters that control the data preprocessing + dataset_text_field: str = field( + default="text", + metadata={"help": "Name of the column that contains text data in the dataset."}, ) - max_seq_length: Optional[int] = field( + dataset_kwargs: Optional[dict[str, Any]] = field( default=None, metadata={ - "help": "Maximum sequence length for the `ConstantLengthDataset` and for automatically creating the " - "dataset. If `None`, it uses the smaller value between `tokenizer.model_max_length` and `1024`." + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`." }, ) dataset_num_proc: Optional[int] = field( default=None, - metadata={"help": "Number of processes to use for processing the dataset. Only used when `packing=False`."}, + metadata={"help": "Number of processes to use for processing the dataset."}, ) - dataset_batch_size: int = field( - default=1000, + max_length: Optional[int] = field( + default=1024, metadata={ - "help": "Number of examples to tokenize per batch. If `dataset_batch_size <= 0` or `dataset_batch_size is " - "None`, tokenizes the full dataset as a single batch." + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "sequence length." }, ) - model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, + packing: bool = field( + default=False, metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " - "from a string." + "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define " + "sequence length." }, ) - dataset_kwargs: Optional[dict[str, Any]] = field( + eval_packing: Optional[bool] = field( default=None, + metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + ) + + # Parameters that control the training + learning_rate: float = field( + default=2.0e-5, metadata={ - "help": "Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets." + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`TrainingArguments`." }, ) - eval_packing: Optional[bool] = field( + + # Deprecated parameters + dataset_batch_size: int = field( default=None, - metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + metadata={"help": "Deprecated. You can safely remove this parameter from your configuration."}, ) num_of_sequences: int = field( - default=1024, - metadata={"help": "Number of sequences to use for the `ConstantLengthDataset`."}, + default=None, + metadata={ + "help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized " + "sequence, unlike `num_of_sequences`, which referred to string sequences." + }, ) chars_per_token: float = field( - default=3.6, metadata={"help": "Number of characters per token to use for the `ConstantLengthDataset`."} + default=None, + metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."}, ) - use_liger: bool = field( - default=False, - metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."}, + max_seq_length: Optional[int] = field( + default=None, + metadata={"help": "Deprecated. Use `max_length` instead."}, ) + + def __post_init__(self): + super().__post_init__() + + if self.dataset_batch_size is not None: + warnings.warn( + "`dataset_batch_size` is deprecated and will be remove in version 0.18.0. You can safely remove this " + "parameter from your configuration.", + DeprecationWarning, + ) + + if self.num_of_sequences is not None: + warnings.warn( + "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_length` instead, " + "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r" + "eferred to string sequences.", + DeprecationWarning, + ) + + if self.chars_per_token is not None: + warnings.warn( + "`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the " + "packing length, use `max_length`.", + DeprecationWarning, + ) + + if self.max_seq_length is not None: + warnings.warn( + "`max_seq_length` is deprecated and will be remove in version 0.20.0. Use `max_length` instead.", + DeprecationWarning, + ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9e2e5fe04f..b0104f4b53 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -13,18 +13,17 @@ # limitations under the License. import dataclasses -import inspect import os import warnings -from typing import Callable, Optional, Union +from collections import defaultdict +from typing import Any, Callable, Optional, Type, Union -import datasets import torch import torch.nn as nn -from accelerate.state import PartialState -from datasets import Dataset -from datasets.arrow_writer import SchemaInferenceError -from datasets.builder import DatasetGenerationError +import transformers +from accelerate import PartialState +from datasets import Dataset, IterableDataset +from packaging import version from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -36,25 +35,20 @@ PreTrainedTokenizerBase, ProcessorMixin, Trainer, + TrainingArguments, is_wandb_available, ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_liger_kernel_available, is_peft_available -from transformers.utils.deprecation import deprecate_kwarg -from ..extras.dataset_formatting import get_formatting_func_from_dataset +from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples from .sft_config import SFTConfig -from .utils import ( - ConstantLengthDataset, - DataCollatorForCompletionOnlyLM, - generate_model_card, - get_comet_experiment_url, - peft_module_casting_to_bf16, -) +from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16 if is_peft_available(): + import peft from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_liger_kernel_available(): @@ -65,245 +59,173 @@ class SFTTrainer(Trainer): - r""" - Class definition of the Supervised Finetuning Trainer (SFT Trainer). - This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. - The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. + """ + Trainer for Supervised Fine-Tuning (SFT) method. + + This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from datasets import load_dataset + from trl import SFTTrainer + + dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + + trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` Args: - model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): - The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to - load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is - passed to the `peft_config` argument. - args (`Optional[SFTConfig]`): - The arguments to tweak for training. Will default to a basic instance of [`SFTConfig`] with the `output_dir` - set to a directory named *tmp_trainer* in the current directory if not provided. - data_collator (`Optional[transformers.DataCollator]`): - The data collator to use for training. - train_dataset (`Optional[datasets.Dataset]`): - The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. - eval_dataset (Optional[Union[`datasets.Dataset`, dict[`str`, `datasets.Dataset`]]]): - The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. - processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - This supercedes the `tokenizer` argument, which is now deprecated. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be used. - compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to None): - The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. - If not specified, only the loss will be computed during evaluation. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`Optional[PeftConfig]`): - The PeftConfig object to use to initialize the PeftModel. + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + args ([`SFTConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator (`DataCollator`, *optional*): + Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`. + Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance + of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or + tokenizer. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoTokenizer.from_pretrained`]. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. formatting_func (`Optional[Callable]`): - The formatting function to be used for creating the `ConstantLengthDataset`. + Formatting function applied to the dataset before tokenization. """ _tag_names = ["trl", "sft"] - @deprecate_kwarg( - "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, - model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: Optional[SFTConfig] = None, + model: Union[str, nn.Module, PreTrainedModel], + args: Optional[Union[SFTConfig, TrainingArguments]] = None, data_collator: Optional[DataCollator] = None, # type: ignore - train_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, - formatting_func: Optional[Callable] = None, + formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None, ): + # Args if args is None: - args = SFTConfig(output_dir="tmp_trainer") - elif args is not None and args.__class__.__name__ == "TrainingArguments": - args_as_dict = args.to_dict() - # Manually copy token values as TrainingArguments.to_dict() redacts them - args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) - args = SFTConfig(**args_as_dict) - - if getattr(args, "model_init_kwargs", None) is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." - ) - model_init_kwargs["torch_dtype"] = torch_dtype - - if isinstance(model, str): - if args.use_liger: - model = AutoLigerKernelForCausalLM.from_pretrained(model, **model_init_kwargs) - else: - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - if args.packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): - raise ValueError( - "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = SFTConfig(f"{model_name}-SFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = SFTConfig(**dict_args) + + # Model + if args.model_init_kwargs is not None and not isinstance(model, str): + warnings.warn( + "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." ) + if isinstance(model, str): + model = self._create_model_from_path(model, args) + self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM) - if is_peft_available() and peft_config is not None: - if not isinstance(peft_config, PeftConfig): - raise ValueError( - "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." - f" and you passed a {type(peft_config)}." - ) - - if not isinstance(model, PeftModel): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} - is_sharded_qlora = False - # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call - # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing - # QLoRA + FSDP / DS-Zero3 - if getattr(model, "is_loaded_in_4bit", False): - for _, param in model.named_parameters(): - if param.__class__.__name__ == "Params4bit": - is_sharded_qlora = param.data.device.type in {"cpu", "meta"} - break - if getattr(model, "is_loaded_in_8bit", False) or ( - getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora - ): - prepare_model_kwargs = { - "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) - } - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - - if args is not None: - args = dataclasses.replace(args, gradient_checkpointing=False) - elif getattr(args, "gradient_checkpointing", False) and ( - "use_reentrant" not in gradient_checkpointing_kwargs - or gradient_checkpointing_kwargs["use_reentrant"] - ): - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if ( - "autocast_adapter_dtype" in list(inspect.signature(get_peft_model).parameters) - and getattr(model, "is_loaded_in_4bit", False) - and is_sharded_qlora - ): - model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) - else: - model = get_peft_model(model, peft_config) - if ( - args is not None - and args.bf16 - and getattr(model, "is_loaded_in_4bit", False) - and not is_sharded_qlora - ): - peft_module_casting_to_bf16(model) + # PEFT configuration and model wrapping + if peft_config is not None: + model = self._prepare_peft_model(model, peft_config, args) + # Handle the tokenizer if processing_class is None: processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path) - if getattr(processing_class, "pad_token", None) is None: - processing_class.pad_token = processing_class.eos_token - - if args.max_seq_length is None: - # to overcome some issues with broken tokenizers - args.max_seq_length = min(processing_class.model_max_length, 1024) - - self.dataset_num_proc = args.dataset_num_proc - self.dataset_batch_size = args.dataset_batch_size - - if args.dataset_kwargs is None: - args.dataset_kwargs = {} - - if formatting_func is None: - # check if dataset has ChatML format or instruction format and is supported - # if not stays None - formatting_func = get_formatting_func_from_dataset(train_dataset, processing_class) - # if a template is detected, we don't need to add special tokens again - if formatting_func is not None: - args.dataset_kwargs["add_special_tokens"] = False - - if not args.packing: - if data_collator is None: - data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) - - # Pre-process the datasets only once per node. The remaining processes will use the cache. - with PartialState().local_main_process_first(): - if train_dataset is not None: - train_dataset = self._prepare_dataset( - train_dataset, - processing_class, - args.packing, - args.dataset_text_field, - args.max_seq_length, - formatting_func, - args.num_of_sequences, - args.chars_per_token, - remove_unused_columns=args.remove_unused_columns if args is not None else True, - **args.dataset_kwargs, - ) + if processing_class.pad_token is None: + processing_class.pad_token = processing_class.eos_token # required for padding when collating data + + # Dataset + preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) + if preprocess_dataset: + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) if eval_dataset is not None: - _multiple = isinstance(eval_dataset, dict) - _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} - - eval_packing = args.packing if args.eval_packing is None else args.eval_packing - - for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): - _eval_datasets[_eval_dataset_name] = self._prepare_dataset( - _eval_dataset, - processing_class, - eval_packing, - args.dataset_text_field, - args.max_seq_length, - formatting_func, - args.num_of_sequences, - args.chars_per_token, - remove_unused_columns=args.remove_unused_columns if args is not None else True, - **args.dataset_kwargs, + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" ) - if not _multiple: - eval_dataset = _eval_datasets["singleton"] - - if processing_class.padding_side is not None and processing_class.padding_side != "right": - warnings.warn( - "You passed a processing_class with `padding_side` not equal to `right` to the SFTTrainer. This might " - "lead to some unexpected behaviour due to overflow issues when training a model in half-precision. " - "You might consider adding `processing_class.padding_side = 'right'` to your code.", - UserWarning, - ) + # Data collator + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) + + # Initialize the metrics + self._metrics = defaultdict(list) + + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation + # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped. + super_init_kwargs = {} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs + else: + if optimizer_cls_and_kwargs is not None: + warnings.warn( + "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. " + "The default optimizer will be used. " + "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`." + ) super().__init__( model=model, args=args, @@ -311,196 +233,274 @@ def make_inputs_require_grad(module, input, output): train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, + compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, + **super_init_kwargs, ) # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - if self.train_dataset is not None: - if self.args.max_steps > 0 and args.packing: - self.train_dataset.infinite = True - elif self.args.max_steps == -1 and args.packing: - self.train_dataset.infinite = False + def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel: + """Creates a model from a path or model identifier.""" + model_init_kwargs = args.model_init_kwargs or {} + # Handle torch dtype + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + if args.gradient_checkpointing: + model_init_kwargs["use_cache"] = False + + # Create model + if args.use_liger: + if not is_liger_kernel_available(): + raise ImportError("Please install Liger-kernel for use_liger=True") + model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + return model + + def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + if not is_peft_available(): + raise ImportError("To use PeftModel, you need to install the `peft` library.") + + if not isinstance(peft_config, PeftConfig): + raise ValueError( + f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need " + "to pass a PeftConfig object to the SFTTrainer." + ) + + if isinstance(model, PeftModel): + return model + + # Handle quantized models (QLoRA) + is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) + + is_sharded_qlora = False + if getattr(model, "is_loaded_in_4bit", False): + # Check if model is sharded (FSDP/DS-Zero3) + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} + break + + # Prepare model for kbit training if needed + if is_qlora and not is_sharded_qlora: + model = self._prepare_model_for_kbit_training(model, args) + # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training + args = dataclasses.replace(args, gradient_checkpointing=False) + elif args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Create PEFT model + if ( + version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) + + # Handle bf16 casting for 4-bit models + if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: + peft_module_casting_to_bf16(model) + + return model + + def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: + """Prepares a quantized model for kbit training.""" + prepare_model_kwargs = { + "use_gradient_checkpointing": args.gradient_checkpointing, + "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {}, + } + + return prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ) + + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + return model def _prepare_dataset( self, - dataset, - processing_class, - packing, - dataset_text_field: str, - max_seq_length, - formatting_func: Optional[Callable], - num_of_sequences, - chars_per_token, - remove_unused_columns=True, - append_concat_token=True, - add_special_tokens=True, - skip_prepare_dataset=False, - ): - if dataset is None: - raise ValueError("The dataset should not be None") - - if skip_prepare_dataset: + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: SFTConfig, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Convert the dataset to an IterableDataset if it is a ConstantLengthDataset + if isinstance(dataset, ConstantLengthDataset): return dataset - # If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is - # a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset - column_names = ( - dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None - ) - if column_names and "input_ids" in column_names: - if formatting_func is not None: + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "input_ids" in column_names + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().local_main_process_first(): + # Apply the formatting function if any + if formatting_func is not None and is_processed: warnings.warn( "You passed a dataset that is already processed (contains an `input_ids` field) together with a " - "valid formatting function. Therefore `formatting_func` will be ignored. Either remove the " + "formatting function. Therefore `formatting_func` will be ignored. Either remove the " "`formatting_func` or pass a dataset that is not already processed.", UserWarning, ) - def formatting_func(x): - return x["input_ids"] + if formatting_func is not None and not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" - if not packing: - return dataset + batched = isinstance(formatting_func(next(iter(dataset))), list) - # check if torch dataset / dataloader and do nothing - # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check - if isinstance( - dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset) - ) and not isinstance(dataset, datasets.IterableDataset): - return dataset + def _func(example): + return {"text": formatting_func(example)} - if not packing: - return self._prepare_non_packed_dataloader( - processing_class, - dataset, - dataset_text_field, - max_seq_length, - formatting_func, - add_special_tokens, - remove_unused_columns, - ) + dataset = dataset.map(_func, batched=batched, **map_kwargs) - else: - return self._prepare_packed_dataloader( - processing_class, - dataset, - dataset_text_field, - max_seq_length, - num_of_sequences, - chars_per_token, - formatting_func, - append_concat_token, - add_special_tokens, + # If the dataset is prompt-completion, convert it to language modeling type + if "prompt" in dataset.column_names and "completion" in dataset.column_names: + key = "messages" if is_conversational(dataset[0]) else "text" + + def concat_prompt_completion(example): + return {key: example["prompt"] + example["completion"]} + + dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"]) + + # Convert the dataset to ChatML if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in dataset.column_names else None, + **map_kwargs, ) - def _prepare_non_packed_dataloader( - self, - processing_class, - dataset, - dataset_text_field: str, - max_seq_length, - formatting_func: Optional[Callable] = None, - add_special_tokens=True, - remove_unused_columns=True, - ): - # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt - def tokenize(element): - outputs = processing_class( - element[dataset_text_field] if formatting_func is None else formatting_func(element), - add_special_tokens=add_special_tokens, - truncation=True, - padding=False, - max_length=max_seq_length, - return_overflowing_tokens=False, - return_length=False, + # Apply the chat template if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + remove_columns="messages" if "messages" in dataset.column_names else None, # renamed to "text" + **map_kwargs, ) - if formatting_func is not None and not isinstance(formatting_func(element), list): - raise ValueError( - "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." + # Tokenize the dataset if needed + if not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(ex): + tokenized = processing_class(ex[args.dataset_text_field]) + return {"input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"]} + + dataset = dataset.map(tokenize, **map_kwargs) + + # Pack or truncate + if packing: + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + dataset = dataset.select_columns("input_ids") + dataset = dataset.map( + pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs ) + elif args.max_length is not None: + dataset = dataset.map( + lambda ex: {key: ex[key][: args.max_length] for key in ["input_ids", "attention_mask"]}, + **map_kwargs, + ) + # For Liger kernel, ensure only input_ids is present + if args.use_liger: + dataset = dataset.select_columns("input_ids") - return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + return dataset - signature_columns = ["input_ids", "labels", "attention_mask"] + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) - if dataset.column_names is not None: # None for IterableDataset - extra_columns = list(set(dataset.column_names) - set(signature_columns)) - else: - extra_columns = [] + # Compute token accuracy if we have labels and if the model is not using Liger (no logits) + if "labels" in inputs and not self.use_liger: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = inputs["labels"][..., 1:].contiguous() - if not remove_unused_columns and len(extra_columns) > 0: - warnings.warn( - "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with " - "the default collator and yield to errors. If you want to inspect dataset other columns (in this " - f"case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the " - "default collator and create your own data collator in order to inspect the unused dataset columns.", - UserWarning, - ) + # Get predictions + predictions = shift_logits.argmax(dim=-1) - map_kwargs = { - "batched": True, - "remove_columns": dataset.column_names if remove_unused_columns else None, - "batch_size": self.dataset_batch_size, - } - if isinstance(dataset, datasets.Dataset): - map_kwargs["num_proc"] = self.dataset_num_proc # this arg is not available for IterableDataset - tokenized_dataset = dataset.map(tokenize, **map_kwargs) + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 - return tokenized_dataset + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() - def _prepare_packed_dataloader( - self, - processing_class, - dataset, - dataset_text_field: str, - max_seq_length, - num_of_sequences, - chars_per_token, - formatting_func: Optional[Callable] = None, - append_concat_token=True, - add_special_tokens=True, - ): - if processing_class is None: - raise ValueError("You need to pass a processing_class with `SFTTrainer`.") - - constant_length_iterator = ConstantLengthDataset( - processing_class, - dataset, - dataset_text_field=None if formatting_func is not None else dataset_text_field, - formatting_func=formatting_func, - seq_length=max_seq_length, - infinite=False, - num_of_sequences=num_of_sequences, - chars_per_token=chars_per_token, - eos_token_id=processing_class.eos_token_id, - append_concat_token=append_concat_token, - add_special_tokens=add_special_tokens, - ) + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) - if isinstance(dataset, datasets.IterableDataset): - return constant_length_iterator + # Compute the mean token accuracy and log it + accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0 + self._metrics["mean_token_accuracy"].append(accuracy) - def data_generator(constant_length_iterator): - yield from constant_length_iterator + return (loss, outputs) if return_outputs else loss - try: - packed_dataset = Dataset.from_generator( - data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator} - ) - except (DatasetGenerationError, SchemaInferenceError) as exc: - raise ValueError( - "Error occurred while packing the dataset. " - "Make sure that your dataset has enough samples to at least yield one packed sequence." - ) from exc - return packed_dataset + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics.clear() def create_model_card( self, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index cda8803f3c..853ba1f3ca 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -26,6 +26,7 @@ import numpy as np import pandas as pd import torch +import torch.nn.functional as F import torch.utils.data from accelerate import Accelerator, PartialState from accelerate.state import AcceleratorState @@ -139,7 +140,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -166,7 +167,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -181,7 +182,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find instruction key `{self.instruction_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -993,9 +994,15 @@ class OnPolicyConfig(TrainingArguments): response_length (`int`, *optional*, defaults to `53`): Length of the response. stop_token (`str` or `None`, *optional*, defaults to `None`): - Stop token. + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. + stop_token_id (`int` or `None`, *optional*, defaults to `None`): - Truncation token id. + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. temperature (`float`, *optional*, defaults to `0.7`): Sampling temperature. missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): @@ -1054,11 +1061,17 @@ class OnPolicyConfig(TrainingArguments): ) stop_token: Optional[Literal["eos"]] = field( default=None, - metadata={"help": "Stop token."}, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, ) stop_token_id: Optional[int] = field( default=None, - metadata={"help": "Truncation token id."}, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, ) temperature: float = field( default=0.7, @@ -1569,3 +1582,104 @@ def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: experiment = comet_ml.get_running_experiment() if experiment is not None: experiment.log_table(tabular_data=table, filename=name) + + +def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """ + Shift non-zero elements in the mask and corresponding tensors to the left. + + This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. + For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across + all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: + + ``` + [[0, 0, x, x, x, x], -> [[x, x, x, x], + [0, x, x, x, 0, 0]] [x, x, x, 0]] + ``` + + Args: + + mask (`torch.Tensor`): + 2D tensor (binary mask) with shape `(N, M)`. + *tensors (`torch.Tensor`) + One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, + with non-zero values shifted and excess zero columns truncated in the same manner. + + Returns: + `torch.Tensor`: + Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. + `*torch.Tensor` + Updated tensors, processed in the same way as the mask. + + Example: + ```python + >>> mask = torch.tensor([[0, 0, 1, 1, 1], + ... [0, 1, 1, 0, 0]]) + >>> tensor = torch.tensor([[9, 9, 2, 3, 4], + ... [9, 5, 6, 9, 9]]) + >>> new_mask, new_tensor = flush_left(mask, tensor) + >>> print(new_mask) + tensor([[1, 1, 1], + [1, 1, 0]]) + >>> print(new_tensor) + tensor([[2, 3, 4], + [5, 6, 0]]) + ``` + """ + # Create copy of mask and tensors + mask = mask.clone() + tensors = [t.clone() for t in tensors] + + # Shift non-zero values to the left + for i in range(mask.size(0)): + first_one_idx = torch.nonzero(mask[i])[0].item() + mask[i] = torch.roll(mask[i], shifts=-first_one_idx) + for tensor in tensors: + tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx) + + # Get the first column idx that is all zeros and remove every column after that + empty_cols = torch.sum(mask, dim=0) == 0 + first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1) + mask = mask[:, :first_empty_col] + for i, tensor in enumerate(tensors): + tensors[i] = tensor[:, :first_empty_col] + + if not tensors: + return mask + else: + return mask, *tensors + + +def selective_log_softmax(logits, index): + """ + A memory-efficient implementation of the common `log_softmax -> gather` operation. + + This function is equivalent to the following naive implementation: + ```python + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 2d535344e7..6c7579ae8a 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -44,6 +44,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, ) from .xpo_config import XPOConfig @@ -274,8 +275,7 @@ def _compute_logprobs(self, model, model_data, ref_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions