From 0f5ffad26e96d1a0eef568be91da956e81b4a11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 20 Jan 2025 19:02:15 +0100 Subject: [PATCH 01/96] =?UTF-8?q?=F0=9F=91=A8=E2=80=8D=F0=9F=91=A8?= =?UTF-8?q?=E2=80=8D=F0=9F=91=A7=E2=80=8D=F0=9F=91=A7=20GRPO=20(#2565)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init grpo [ci skip] * initial version * refine args defs * model card * initial doc * fix badges * fix spaces * try link to super in doc * temperature, fix indexing, and std=0.0 * grpo script for cli * peft support * move data preparation in `compute_loss` * weird doc trial * fix device and some logging * unwrap_model_for_generation for distributed setting * Compat with distrib training * revert grpo config doc trial (didn't work) * test * allow model to be str and processing_class to be none; fix loss computation * advantage is always 0.0: don't log * fix peft not installed * proper reward model for testing * fix script for cli * add trl grpo to cli doc * test peft * flush left * fix reward calculation * new reward model * support any reward model * fix reward processing class def * log reward std * fix reward logging * fix grad computation * skip embed layer in test * remove optimizer_cls_and_kwargs * improve GRPO default args * reduce mem usage for grpo test * reduce mem usage in test grpo * reduce memory usage for test * Fix the test * remove redondant * fix min version * Update test_grpo_trainer.py * Update test_grpo_trainer.py * Fix test, finally found the solution! * some doc * Update doc-builder workflow to use specific commit sha * more doc * advantages * drop cancel fo no grad * logged metrics [ci skip] * completion col is ignored [ci skip] * fix latex * double space? ~? * try a latex fix * with branch * Empty commit * Empty commit * double space seems to be the solution --- .github/workflows/build_pr_documentation.yml | 2 +- docs/source/_toctree.yml | 2 + docs/source/clis.mdx | 1 + docs/source/dataset_formats.mdx | 1 + docs/source/grpo_trainer.md | 123 +++++++ docs/source/kto_trainer.mdx | 2 +- scripts/generate_tiny_models.py | 23 ++ tests/test_cli.py | 6 + tests/test_grpo_trainer.py | 148 ++++++++ tests/test_utils.py | 47 +++ trl/__init__.py | 4 + trl/cli.py | 11 + trl/scripts/grpo.py | 92 +++++ trl/trainer/__init__.py | 4 + trl/trainer/dpo_trainer.py | 14 +- trl/trainer/grpo_config.py | 92 +++++ trl/trainer/grpo_trainer.py | 351 +++++++++++++++++++ trl/trainer/utils.py | 66 ++++ 18 files changed, 975 insertions(+), 14 deletions(-) create mode 100644 docs/source/grpo_trainer.md create mode 100644 tests/test_grpo_trainer.py create mode 100644 trl/scripts/grpo.py create mode 100644 trl/trainer/grpo_config.py create mode 100644 trl/trainer/grpo_trainer.py diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index acc8d16d35..bf72dc7c1e 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -9,7 +9,7 @@ concurrency: jobs: build: - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@e4fcf608695cf4bddb8c7f4f72aa15fa14110a94 with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1f793e6ea9..4ccc57ae8f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -68,6 +68,8 @@ title: Online DPO - local: gkd_trainer title: GKD + - local: grpo_trainer + title: GRPO - local: kto_trainer title: KTO - local: nash_md_trainer diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index 9c7a2dfca8..885227a116 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -7,6 +7,7 @@ Currently supported CLIs are: #### Training commands - `trl dpo`: fine-tune a LLM with DPO +- `trl grpo`: fine-tune a LLM with GRPO - `trl kto`: fine-tune a LLM with KTO - `trl sft`: fine-tune a LLM with SFT diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index 1306fdad7f..d29770abb4 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -270,6 +270,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`GRPOTrainer`] | [Prompt-only](#prompt-only) | | [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | | [`NashMDTrainer`] | [Prompt-only](#prompt-only) | diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md new file mode 100644 index 0000000000..59abe04356 --- /dev/null +++ b/docs/source/grpo_trainer.md @@ -0,0 +1,123 @@ +# GRPO Trainer + +[![](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl) + +## Overview + +TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday). + +The abstract from the paper is the following: + +> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. + +This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Quick start + +This example demonstrates how to train a model using the GRPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model and the [RM-Gemma-2B model](https://huggingface.co/weqweasdas/RM-Gemma-2B) as the reward model. We use the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here: + + + +Below is the script to train the model. We use PEFT to reduce the memory requirements. + +```python +# train_grpo.py +from datasets import load_dataset +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer + +# Load the dataset +dataset = load_dataset("trl-lib/tldr", split="train") + +training_args = GRPOConfig( + output_dir="Qwen2-0.5B-GRPO", + learning_rate=1e-5, + logging_steps=10, + gradient_accumulation_steps=16, + max_completion_length=128, +) +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_model="weqweasdas/RM-Gemma-2B", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(task_type="CAUSAL_LM"), +) + +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_grpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 day. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png) + +## Looking deeper into the GRPO method + +GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png) + +### Generating completions + +At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)). + +### Computing the advantage + +For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows: + +$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$ + +This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**. + +### Estimating the KL divergence + +KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows: + +$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i, [[x, x, x, x], # [0, x, x, x, 0, 0]] [x, x, x, 0]] - for i in range(attention_mask.size(0)): - first_one_idx = torch.nonzero(attention_mask[i])[0].item() - input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) - attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) - loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) - - # Get the first column idx that is all zeros and remove every column after that - empty_cols = torch.sum(attention_mask, dim=0) == 0 - first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) - input_ids = input_ids[:, :first_empty_col] - attention_mask = attention_mask[:, :first_empty_col] - loss_mask = loss_mask[:, :first_empty_col] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) # Truncate right if self.args.max_length is not None: diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py new file mode 100644 index 0000000000..a0a61d3b16 --- /dev/null +++ b/trl/trainer/grpo_config.py @@ -0,0 +1,92 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class GRPOConfig(TrainingArguments): + r""" + Configuration class for the [`GRPOTrainer`]. + + Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the + [`~transformers.TrainingArguments`] documentation. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`GRPOTrainer`] is provided as a string. + + > Parameters that control the data preprocessing + + num_generations (`int` or `None`, *optional*, defaults to `8`): + Number of generations per prompt to sample. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the generated completion. + + > Parameters that control the training + + learning_rate (`float`, *optional*, defaults to `1e-6`): + Initial learning rate for [`AdamW`] optimizer. The default value replaces that of + [`~transformers.TrainingArguments`]. + beta (`float`, *optional*, defaults to `0.04`): + KL coefficient. + """ + + # Parameters that control the model and reference model + model_init_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `GRPOTrainer` is provided as a string." + }, + ) + + # Parameters that control the data preprocessing + num_generations: Optional[int] = field( + default=8, + metadata={"help": "Number of generations to sample."}, + ) + temperature: Optional[float] = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + max_completion_length: Optional[int] = field( + default=256, + metadata={"help": "Maximum length of the generated completion."}, + ) + + # Parameters that control the training + learning_rate: float = field( + default=1e-6, + metadata={ + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." + }, + ) + beta: float = field( + default=0.04, + metadata={"help": "KL coefficient."}, + ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py new file mode 100644 index 0000000000..0e830b6822 --- /dev/null +++ b/trl/trainer/grpo_trainer.py @@ -0,0 +1,351 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.utils.data +import transformers +from datasets import Dataset, IterableDataset +from packaging import version +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollator, + EvalPrediction, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + is_wandb_available, +) +from transformers.utils import is_peft_available + +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ..models import create_reference_model, unwrap_model_for_generation +from .grpo_config import GRPOConfig +from .utils import generate_model_card, get_comet_experiment_url + + +if is_peft_available(): + from peft import PeftConfig, get_peft_model + +if is_wandb_available(): + import wandb + + +class GRPOTrainer(Trainer): + def __init__( + self, + model: Union[str, PreTrainedModel, nn.Module] = None, + reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + args: GRPOConfig = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, + processing_class: Optional[PreTrainedTokenizerBase] = None, + reward_processing_class: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = GRPOConfig(f"{model_name}-GRPO") + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + raise ValueError( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "This argument can only be used when the `model` argument is a string." + ) + + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # Reference model + if peft_config is None: + # If PEFT configuration is not provided, create a reference model based on the initial model. + self.ref_model = create_reference_model(model) + else: + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") + + # Reward model + if isinstance(reward_model, str): + reward_model = AutoModelForSequenceClassification.from_pretrained( + reward_model, num_labels=1, **model_init_kwargs + ) + self.reward_model = reward_model + + # Reward processing class + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(reward_model.config._name_or_path) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + self.reward_processing_class = reward_processing_class + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + self.reward_model.config.pad_token_id = reward_processing_class.pad_token_id + + # Data loading and preprocessing + if data_collator is None: + + def data_collator(features): # No data collation is needed in GRPO + return features + + # Training arguments + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.generation_config = GenerationConfig( + max_new_tokens=self.max_completion_length, + do_sample=True, + temperature=args.temperature, + num_return_sequences=self.num_generations, + pad_token_id=processing_class.pad_token_id, + ) + self.beta = args.beta + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: + # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To + # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. + # This acts as a flag to indicate that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Initialize the metrics + self._metrics = {"kl": [], "reward": [], "reward_std": []} + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_loss_func=compute_loss_func, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.ref_model is not None: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt"] + + # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device. + # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step. + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + return inputs + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + + prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompt_inputs = self.processing_class( + prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + + # Generate completions + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config) + prompt_length = prompt_inputs["input_ids"].size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Get the per-token log probabilities for the completions for the model and the reference model + def get_per_token_logps(model, input_ids): + logits = model(input_ids).logits + logits = torch.roll(logits, shifts=1, dims=1) # Shape (B*G, L) + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=input_ids.unsqueeze(2)).squeeze(2) + return per_token_logps + + per_token_logps = get_per_token_logps(model, prompt_completion_ids) + per_token_logps = per_token_logps[:, prompt_length:] # get rid of the prompt + + with torch.inference_mode(): + if self.ref_model is not None: + ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids) + else: + with self.accelerator.unwrap_model(model).disable_adapter(): + ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids) + ref_per_token_logps = ref_per_token_logps[:, prompt_length:] # get rid of the prompt + + # Compute the KL divergence between the model and the reference model + per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + device = self.accelerator.device + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # Decode the generated completions + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Compute the rewards + prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] + if is_conversational(inputs[0]): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, self.reward_processing_class)["text"] for x in messages] + reward_inputs = self.reward_processing_class( + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = self.reward_processing_class( + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards = self.reward_model(**reward_inputs).logits[:, 0] # Shape (B*G,) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # x - x.detach() allows for preserving gradients from x + advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss = -(advantages - self.beta * per_token_kl) + loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + + # Log the metrics + self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) + + self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) + + mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + + return loss + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super().log(logs, start_time) + else: # transformers<=4.46 + return super().log(logs) + self._metrics = {key: [] for key in self._metrics} + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + citation = textwrap.dedent( + """\ + @article{zhihong2024deepseekmath, + title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, + author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, + year = 2024, + eprint = {arXiv:2402.03300}, + """ + ) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="GRPO", + trainer_citation=citation, + paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", + paper_id="2402.03300", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index cda8803f3c..1228dc7ece 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1569,3 +1569,69 @@ def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: experiment = comet_ml.get_running_experiment() if experiment is not None: experiment.log_table(tabular_data=table, filename=name) + + +def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """ + Shift non-zero elements in the mask and corresponding tensors to the left. + + This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. + For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across + all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: + + ``` + [[0, 0, x, x, x, x], -> [[x, x, x, x], + [0, x, x, x, 0, 0]] [x, x, x, 0]] + ``` + + Args: + + mask (`torch.Tensor`): + 2D tensor (binary mask) with shape `(N, M)`. + *tensors (`torch.Tensor`) + One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, + with non-zero values shifted and excess zero columns truncated in the same manner. + + Returns: + `torch.Tensor`: + Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. + `*torch.Tensor` + Updated tensors, processed in the same way as the mask. + + Example: + ```python + >>> mask = torch.tensor([[0, 0, 1, 1, 1], + ... [0, 1, 1, 0, 0]]) + >>> tensor = torch.tensor([[9, 9, 2, 3, 4], + ... [9, 5, 6, 9, 9]]) + >>> new_mask, new_tensor = flush_left(mask, tensor) + >>> print(new_mask) + tensor([[1, 1, 1], + [1, 1, 0]]) + >>> print(new_tensor) + tensor([[2, 3, 4], + [5, 6, 0]]) + ``` + """ + # Create copy of mask and tensors + mask = mask.clone() + tensors = [t.clone() for t in tensors] + + # Shift non-zero values to the left + for i in range(mask.size(0)): + first_one_idx = torch.nonzero(mask[i])[0].item() + mask[i] = torch.roll(mask[i], shifts=-first_one_idx) + for tensor in tensors: + tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx) + + # Get the first column idx that is all zeros and remove every column after that + empty_cols = torch.sum(mask, dim=0) == 0 + first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1) + mask = mask[:, :first_empty_col] + for i, tensor in enumerate(tensors): + tensors[i] = tensor[:, :first_empty_col] + + if not tensors: + return mask + else: + return mask, *tensors From 5fd78367aef83d3666737d29e50a0d88eb6d807f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 20 Jan 2025 21:26:11 +0100 Subject: [PATCH 02/96] =?UTF-8?q?=F0=9F=AB=A3=20Ignore=20CLI=20test=20for?= =?UTF-8?q?=20Python=203.9=20(#2592)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ignore cli test for python 3.9 * move import inside tests --- tests/test_cli.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 12118ebc2f..2a5b2ed2cd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,16 +13,22 @@ # limitations under the License. +import sys import tempfile import unittest from io import StringIO from unittest.mock import patch -from trl.cli import main - +@unittest.skipIf( + sys.version_info < (3, 10), + "Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " + "to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche +) class TestCLI(unittest.TestCase): def test_dpo(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none" with patch("sys.argv", command.split(" ")): @@ -30,6 +36,8 @@ def test_dpo(self): @patch("sys.stdout", new_callable=StringIO) def test_env(self, mock_stdout): + from trl.cli import main + command = "trl env" with patch("sys.argv", command.split(" ")): main() @@ -42,12 +50,16 @@ def test_grpo(self): main() def test_kto(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_unpaired_preference --report_to none" with patch("sys.argv", command.split(" ")): main() def test_sft(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_language_modeling --report_to none" with patch("sys.argv", command.split(" ")): From 3d2c1e49b1407989beaeb494d0d92ac73c5637ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 20 Jan 2025 22:17:39 +0100 Subject: [PATCH 03/96] Fix merge error (#2595) --- tests/test_cli.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2a5b2ed2cd..234b4e7ba1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -44,6 +44,8 @@ def test_env(self, mock_stdout): self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) def test_grpo(self): + from trl.cli import main + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 3 --max_completion_length 32 --report_to none" with patch("sys.argv", command.split(" ")): From d9f056862f2bd1514fad068f49f6bc05d8494d4a Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Tue, 21 Jan 2025 09:32:31 +0330 Subject: [PATCH 04/96] =?UTF-8?q?=F0=9F=A7=B0=20Tool=20fine-tuning=20suppo?= =?UTF-8?q?rt=20DPO=20(#2479)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * adding tool fine-tuning support for DPO * precommit * adding test for DPOTrainer with tool usage * style * fix test * a comment --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- tests/test_dpo_trainer.py | 39 +++++++++++++++++++++++++++++++++++++- trl/trainer/dpo_config.py | 12 +++++++++++- trl/trainer/dpo_trainer.py | 4 +++- 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6200ef8f91..c4a0232ee3 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1152,11 +1152,48 @@ def test_dpo_trainer_use_num_logits_to_keep(self): trainer.train() - def test_padding_free(self): + def test_dpo_trainer_with_tools(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + tools=[get_current_temperature], + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference") + + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + # We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When + # pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's + # what we're checking here + self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) + + def test_padding_free(self): + model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token # Normally, we need `attn_implementation="flash_attention_2"` to that the model returns correct logits. # Without it, the logits may be incorrect, but that's fine here. This test focuses only on the inner logic # of padding_free. diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 6f98b86b88..b7c18e11cc 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any, Callable, Optional, Union from transformers import TrainingArguments @@ -93,6 +93,9 @@ class DPOConfig(TrainingArguments): Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation. + tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): + List of tools (callable functions) that will be accessible to the model. + If the template does not support function calling, this argument will have no effect. > Parameters that control the training @@ -261,6 +264,13 @@ class DPOConfig(TrainingArguments): "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." }, ) + tools: Optional[list[Union[dict, Callable]]] = field( + default=None, + metadata={ + "help": "List of tools (callable functions) that will be accessible to the model. If the template does " + "not support function calling, this argument will have no effect." + }, + ) # Parameters that control the training learning_rate: float = field( diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 7609d55ad0..218f1af5a9 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -543,7 +543,9 @@ def _prepare_dataset( # Apply the chat template if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs) + dataset = dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs + ) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` From b6a084c46e17ad3439557b2c75703671eb4db151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 21 Jan 2025 15:12:04 +0100 Subject: [PATCH 05/96] =?UTF-8?q?=F0=9F=92=BE=20Reduce=20memory=20peak=20i?= =?UTF-8?q?n=20GRPO=20by=20adding=20`max=5Fprompt=5Flength`=20and=20loop?= =?UTF-8?q?=20usage=20in=20logp=20computation=20(#2598)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add max_prompt len to config * truncate prompt and compute log probs line by line --- trl/trainer/grpo_config.py | 8 ++++++++ trl/trainer/grpo_trainer.py | 28 ++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index a0a61d3b16..f26e3f9c4a 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -39,6 +39,8 @@ class GRPOConfig(TrainingArguments): > Parameters that control the data preprocessing + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): Number of generations per prompt to sample. temperature (`float`, *optional*, defaults to `0.9`): @@ -65,6 +67,12 @@ class GRPOConfig(TrainingArguments): ) # Parameters that control the data preprocessing + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." + }, + ) num_generations: Optional[int] = field( default=8, metadata={"help": "Number of generations to sample."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0e830b6822..579f55ec38 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -138,6 +138,7 @@ def data_collator(features): # No data collation is needed in GRPO return features # Training arguments + self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.generation_config = GenerationConfig( @@ -203,6 +204,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) prompt_inputs = super()._prepare_inputs(prompt_inputs) + if self.max_prompt_length is not None: + prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] + prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] + # Generate completions with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config) @@ -211,13 +216,20 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Get the per-token log probabilities for the completions for the model and the reference model def get_per_token_logps(model, input_ids): - logits = model(input_ids).logits - logits = torch.roll(logits, shifts=1, dims=1) # Shape (B*G, L) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=input_ids.unsqueeze(2)).squeeze(2) - return per_token_logps + logits = model(input_ids).logits # (B, L, V) + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it + # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. + per_token_logps = [] + for logits_row, input_ids_row in zip(logits, input_ids): + log_probs = logits_row.log_softmax(dim=-1) + token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) + per_token_logps.append(token_log_prob) + return torch.stack(per_token_logps) per_token_logps = get_per_token_logps(model, prompt_completion_ids) - per_token_logps = per_token_logps[:, prompt_length:] # get rid of the prompt + # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) + per_token_logps = per_token_logps[:, prompt_length - 1 :] with torch.inference_mode(): if self.ref_model is not None: @@ -225,7 +237,7 @@ def get_per_token_logps(model, input_ids): else: with self.accelerator.unwrap_model(model).disable_adapter(): ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids) - ref_per_token_logps = ref_per_token_logps[:, prompt_length:] # get rid of the prompt + ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :] # Compute the KL divergence between the model and the reference model per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 @@ -287,9 +299,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super().log(logs, start_time) + super().log(logs, start_time) else: # transformers<=4.46 - return super().log(logs) + super().log(logs) self._metrics = {key: [] for key in self._metrics} def create_model_card( From a5c88d6c7508beb107219de7a656118ac4a36f1f Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 21 Jan 2025 13:09:18 -0800 Subject: [PATCH 06/96] =?UTF-8?q?=E2=9A=A1=20Add=20uv=20installation=20ins?= =?UTF-8?q?tructions=20(#2601)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add uv * Update docs/source/installation.mdx * Update docs/source/installation.mdx * pypi -> PyPI --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- docs/source/installation.mdx | 25 ++++++++++++++++++++----- setup.py | 6 +++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index bf74b64175..8ab4165931 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,14 +1,29 @@ # Installation -You can install TRL either from pypi or from source: +You can install TRL either from PyPI or from source: -## pypi -Install the library with pip: +## PyPI +Install the library with pip or [uv](https://docs.astral.sh/uv/): + + + + +uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), . + +```bash +uv pip install trl +``` + + + ```bash pip install trl ``` -### Source + + + +## Source You can also install the latest version from source. First clone the repo and then run the installation with `pip`: ```bash @@ -21,4 +36,4 @@ If you want the development install you can replace the pip install with the fol ```bash pip install -e ".[dev]" -``` \ No newline at end of file +``` diff --git a/setup.py b/setup.py index e438844a75..87c52071ee 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py -To create the package for pypi. +To create the package for PyPI. 0. Prerequisites: - Dependencies: @@ -50,7 +50,7 @@ For the sources, run: "python setup.py sdist" You should now have a /dist directory with both .whl and .tar.gz source versions. -5. Check that everything looks correct by uploading the package to the pypi test server: +5. Check that everything looks correct by uploading the package to the PyPI test server: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ @@ -59,7 +59,7 @@ pip install -U tqdm pip install -i https://testpypi.python.org/pypi evaluate -6. Upload the final version to actual pypi: +6. Upload the final version to actual PyPI: twine upload dist/* -r pypi 7. Fill release notes in the tag in github once everything is looking hunky-dory. From d4222a1e08def2be56572eb2973ef3bf50143a4f Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Tue, 21 Jan 2025 22:44:18 +0100 Subject: [PATCH 07/96] =?UTF-8?q?=F0=9F=A7=A9=20PPO/RLOO/OnlineDPO=20seque?= =?UTF-8?q?nce=20generation:=20make=20deepsped=203=20weight=20gathering=20?= =?UTF-8?q?optional=20(#2557)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * PPO/RLOO/OnlineDPO: add ds3_gather_for_generation argument to control weights gathering for generation * code formatting * rephrase and document * more doc * style [ci skip] * Trigger CI --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- docs/source/reducing_memory_usage.md | 38 ++++++++++++++++++++++++++++ trl/models/utils.py | 7 ++++- trl/trainer/online_dpo_config.py | 16 ++++++++++-- trl/trainer/online_dpo_trainer.py | 4 ++- trl/trainer/ppo_config.py | 12 +++++++++ trl/trainer/ppo_trainer.py | 8 ++++-- trl/trainer/rloo_config.py | 12 +++++++++ trl/trainer/rloo_trainer.py | 8 ++++-- 8 files changed, 97 insertions(+), 8 deletions(-) diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index e015c43906..6c05490616 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -93,3 +93,41 @@ training_args = SFTConfig(..., packing=True, max_seq_length=512) Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230). + +## Disabling model gathering for generation in online methods + +When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204). + +If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter: + + + + +```python +from trl import OnlineDPOConfig + +training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import PPOConfig + +training_args = PPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig(..., ds3_gather_for_generation=False) +``` + + + + +This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds. diff --git a/trl/models/utils.py b/trl/models/utils.py index 1f4932d21b..a5b7b2c7bf 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -172,7 +172,10 @@ def add_hooks(model: "DeepSpeedEngine") -> None: @contextmanager def unwrap_model_for_generation( - model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False + model: Union["DistributedDataParallel", "DeepSpeedEngine"], + accelerator: "Accelerator", + is_peft_model: bool = False, + gather_deepspeed3_params: bool = True, ) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]: """Context manager to unwrap a model for generation. For ZeRO-3 models, we gather the weights once to speed up generation. @@ -181,6 +184,8 @@ def unwrap_model_for_generation( if is_peft_model: unwrapped_model.pretrained_model.disable_adapter() if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: + if not gather_deepspeed3_params: + yield accelerator.unwrap_model(model) with deepspeed.zero.GatheredParameters(model.parameters()): remove_hooks(model) yield accelerator.unwrap_model(model) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index d01294c2e5..8557d29628 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -64,6 +64,10 @@ class OnlineDPOConfig(TrainingArguments): Whether to disable dropout in the model and reference model. use_vllm (`bool`, *optional*, defaults to `False`): Whether to use the vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. """ learning_rate: float = field( @@ -114,8 +118,8 @@ class OnlineDPOConfig(TrainingArguments): metadata={ "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " "the reference model. For the IPO loss (`loss_type='ipo'`), β is the regularization parameter denoted by " - "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is " - "selected for each new epoch and the last β is used for the rest of the epochs." + "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β " + "is selected for each new epoch and the last β is used for the rest of the epochs." }, ) loss_type: str = field( @@ -140,6 +144,14 @@ class OnlineDPOConfig(TrainingArguments): "(`pip install vllm`)." }, ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) def __post_init__(self): super().__post_init__() diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 7c7a6b3169..44bc02c563 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -476,7 +476,9 @@ def _generate(self, model, prompts): inputs = self._prepare_inputs(inputs) prompt_ids = inputs["prompt_input_ids"].repeat(2, 1) prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1) - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: output = unwrapped_model.generate( input_ids=prompt_ids, attention_mask=prompt_mask, diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index ecaa7192d5..0b0ec0c318 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -53,6 +53,10 @@ class PPOConfig(OnPolicyConfig): Discount factor. lam (`float`, *optional*, defaults to `0.95`): Lambda value for GAE. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. """ exp_name: str = field( @@ -103,3 +107,11 @@ class PPOConfig(OnPolicyConfig): default=0.95, metadata={"help": "Lambda value for GAE."}, ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index fe3ea3a147..ef29461a70 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -414,7 +414,9 @@ def repeat_generator(): scores = [] sequence_lengths = [] values = [] - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: query_responses, logitss = batch_generation( unwrapped_model.policy, queries, @@ -688,7 +690,9 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: for batch in self.eval_dataloader: query = batch["input_ids"] with torch.no_grad(): diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index a52407c171..bd0b6ed8ed 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -50,6 +50,10 @@ class RLOOConfig(OnPolicyConfig): Whether to normalize advantages. token_level_kl (`bool`, *optional*, defaults to `True`): Whether to use token-level KL penalty or sequence-level KL penalty. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. """ exp_name: str = field( @@ -96,3 +100,11 @@ class RLOOConfig(OnPolicyConfig): default=False, metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"}, ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 321ba164be..03b7ef922d 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -310,7 +310,9 @@ def repeat_generator(): sequence_lengths = [] # Generate responses and compute logprobs - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: query_responses, logitss = batch_generation( unwrapped_model, queries, @@ -565,7 +567,9 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: for batch in self.eval_dataloader: query = batch["input_ids"] with torch.no_grad(): From a9b54a852ee12ff508773edb02e1c243817e71ae Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Wed, 22 Jan 2025 12:24:42 +0100 Subject: [PATCH 08/96] =?UTF-8?q?=F0=9F=AB=B7=20Include=20stop=20token=20i?= =?UTF-8?q?n=20policy=20model's=20generation=5Fconfig=20(#2528)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Include stop token in policy model's generation_config * Fix formatting * Update trl/trainer/ppo_trainer.py * Update trl/trainer/ppo_trainer.py * don't modify args * clarify doc * more nice doc * missing no [ci skip] * really don't modify args * oups --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- trl/trainer/ppo_trainer.py | 26 ++++++++++++++++---------- trl/trainer/utils.py | 20 ++++++++++++++++---- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index ef29461a70..83926cfd6a 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -138,10 +138,18 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.policy_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: @@ -220,8 +228,6 @@ def __init__( for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler( @@ -449,9 +455,9 @@ def repeat_generator(): # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) # Response Processing 2. run reward model on the truncated responses @@ -706,9 +712,9 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) table["query"].extend( gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1228dc7ece..719d952f1f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -993,9 +993,15 @@ class OnPolicyConfig(TrainingArguments): response_length (`int`, *optional*, defaults to `53`): Length of the response. stop_token (`str` or `None`, *optional*, defaults to `None`): - Stop token. + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. + stop_token_id (`int` or `None`, *optional*, defaults to `None`): - Truncation token id. + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. temperature (`float`, *optional*, defaults to `0.7`): Sampling temperature. missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): @@ -1054,11 +1060,17 @@ class OnPolicyConfig(TrainingArguments): ) stop_token: Optional[Literal["eos"]] = field( default=None, - metadata={"help": "Stop token."}, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, ) stop_token_id: Optional[int] = field( default=None, - metadata={"help": "Truncation token id."}, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, ) temperature: float = field( default=0.7, From fe4b5efe4e23f4331ba9c5b0c8bd92dc8302c287 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 22 Jan 2025 15:33:50 +0100 Subject: [PATCH 09/96] =?UTF-8?q?=E2=9C=82=EF=B8=8F=20Reintroduce=20`trunc?= =?UTF-8?q?ation=5Fmode`=20in=20`DPOTrainer`=20(#2551)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * reintroduce truncation mode in DPOTrainer * move truncation_mode in dataset.map invocation * truncate full sequence * "." [ci skip] * Empty commit --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- trl/trainer/dpo_config.py | 20 +++++++++++--------- trl/trainer/dpo_trainer.py | 26 +++++++++++++++++++------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b7c18e11cc..a3cdc28d28 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -71,14 +71,15 @@ class DPOConfig(TrainingArguments): Padding value to use. If `None`, the padding value of the tokenizer is used. label_pad_token_id (`int`, *optional*, defaults to `-100`): Padding value to use for labels. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. max_completion_length (`int` or `None`, *optional*, defaults to `None`): Maximum length of the completion. max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the full sequence (prompt + completion). + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. padding_free (`bool`, *optional*, defaults to `False`): Whether forward passes are performed without padding by flattening all sequences in the batch into a single continuous sequence. This approach requires associating a `position_ids` vector to track @@ -219,13 +220,6 @@ class DPOConfig(TrainingArguments): default=-100, metadata={"help": "Padding value to use for labels."}, ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "Truncation mode to use when the prompt is too long.", - "choices": ["keep_end", "keep_start"], - }, - ) max_prompt_length: Optional[int] = field( default=512, metadata={"help": "Maximum length of the prompt."}, @@ -238,6 +232,14 @@ class DPOConfig(TrainingArguments): default=1024, metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " + "and `'keep_start'`.", + "choices": ["keep_end", "keep_start"], + }, + ) padding_free: bool = field( default=False, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 218f1af5a9..903bb719ca 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -388,12 +388,12 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) - self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id self.max_prompt_length = args.max_prompt_length - self.truncation_mode = args.truncation_mode self.max_completion_length = args.max_completion_length + self.max_length = args.max_length + self.truncation_mode = args.truncation_mode self.precompute_ref_log_probs = args.precompute_ref_log_probs self.use_num_logits_to_keep = args.use_num_logits_to_keep @@ -595,7 +595,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l >>> from transformers import GPT2Tokenizer >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} - >>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False) + >>> DPOTrainer.tokenize_row( + ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False + ... ) {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} ``` """ @@ -1145,10 +1147,20 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - loss_mask = loss_mask[:, : self.args.max_length] + if self.max_length is not None: + if self.truncation_mode == "keep_end": + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + elif self.truncation_mode == "keep_start": + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) if self.use_num_logits_to_keep: # Compute num_logits_to_keep based on loss_mask pattern: From 949db2357e62d2f0a34decfc5e87eeeea0c6d72c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:38:15 +0100 Subject: [PATCH 10/96] =?UTF-8?q?=F0=9F=91=8B=20Drop=20MDX=20(#2611)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/{alignprop_trainer.mdx => alignprop_trainer.md} | 0 docs/source/{bco_trainer.mdx => bco_trainer.md} | 0 docs/source/{best_of_n.mdx => best_of_n.md} | 0 docs/source/{callbacks.mdx => callbacks.md} | 0 docs/source/{clis.mdx => clis.md} | 0 docs/source/{cpo_trainer.mdx => cpo_trainer.md} | 0 docs/source/{customization.mdx => customization.md} | 0 docs/source/{data_utils.mdx => data_utils.md} | 0 docs/source/{dataset_formats.mdx => dataset_formats.md} | 0 docs/source/{ddpo_trainer.mdx => ddpo_trainer.md} | 0 docs/source/{detoxifying_a_lm.mdx => detoxifying_a_lm.md} | 0 docs/source/{dpo_trainer.mdx => dpo_trainer.md} | 0 docs/source/{index.mdx => index.md} | 0 docs/source/{installation.mdx => installation.md} | 0 .../{iterative_sft_trainer.mdx => iterative_sft_trainer.md} | 0 docs/source/{judges.mdx => judges.md} | 0 docs/source/{kto_trainer.mdx => kto_trainer.md} | 0 docs/source/{learning_tools.mdx => learning_tools.md} | 0 docs/source/{logging.mdx => logging.md} | 0 docs/source/{models.mdx => models.md} | 0 docs/source/{multi_adapter_rl.mdx => multi_adapter_rl.md} | 0 docs/source/{prm_trainer.mdx => prm_trainer.md} | 0 docs/source/{quickstart.mdx => quickstart.md} | 0 docs/source/{reward_trainer.mdx => reward_trainer.md} | 0 docs/source/{sentiment_tuning.mdx => sentiment_tuning.md} | 0 docs/source/{sft_trainer.mdx => sft_trainer.md} | 0 docs/source/{using_llama_models.mdx => using_llama_models.md} | 0 docs/source/{xpo_trainer.mdx => xpo_trainer.md} | 0 28 files changed, 0 insertions(+), 0 deletions(-) rename docs/source/{alignprop_trainer.mdx => alignprop_trainer.md} (100%) rename docs/source/{bco_trainer.mdx => bco_trainer.md} (100%) rename docs/source/{best_of_n.mdx => best_of_n.md} (100%) rename docs/source/{callbacks.mdx => callbacks.md} (100%) rename docs/source/{clis.mdx => clis.md} (100%) rename docs/source/{cpo_trainer.mdx => cpo_trainer.md} (100%) rename docs/source/{customization.mdx => customization.md} (100%) rename docs/source/{data_utils.mdx => data_utils.md} (100%) rename docs/source/{dataset_formats.mdx => dataset_formats.md} (100%) rename docs/source/{ddpo_trainer.mdx => ddpo_trainer.md} (100%) rename docs/source/{detoxifying_a_lm.mdx => detoxifying_a_lm.md} (100%) rename docs/source/{dpo_trainer.mdx => dpo_trainer.md} (100%) rename docs/source/{index.mdx => index.md} (100%) rename docs/source/{installation.mdx => installation.md} (100%) rename docs/source/{iterative_sft_trainer.mdx => iterative_sft_trainer.md} (100%) rename docs/source/{judges.mdx => judges.md} (100%) rename docs/source/{kto_trainer.mdx => kto_trainer.md} (100%) rename docs/source/{learning_tools.mdx => learning_tools.md} (100%) rename docs/source/{logging.mdx => logging.md} (100%) rename docs/source/{models.mdx => models.md} (100%) rename docs/source/{multi_adapter_rl.mdx => multi_adapter_rl.md} (100%) rename docs/source/{prm_trainer.mdx => prm_trainer.md} (100%) rename docs/source/{quickstart.mdx => quickstart.md} (100%) rename docs/source/{reward_trainer.mdx => reward_trainer.md} (100%) rename docs/source/{sentiment_tuning.mdx => sentiment_tuning.md} (100%) rename docs/source/{sft_trainer.mdx => sft_trainer.md} (100%) rename docs/source/{using_llama_models.mdx => using_llama_models.md} (100%) rename docs/source/{xpo_trainer.mdx => xpo_trainer.md} (100%) diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.md similarity index 100% rename from docs/source/alignprop_trainer.mdx rename to docs/source/alignprop_trainer.md diff --git a/docs/source/bco_trainer.mdx b/docs/source/bco_trainer.md similarity index 100% rename from docs/source/bco_trainer.mdx rename to docs/source/bco_trainer.md diff --git a/docs/source/best_of_n.mdx b/docs/source/best_of_n.md similarity index 100% rename from docs/source/best_of_n.mdx rename to docs/source/best_of_n.md diff --git a/docs/source/callbacks.mdx b/docs/source/callbacks.md similarity index 100% rename from docs/source/callbacks.mdx rename to docs/source/callbacks.md diff --git a/docs/source/clis.mdx b/docs/source/clis.md similarity index 100% rename from docs/source/clis.mdx rename to docs/source/clis.md diff --git a/docs/source/cpo_trainer.mdx b/docs/source/cpo_trainer.md similarity index 100% rename from docs/source/cpo_trainer.mdx rename to docs/source/cpo_trainer.md diff --git a/docs/source/customization.mdx b/docs/source/customization.md similarity index 100% rename from docs/source/customization.mdx rename to docs/source/customization.md diff --git a/docs/source/data_utils.mdx b/docs/source/data_utils.md similarity index 100% rename from docs/source/data_utils.mdx rename to docs/source/data_utils.md diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.md similarity index 100% rename from docs/source/dataset_formats.mdx rename to docs/source/dataset_formats.md diff --git a/docs/source/ddpo_trainer.mdx b/docs/source/ddpo_trainer.md similarity index 100% rename from docs/source/ddpo_trainer.mdx rename to docs/source/ddpo_trainer.md diff --git a/docs/source/detoxifying_a_lm.mdx b/docs/source/detoxifying_a_lm.md similarity index 100% rename from docs/source/detoxifying_a_lm.mdx rename to docs/source/detoxifying_a_lm.md diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.md similarity index 100% rename from docs/source/dpo_trainer.mdx rename to docs/source/dpo_trainer.md diff --git a/docs/source/index.mdx b/docs/source/index.md similarity index 100% rename from docs/source/index.mdx rename to docs/source/index.md diff --git a/docs/source/installation.mdx b/docs/source/installation.md similarity index 100% rename from docs/source/installation.mdx rename to docs/source/installation.md diff --git a/docs/source/iterative_sft_trainer.mdx b/docs/source/iterative_sft_trainer.md similarity index 100% rename from docs/source/iterative_sft_trainer.mdx rename to docs/source/iterative_sft_trainer.md diff --git a/docs/source/judges.mdx b/docs/source/judges.md similarity index 100% rename from docs/source/judges.mdx rename to docs/source/judges.md diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.md similarity index 100% rename from docs/source/kto_trainer.mdx rename to docs/source/kto_trainer.md diff --git a/docs/source/learning_tools.mdx b/docs/source/learning_tools.md similarity index 100% rename from docs/source/learning_tools.mdx rename to docs/source/learning_tools.md diff --git a/docs/source/logging.mdx b/docs/source/logging.md similarity index 100% rename from docs/source/logging.mdx rename to docs/source/logging.md diff --git a/docs/source/models.mdx b/docs/source/models.md similarity index 100% rename from docs/source/models.mdx rename to docs/source/models.md diff --git a/docs/source/multi_adapter_rl.mdx b/docs/source/multi_adapter_rl.md similarity index 100% rename from docs/source/multi_adapter_rl.mdx rename to docs/source/multi_adapter_rl.md diff --git a/docs/source/prm_trainer.mdx b/docs/source/prm_trainer.md similarity index 100% rename from docs/source/prm_trainer.mdx rename to docs/source/prm_trainer.md diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.md similarity index 100% rename from docs/source/quickstart.mdx rename to docs/source/quickstart.md diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.md similarity index 100% rename from docs/source/reward_trainer.mdx rename to docs/source/reward_trainer.md diff --git a/docs/source/sentiment_tuning.mdx b/docs/source/sentiment_tuning.md similarity index 100% rename from docs/source/sentiment_tuning.mdx rename to docs/source/sentiment_tuning.md diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.md similarity index 100% rename from docs/source/sft_trainer.mdx rename to docs/source/sft_trainer.md diff --git a/docs/source/using_llama_models.mdx b/docs/source/using_llama_models.md similarity index 100% rename from docs/source/using_llama_models.mdx rename to docs/source/using_llama_models.md diff --git a/docs/source/xpo_trainer.mdx b/docs/source/xpo_trainer.md similarity index 100% rename from docs/source/xpo_trainer.mdx rename to docs/source/xpo_trainer.md From 887c1f3fa3b9942d1f6089e7a5314d43a339d6b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:30:22 +0100 Subject: [PATCH 11/96] =?UTF-8?q?=F0=9F=92=8E=20Rename=20an=20inner=20var?= =?UTF-8?q?=20in=20GRPO=20to=20improve=20clarity=20(#2616)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * rename advatages to per_token_loss for clarity * doc ci --- .github/workflows/build_pr_documentation.yml | 2 +- trl/trainer/grpo_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index bf72dc7c1e..acc8d16d35 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -9,7 +9,7 @@ concurrency: jobs: build: - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@e4fcf608695cf4bddb8c7f4f72aa15fa14110a94 + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 579f55ec38..683df410a5 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -281,8 +281,8 @@ def get_per_token_logps(model, input_ids): advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) # x - x.detach() allows for preserving gradients from x - advantages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - per_token_loss = -(advantages - self.beta * per_token_kl) + per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics From a1d295511675963d2031014484f2445b152d9452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:39:45 +0100 Subject: [PATCH 12/96] =?UTF-8?q?=F0=9F=8F=86=20Custom=20reward=20function?= =?UTF-8?q?=20for=20GRPO=20and=20shiny=20doc=20(#2606)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit * doc on custom reward function * test * doc doc doc * fix collator * style * links? * I need a docdoc 🎵 * fix link * I do like writing doc tbh * it takes time, but it's worth it * no return! * type hint * it's probably the best of both worlds [ci skip] * new doc before implementation * tests * more doc * style * multiple pretrained funcs * fix arg name * main? * example for R1 * fix script * clearer * import [ci skip] * Update docs/source/grpo_trainer.md Co-authored-by: lewtun --------- Co-authored-by: lewtun --- docs/source/grpo_trainer.md | 93 +++++++++++++++++ tests/test_grpo_trainer.py | 155 +++++++++++++++++++++++++++- trl/scripts/grpo.py | 2 +- trl/trainer/grpo_trainer.py | 199 +++++++++++++++++++++++++++--------- 4 files changed, 392 insertions(+), 57 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 59abe04356..72a36d2be1 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -114,6 +114,99 @@ The GRPO Trainer logs the following metrics: - `reward_std` : The average standard deviation within reward groups. - `kl` : The average KL divergence between the model and the reference model calculated on completions. +## Customization + +### Using a custom reward function + +The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements: + +1. **Input arguments**: + - The function must accept two arguments: `prompts` and `completions`. + - Depending on the dataset format, the input will vary: + - For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings. + - For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries. + +2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion. + +#### Example 1: Reward longer completions + +Below is an example of a reward function for a standard format that rewards longer completions: + +```python +def reward_func(prompts, completions): + """Reward function that gives higher scores to longer completions.""" + return [float(len(completion)) for completion in completions] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] +>>> completions = [" blue.", " in the sky."] +>>> print(reward_func(prompts, completions)) +[6.0, 12.0] +``` + +#### Example 2: Reward completions with specific format + +Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +It is designed for conversational format, where prompts and completions consist of structured messages. + +```python +import re + +def format_reward_func(prompts, completions): + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?.*?$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] +``` + +You can test this function as follows: + +```python +>>> prompts = [ +... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}], +... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}], +... ] +>>> completions = [ +... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], +... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], +... ] +>>> format_reward_func(prompts, completions) +[1.0, 0.0] +>>> +``` + +#### Passing the reward function to the trainer + +To use your custom reward function, pass it to the `GRPOTrainer` as follows: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=reward_func, + ..., +) +``` + +If you have multiple reward functions, you can pass them as a list: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=[reward_func1, reward_func2], + ..., +) +``` + +and the reward will be computed as the sum of the rewards from each function. + +Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. + ## GRPOTrainer [[autodoc]] GRPOTrainer diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 1afd490bd1..8eaa6a566c 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -35,7 +35,7 @@ def test_init_minimal(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", train_dataset=dataset, ) @@ -54,7 +54,7 @@ def test_training(self, config_name): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", args=training_args, train_dataset=dataset, ) @@ -87,7 +87,7 @@ def test_training_peft(self): ) trainer = GRPOTrainer( model=model, - reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", args=training_args, train_dataset=dataset, peft_config=LoraConfig(), @@ -130,10 +130,155 @@ def test_training_different_reward_model(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_model=reward_model, + reward_funcs=reward_model, + args=training_args, + train_dataset=dataset, + reward_processing_classes=reward_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_standard(self): + # Test if trainer can handle reward function with standard format + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(prompts, completions): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_conversational(self): + # Test if trainer can handle reward function with conversational format + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + def reward_func(prompts, completions): + """Reward function that gives higher scores to longer completion content.""" + completion_contents = [completion[0]["content"] for completion in completions] + return [float(len(content)) for content in completion_contents] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_reward_funcs(self): + # Test that GRPOTrainer can be instantiated with multiple reward functions + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(prompts, completions): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(prompts, completions): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_mixed_reward_funcs(self): + # Test if the trainer can handle a mix of reward functions and reward models + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(prompts, completions): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"], args=training_args, train_dataset=dataset, - reward_processing_class=reward_tokenizer, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index 552f6c3a4c..4b336b28e9 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -60,7 +60,7 @@ def main(script_args, training_args, model_args): # Initialize the GRPO trainer trainer = GRPOTrainer( model=model, - reward_model=reward_model, + reward_funcs=reward_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 683df410a5..d787178e8f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -17,7 +17,6 @@ from typing import Any, Callable, Optional, Union import torch -import torch.nn as nn import torch.utils.data import transformers from datasets import Dataset, IterableDataset @@ -26,8 +25,6 @@ AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, - DataCollator, - EvalPrediction, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -49,24 +46,105 @@ if is_wandb_available(): import wandb +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] + class GRPOTrainer(Trainer): + """ + Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the + paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). + + Example: + + ```python + from datasets import load_dataset + from trl import GRPOTrainer + + dataset = load_dataset("trl-lib/tldr", split="train") + + trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs="weqweasdas/RM-Gemma-2B", + train_dataset=dataset, + ) + + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: This should take a list of prompts and completions and return a list of + rewards. For more details, see [Using a custom reward function](#using-a-custom-reward-function). + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`GRPOConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. + reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. + For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), + the corresponding entries in `reward_processing_classes` are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + def __init__( self, - model: Union[str, PreTrainedModel, nn.Module] = None, - reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + model: Union[str, PreTrainedModel], + reward_funcs: Union[RewardFunc, list[RewardFunc]], args: GRPOConfig = None, - data_collator: Optional[DataCollator] = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, - reward_processing_class: Optional[PreTrainedTokenizerBase] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - compute_loss_func: Optional[Callable] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, ): # Args @@ -114,28 +192,40 @@ def __init__( if processing_class is None: processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") - # Reward model - if isinstance(reward_model, str): - reward_model = AutoModelForSequenceClassification.from_pretrained( - reward_model, num_labels=1, **model_init_kwargs - ) - self.reward_model = reward_model + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + self.reward_funcs = reward_funcs # Reward processing class - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(reward_model.config._name_or_path) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - self.reward_processing_class = reward_processing_class - # The reward model computes the reward for the latest non-padded token in the input sequence. - # So it's important to set the pad token ID to the padding token ID of the processing class. - self.reward_model.config.pad_token_id = reward_processing_class.pad_token_id - - # Data loading and preprocessing - if data_collator is None: - - def data_collator(features): # No data collation is needed in GRPO - return features + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + else: + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError("The number of reward processing classes must match the number of reward functions.") + + for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + # Data collator + def data_collator(features): # No data collation is needed in GRPO + return features # Training arguments self.max_prompt_length = args.max_prompt_length @@ -168,17 +258,16 @@ def data_collator(features): # No data collation is needed in GRPO train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, - compute_loss_func=compute_loss_func, - compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) if self.ref_model is not None: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. @@ -252,24 +341,32 @@ def get_per_token_logps(model, input_ids): # Decode the generated completions completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] # Compute the rewards prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] - if is_conversational(inputs[0]): - completions = [[{"role": "assistant", "content": completion}] for completion in completions] - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, self.reward_processing_class)["text"] for x in messages] - reward_inputs = self.reward_processing_class( - texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - else: - texts = [p + c for p, c in zip(prompts, completions)] - reward_inputs = self.reward_processing_class( - texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - reward_inputs = super()._prepare_inputs(reward_inputs) - with torch.inference_mode(): - rewards = self.reward_model(**reward_inputs).logits[:, 0] # Shape (B*G,) + + rewards = torch.zeros(len(self.reward_funcs), len(prompts), device=device) + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes) + ): + if isinstance(reward_func, PreTrainedModel): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + rewards[i] = torch.tensor(reward_func(prompts, completions)) + # Sum the rewards from all reward functions + rewards = rewards.sum(dim=0) # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) From 40c238395e345e6013f899b3768b53c73e60844b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Jan 2025 12:12:06 -0500 Subject: [PATCH 13/96] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20DPO=20gradient=20acc?= =?UTF-8?q?umulation=20loss=20scaling=20(#2615)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix DPO for gradient accumulation * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 903bb719ca..886022a612 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -480,6 +480,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) From 59c201433cfae41c5869372823ba9ce177e59a23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 23 Jan 2025 18:57:43 +0100 Subject: [PATCH 14/96] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20BCO=20gradient=20acc?= =?UTF-8?q?umulation=20loss=20scaling=20(#2638)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/bco_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index ccbb701d9f..18df4a7210 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -682,6 +682,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) From 0e216f7411fd55f6d90edf1b4dbee86b70425894 Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Thu, 23 Jan 2025 22:46:37 +0330 Subject: [PATCH 15/96] =?UTF-8?q?=F0=9F=8D=AD=20Custom=20reward=20function?= =?UTF-8?q?=20for=20RLOO=20(#2612)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * rloo custom reward function and test * idont even know why i did that * removing get_reward_custom * remove get_reward_custom test * fix code quality check * adding test * end this mysery already * fix test --- tests/test_rloo_trainer.py | 39 +++++++++++++++++++++++++++++ trl/trainer/rloo_trainer.py | 50 +++++++++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 13 deletions(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 7de90a3b24..dcafe7829e 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -172,3 +172,42 @@ def test_rloo_training(self): # Check if objective/rlhf_reward is available self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1]) + + def test_rloo_training_with_custom_reward(self): + # dummy reward function + def reward_function(texts): + # based on length of text + rewards = [len(text) for text in texts] + return rewards + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RLOOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + total_episodes=1, + num_train_epochs=1, + max_steps=2, + report_to="none", + ) + + # Create a simple dataset + dummy_text = [{"content": "Hello World!", "role": "user"}] + dummy_data = self.tokenizer.apply_chat_template(dummy_text) + dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]}) + + trainer = RLOOTrainer( + config=training_args, + policy=self.policy_model, + reward_model=reward_function, + ref_policy=self.policy_ref_model, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Test that training completes without errors + trainer.train() + + # Check if objective/rlhf_reward is available + self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1]) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 03b7ef922d..6626baa15c 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -18,7 +18,7 @@ import textwrap import time from collections import defaultdict -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np import pandas as pd @@ -79,7 +79,7 @@ def __init__( ], policy: nn.Module, ref_policy: nn.Module, - reward_model: nn.Module, + reward_model: Union[nn.Module, Callable[[list[str]], list[float]]], train_dataset: Dataset, data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, @@ -152,7 +152,8 @@ def __init__( # setup model, optimizer, and others ######### for module in [policy, ref_policy, reward_model]: - disable_dropout_in_model(module) + if isinstance(module, nn.Module): + disable_dropout_in_model(module) if args.stop_token and args.stop_token == "eos": args.stop_token_id = self.processing_class.eos_token_id self.model = policy @@ -219,16 +220,18 @@ def __init__( self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) + if isinstance(self.reward_model, nn.Module): + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) self.ref_policy = prepare_deepspeed( self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.deepspeed = self.model else: self.ref_policy = self.ref_policy.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) + if isinstance(self.reward_model, nn.Module): + self.reward_model = self.reward_model.to(self.accelerator.device) def get_train_dataloader(self) -> DataLoader: return self.dataloader @@ -350,9 +353,18 @@ def repeat_generator(): # Response Processing 2. run reward model on the truncated responses postprocessed_query_response = torch.cat((query, postprocessed_response), 1) sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) + + if isinstance(reward_model, nn.Module): + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + else: + score = torch.tensor( + reward_model( + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) + ), + dtype=torch.float, + ).to(device) # Store batch results responses.append(response) @@ -595,9 +607,21 @@ def generate_completions(self, sampling: bool = False): ) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) + + if isinstance(self.reward_model, nn.Module): + _, score, _ = get_reward( + self.reward_model, + postprocessed_query_response, + processing_class.pad_token_id, + context_length, + ) + else: + score = torch.tensor( + self.reward_model( + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) + ), + dtype=torch.float, + ).to(postprocessed_query_response.device) table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: From f34b70a32ef2820d3fd5c5b1ff6d1fd1e7799f04 Mon Sep 17 00:00:00 2001 From: Superskyyy Date: Thu, 23 Jan 2025 15:23:54 -0500 Subject: [PATCH 16/96] =?UTF-8?q?=F0=9F=8C=AF=20Fix=20context=20manager=20?= =?UTF-8?q?runtime=20error=20when=20gather=20is=20disabled=20(#2639)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/models/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trl/models/utils.py b/trl/models/utils.py index a5b7b2c7bf..a024b10a09 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -186,9 +186,10 @@ def unwrap_model_for_generation( if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: if not gather_deepspeed3_params: yield accelerator.unwrap_model(model) - with deepspeed.zero.GatheredParameters(model.parameters()): - remove_hooks(model) - yield accelerator.unwrap_model(model) - add_hooks(model) + else: + with deepspeed.zero.GatheredParameters(model.parameters()): + remove_hooks(model) + yield accelerator.unwrap_model(model) + add_hooks(model) else: yield unwrapped_model From 5e4d7be0e10d98f18842297cba34544d92f722ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 24 Jan 2025 09:06:16 +0100 Subject: [PATCH 17/96] Update grpo_trainer.md --- docs/source/grpo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 72a36d2be1..56b46e78ac 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -43,7 +43,7 @@ training_args = GRPOConfig( ) trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", - reward_model="weqweasdas/RM-Gemma-2B", + reward_funcs="weqweasdas/RM-Gemma-2B", args=training_args, train_dataset=dataset, peft_config=LoraConfig(task_type="CAUSAL_LM"), From 8e65825d4cb22cf304bcd245adef978c473efcb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 24 Jan 2025 12:22:46 +0100 Subject: [PATCH 18/96] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20CPO=20gradient=20acc?= =?UTF-8?q?umulation=20loss=20scaling=20(#2645)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/cpo_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 050dddad99..644a2d5353 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -357,6 +357,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) From d14f7f3eb274c77490cf1fc34b2d379288adb794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:22:54 +0100 Subject: [PATCH 19/96] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20GRPO=20gradient=20ac?= =?UTF-8?q?cumulation=20loss=20scaling=20(#2647)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/grpo_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d787178e8f..86753a3a5c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -262,6 +262,11 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + if self.ref_model is not None: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) @@ -393,7 +398,7 @@ def get_per_token_logps(model, input_ids): return loss def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): super().log(logs, start_time) From 6f99f42f724123409422f2fad42bf56fa91f366f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:23:16 +0100 Subject: [PATCH 20/96] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20KTO=20gradient=20acc?= =?UTF-8?q?umulation=20loss=20scaling=20(#2648)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/kto_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 897ce25520..c45a88d554 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -746,6 +746,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) From 2578e950238077d1ce421801e19f26a5f74a619c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:31:07 +0100 Subject: [PATCH 21/96] =?UTF-8?q?=F0=9F=9A=9B=20Provide=20all=20columns=20?= =?UTF-8?q?of=20the=20dataset=20to=20the=20reward=20function=20(#2650)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * The reward function is provided with all col from the dataset * Minor clarifications * minor renaming in doc [ci skip] * fix indentation --- docs/source/grpo_trainer.md | 44 +++++++++++++++++++++++++++------ tests/test_grpo_trainer.py | 49 +++++++++++++++++++++++++++++++++---- trl/trainer/grpo_config.py | 12 +++++++++ trl/trainer/grpo_trainer.py | 14 ++++++++--- 4 files changed, 104 insertions(+), 15 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 56b46e78ac..9db61842ac 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -121,7 +121,12 @@ The GRPO Trainer logs the following metrics: The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements: 1. **Input arguments**: - - The function must accept two arguments: `prompts` and `completions`. + - The function must accept the following as keyword arguments: + - `prompts` (contains the prompts), + - `completions` (contains the generated completions), + - All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument. + + The easiest way to comply with this requirement is to use `**kwargs` in the function signature. - Depending on the dataset format, the input will vary: - For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings. - For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries. @@ -133,7 +138,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa Below is an example of a reward function for a standard format that rewards longer completions: ```python -def reward_func(prompts, completions): +def reward_func(completions, **kwargs): """Reward function that gives higher scores to longer completions.""" return [float(len(completion)) for completion in completions] ``` @@ -143,19 +148,19 @@ You can test it as follows: ```python >>> prompts = ["The sky is", "The sun is"] >>> completions = [" blue.", " in the sky."] ->>> print(reward_func(prompts, completions)) +>>> print(reward_func(prompts=prompts, completions=completions)) [6.0, 12.0] ``` #### Example 2: Reward completions with specific format -Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). It is designed for conversational format, where prompts and completions consist of structured messages. ```python import re -def format_reward_func(prompts, completions): +def format_reward_func(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" pattern = r"^.*?.*?$" completion_contents = [completion[0]["content"] for completion in completions] @@ -174,9 +179,34 @@ You can test this function as follows: ... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], ... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], ... ] ->>> format_reward_func(prompts, completions) +>>> format_reward_func(prompts=prompts, completions=completions) +[1.0, 0.0] +``` + +#### Example 3: Reward completions based on a reference + +Below is an example of a reward function that checks if the is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`. + +```python +import re + +def reward_func(completions, ground_truth, **kwargs): + # Regular expression to capture content inside \boxed{} + matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions] + contents = [match.group(1) if match else "" for match in matches] + # Reward 1 if the content is the same as the ground truth, 0 otherwise + return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)] +``` + +You can test this function as follows: + +```python +>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."] +>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."] +>>> ground_truth = ["2", "5"] +>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth) [1.0, 0.0] ->>> ``` #### Passing the reward function to the trainer diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 8eaa6a566c..5de52ef69d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -151,7 +151,7 @@ def test_training_reward_func_standard(self): # Test if trainer can handle reward function with standard format dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - def reward_func(prompts, completions): + def reward_func(completions, **kwargs): """Reward function that rewards longer completions.""" return [float(len(completion)) for completion in completions] @@ -186,7 +186,7 @@ def test_training_reward_func_conversational(self): # Test if trainer can handle reward function with conversational format dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") - def reward_func(prompts, completions): + def reward_func(completions, **kwargs): """Reward function that gives higher scores to longer completion content.""" completion_contents = [completion[0]["content"] for completion in completions] return [float(len(content)) for content in completion_contents] @@ -222,11 +222,11 @@ def test_training_multiple_reward_funcs(self): # Test that GRPOTrainer can be instantiated with multiple reward functions dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - def reward_func1(prompts, completions): + def reward_func1(completions, **kwargs): """Reward function that rewards longer completions.""" return [float(len(completion)) for completion in completions] - def reward_func2(prompts, completions): + def reward_func2(completions, **kwargs): """Reward function that rewards completions with more unique letters.""" return [float(len(set(completion))) for completion in completions] @@ -261,7 +261,7 @@ def test_training_multiple_mixed_reward_funcs(self): # Test if the trainer can handle a mix of reward functions and reward models dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - def reward_func(prompts, completions): + def reward_func(completions, **kwargs): """Reward function that rewards longer completions.""" return [float(len(completion)) for completion in completions] @@ -291,3 +291,42 @@ def reward_func(prompts, completions): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_additional_column(self): + # Test if trainer can handle reward function that rely on additional columns in the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Add a column to the dataset (dummy example, the column could be anything) + some_values = list(range(len(dataset))) + dataset = dataset.add_column("some_values", some_values) + + def reward_func(completions, some_values, **kwargs): + """Reward function that rewards completions with lengths closer to the values in some_values.""" + return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index f26e3f9c4a..490ee90554 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -39,6 +39,9 @@ class GRPOConfig(TrainingArguments): > Parameters that control the data preprocessing + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that + requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): @@ -67,6 +70,15 @@ class GRPOConfig(TrainingArguments): ) # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on + # additional columns to compute the reward + remove_unused_columns: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " + "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." + }, + ) max_prompt_length: Optional[int] = field( default=512, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 86753a3a5c..edfbe9f100 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -94,8 +94,9 @@ class GRPOTrainer(Trainer): using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the keyword arguments in `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. - - A custom reward function: This should take a list of prompts and completions and return a list of - rewards. For more details, see [Using a custom reward function](#using-a-custom-reward-function). + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. For more details, see + [Using a custom reward function](#using-a-custom-reward-function). - A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed. args ([`GRPOConfig`], *optional*, defaults to `None`): @@ -369,7 +370,14 @@ def get_per_token_logps(model, input_ids): with torch.inference_mode(): rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: - rewards[i] = torch.tensor(reward_func(prompts, completions)) + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} + for key in reward_kwargs: + for example in inputs: + # Repeat each value in the column for `num_generations` times + reward_kwargs[key].extend([example[key]] * self.num_generations) + output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # Sum the rewards from all reward functions rewards = rewards.sum(dim=0) From aeb03cf1a9d367440ba6ffcb6c14de8911288282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 25 Jan 2025 10:10:29 +0100 Subject: [PATCH 22/96] =?UTF-8?q?=F0=9F=91=90=20DeepSpeed=20integration=20?= =?UTF-8?q?for=20GRPO=20(#2652)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/models/__init__.py | 4 ++-- trl/models/utils.py | 35 +++++++++++++++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 14 +++++++++++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/trl/models/__init__.py b/trl/models/__init__.py index db998369c3..2365e7c1de 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -20,7 +20,7 @@ _import_structure = { "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"], "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], - "utils": ["SUPPORTED_ARCHITECTURES", "setup_chat_format", "unwrap_model_for_generation"], + "utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"], } try: @@ -39,7 +39,7 @@ if TYPE_CHECKING: from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead - from .utils import SUPPORTED_ARCHITECTURES, setup_chat_format, unwrap_model_for_generation + from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation try: if not is_diffusers_available(): diff --git a/trl/models/utils.py b/trl/models/utils.py index a024b10a09..22a30c0afb 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -14,6 +14,7 @@ import itertools from contextlib import contextmanager +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Optional, Union @@ -193,3 +194,37 @@ def unwrap_model_for_generation( add_hooks(model) else: yield unwrapped_model + + +def prepare_deepspeed(model, accelerator): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + stage = config_kwargs["zero_optimization"]["stage"] + + if model is not None: + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and stage == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache + # @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO + # disabled (stage 0) + if stage != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index edfbe9f100..bf2ef75aa0 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -32,10 +32,11 @@ TrainerCallback, is_wandb_available, ) +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import is_peft_available from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template -from ..models import create_reference_model, unwrap_model_for_generation +from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .grpo_config import GRPOConfig from .utils import generate_model_card, get_comet_experiment_url @@ -158,6 +159,7 @@ def __init__( # Trained model model_init_kwargs = args.model_init_kwargs or {} if isinstance(model, str): + model_id = model torch_dtype = model_init_kwargs.get("torch_dtype") if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: pass # torch_dtype is already a torch.dtype or "auto" or None @@ -171,6 +173,7 @@ def __init__( ) model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) else: + model_id = model.config._name_or_path if args.model_init_kwargs is not None: raise ValueError( "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " @@ -181,7 +184,9 @@ def __init__( model = get_peft_model(model, peft_config) # Reference model - if peft_config is None: + if is_deepspeed_zero3_enabled(): + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + elif peft_config is None: # If PEFT configuration is not provided, create a reference model based on the initial model. self.ref_model = create_reference_model(model) else: @@ -269,7 +274,10 @@ def data_collator(features): # No data collation is needed in GRPO self.model_accepts_loss_kwargs = False if self.ref_model is not None: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, PreTrainedModel): From 317d2d477ba4dcebe81fbfcc6abc7bfb91b50ad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 25 Jan 2025 11:43:00 +0100 Subject: [PATCH 23/96] =?UTF-8?q?=F0=9F=94=8E=20Finegrained=20reward=20log?= =?UTF-8?q?ging=20for=20GRPO=20(#2651)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/grpo_trainer.md | 1 + trl/trainer/grpo_trainer.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 9db61842ac..4abd250ae4 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -110,6 +110,7 @@ In TRL though, as in the original paper, we only do one update per generation, s The GRPO Trainer logs the following metrics: +- `reward/{reward_func_name}`: The reward computed by each reward function. - `reward`: The average reward. - `reward_std` : The average standard deviation within reward groups. - `kl` : The average KL divergence between the model and the reference model calculated on completions. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bf2ef75aa0..3cddc9eb5c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,6 +14,7 @@ import os import textwrap +from collections import defaultdict from typing import Any, Callable, Optional, Union import torch @@ -255,7 +256,7 @@ def data_collator(features): # No data collation is needed in GRPO model.warnings_issued["estimate_tokens"] = True # Initialize the metrics - self._metrics = {"kl": [], "reward": [], "reward_std": []} + self._metrics = defaultdict(list) super().__init__( model=model, @@ -361,7 +362,7 @@ def get_per_token_logps(model, input_ids): # Compute the rewards prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] - rewards = torch.zeros(len(self.reward_funcs), len(prompts), device=device) + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): @@ -376,7 +377,7 @@ def get_per_token_logps(model, input_ids): ) reward_inputs = super()._prepare_inputs(reward_inputs) with torch.inference_mode(): - rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} @@ -385,9 +386,10 @@ def get_per_token_logps(model, input_ids): # Repeat each value in the column for `num_generations` times reward_kwargs[key].extend([example[key]] * self.num_generations) output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) - rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + # Sum the rewards from all reward functions - rewards = rewards.sum(dim=0) + rewards = rewards_per_func.sum(dim=1) # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) @@ -404,6 +406,14 @@ def get_per_token_logps(model, input_ids): loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics + reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) + self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) @@ -414,13 +424,13 @@ def get_per_token_logps(model, input_ids): return loss def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): super().log(logs, start_time) else: # transformers<=4.46 super().log(logs) - self._metrics = {key: [] for key in self._metrics} + self._metrics.clear() def create_model_card( self, From 807046b7d739294502de837f04803e7165c60c71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 25 Jan 2025 13:14:34 +0100 Subject: [PATCH 24/96] =?UTF-8?q?=F0=9F=93=8D=20Disable=20caching=20when?= =?UTF-8?q?=20grad=20checkpointing=20enable=20in=20GRPO=20(#2653)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * disable caching when grad checkpointing * style --- trl/trainer/grpo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3cddc9eb5c..1fdea42328 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -172,6 +172,10 @@ def __init__( "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." ) + # Disable caching if gradient checkpointing is enabled (not supported) + model_init_kwargs["use_cache"] = ( + False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") + ) model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) else: model_id = model.config._name_or_path From 472065665461685ddc912b75d646a459599b4ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:56:09 +0100 Subject: [PATCH 25/96] =?UTF-8?q?=F0=9F=93=8F=20Log=20completion=20length?= =?UTF-8?q?=20in=20GRPO=20(#2659)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/grpo_trainer.md | 1 + trl/trainer/grpo_trainer.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 4abd250ae4..7250b72cab 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -110,6 +110,7 @@ In TRL though, as in the original paper, we only do one update per generation, s The GRPO Trainer logs the following metrics: +- `completion_length`: The average completion length. - `reward/{reward_func_name}`: The reward computed by each reward function. - `reward`: The average reward. - `reward_std` : The average standard deviation within reward groups. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1fdea42328..c7bd0e3e0e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -410,6 +410,9 @@ def get_per_token_logps(model, input_ids): loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics + completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics["completion_length"].append(completion_length) + reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, PreTrainedModel): From 55a329e9f0636a2ad6522caa4a601326def44545 Mon Sep 17 00:00:00 2001 From: Andy Liu <31980222+andyl98@users.noreply.github.com> Date: Sun, 26 Jan 2025 01:05:21 -0800 Subject: [PATCH 26/96] =?UTF-8?q?=F0=9F=8C=80=20Fix=20GRPO=20default=20com?= =?UTF-8?q?pletion=20length=20doc=20(#2662)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/grpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 490ee90554..310a9eceb6 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -48,7 +48,7 @@ class GRPOConfig(TrainingArguments): Number of generations per prompt to sample. temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. - max_completion_length (`int` or `None`, *optional*, defaults to `None`): + max_completion_length (`int` or `None`, *optional*, defaults to `256`): Maximum length of the generated completion. > Parameters that control the training From 1123bd0f514164eb297e7d6d48d8d8057c6e7334 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:37:15 +0100 Subject: [PATCH 27/96] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20Add=20model=20tag?= =?UTF-8?q?s=20to=20model=20trained=20with=20GRPO=20(#2663)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/grpo_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c7bd0e3e0e..a5f6e6f219 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -137,6 +137,8 @@ class GRPOTrainer(Trainer): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. """ + _tag_names = ["trl", "grpo"] + def __init__( self, model: Union[str, PreTrainedModel], @@ -278,6 +280,9 @@ def data_collator(features): # No data collation is needed in GRPO # self.model_accepts_loss_kwargs to False to enable scaling. self.model_accepts_loss_kwargs = False + # Add tags to the model + self.model.add_model_tags(self._tag_names) + if self.ref_model is not None: if self.is_deepspeed_enabled: self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) From 4659ad916fe3e05fd81bf1297db7a6767d2c69bc Mon Sep 17 00:00:00 2001 From: omahs <73983677+omahs@users.noreply.github.com> Date: Tue, 28 Jan 2025 11:26:36 +0100 Subject: [PATCH 28/96] =?UTF-8?q?=F0=9F=96=8A=20Fix=20typos=20(#2673)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix typos * fix typo * fix typo * fix typos * fix typos * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo --- docs/source/bco_trainer.md | 4 ++-- docs/source/ddpo_trainer.md | 6 +++--- docs/source/detoxifying_a_lm.md | 6 +++--- examples/datasets/hh-rlhf-helpful-base.py | 2 +- examples/datasets/lm-human-preferences-descriptiveness.py | 2 +- examples/datasets/lm-human-preferences-sentiment.py | 2 +- examples/datasets/math_shepherd.py | 2 +- examples/datasets/prm800k.py | 2 +- examples/datasets/rlaif-v.py | 2 +- examples/datasets/tldr.py | 2 +- examples/datasets/tldr_preference.py | 2 +- examples/datasets/ultrafeedback-prompt.py | 2 +- examples/datasets/ultrafeedback.py | 2 +- tests/test_data_utils.py | 8 ++++---- trl/trainer/gkd_trainer.py | 2 +- 15 files changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/source/bco_trainer.md b/docs/source/bco_trainer.md index c23365cc00..e449f86b63 100644 --- a/docs/source/bco_trainer.md +++ b/docs/source/bco_trainer.md @@ -62,7 +62,7 @@ embedding_model = Accelerator().prepare_model(self.embedding_model) embedding_func = partial(embed_prompt, model=embedding_model) ``` -Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: +Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: ```py training_args = BCOConfig( @@ -97,4 +97,4 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype ## BCOConfig -[[autodoc]] BCOConfig \ No newline at end of file +[[autodoc]] BCOConfig diff --git a/docs/source/ddpo_trainer.md b/docs/source/ddpo_trainer.md index 0682144edb..eca557c9e4 100644 --- a/docs/source/ddpo_trainer.md +++ b/docs/source/ddpo_trainer.md @@ -14,8 +14,8 @@ ## Getting started with Stable Diffusion finetuning with reinforcement learning The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` -library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. -Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made. +library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. +Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made. There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. @@ -26,7 +26,7 @@ For a more detailed look into the interface and the associated default implement Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. -Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. +Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. ## Getting started with `examples/scripts/ddpo.py` diff --git a/docs/source/detoxifying_a_lm.md b/docs/source/detoxifying_a_lm.md index fe97422889..2bc639b11b 100644 --- a/docs/source/detoxifying_a_lm.md +++ b/docs/source/detoxifying_a_lm.md @@ -45,7 +45,7 @@ When doing PPO, it is very important to design the problem efficiently so that t ### Pre-processing the dataset -The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score. +The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score. A `prompt` example: ``` @@ -109,7 +109,7 @@ ref_model = create_reference_model(model, num_shared_layers=6) trainer = PPOTrainer(..., ref_model=ref_model) ``` -In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). +In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). - One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower). @@ -176,7 +176,7 @@ The evaluation script can be found [here](https://github.com/huggingface/trl/blo The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers). -To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful. +To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful. ### Limitations diff --git a/examples/datasets/hh-rlhf-helpful-base.py b/examples/datasets/hh-rlhf-helpful-base.py index 44966f917e..98a225c8ec 100644 --- a/examples/datasets/hh-rlhf-helpful-base.py +++ b/examples/datasets/hh-rlhf-helpful-base.py @@ -110,7 +110,7 @@ def extract_dialogue(example: str) -> list[dict[str, str]]: - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The user query. +- `"prompt"`: The user query. - `"chosen"`: A response deemed helpful by human evaluators. - `"rejected"`: A response considered less helpful or unhelpful. diff --git a/examples/datasets/lm-human-preferences-descriptiveness.py b/examples/datasets/lm-human-preferences-descriptiveness.py index b836fcc6d5..7515b77373 100644 --- a/examples/datasets/lm-human-preferences-descriptiveness.py +++ b/examples/datasets/lm-human-preferences-descriptiveness.py @@ -82,7 +82,7 @@ def to_prompt_completion(example, tokenizer): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The text sample. +- `"prompt"`: The text sample. - `"chosen"`: A version of the text with enhanced descriptiveness. - `"rejected"`: A version of the text with less descriptiveness. diff --git a/examples/datasets/lm-human-preferences-sentiment.py b/examples/datasets/lm-human-preferences-sentiment.py index 198469c9e0..da411742ba 100644 --- a/examples/datasets/lm-human-preferences-sentiment.py +++ b/examples/datasets/lm-human-preferences-sentiment.py @@ -77,7 +77,7 @@ def to_prompt_completion(example, tokenizer): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The text sample. +- `"prompt"`: The text sample. - `"chosen"`: A version of the text that conveys the desired sentiment. - `"rejected"`: A version of the text that does not convey the desired sentiment. diff --git a/examples/datasets/math_shepherd.py b/examples/datasets/math_shepherd.py index 5dbd5ab7ea..47a28f0a30 100644 --- a/examples/datasets/math_shepherd.py +++ b/examples/datasets/math_shepherd.py @@ -141,7 +141,7 @@ def process_example(example): - **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision) Columns: -- `"pompt"`: The problem statement. +- `"prompt"`: The problem statement. - `"completions"`: A list of reasoning steps generated to solve the problem. - `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step. diff --git a/examples/datasets/prm800k.py b/examples/datasets/prm800k.py index c859272909..631fc89d24 100644 --- a/examples/datasets/prm800k.py +++ b/examples/datasets/prm800k.py @@ -115,7 +115,7 @@ def process_batch(examples): - **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision) Columns: -- `"pompt"`: The problem statement. +- `"prompt"`: The problem statement. - `"completions"`: A list of reasoning steps generated to solve the problem. - `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step. diff --git a/examples/datasets/rlaif-v.py b/examples/datasets/rlaif-v.py index 9548daa6b8..b867d6ed68 100644 --- a/examples/datasets/rlaif-v.py +++ b/examples/datasets/rlaif-v.py @@ -77,7 +77,7 @@ def to_conversational(example): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The task related to the image. +- `"prompt"`: The task related to the image. - `"images"`: The image. - `"chosen"`: The preferred answer. - `"rejected"`: An alternative answer that was not preferred. diff --git a/examples/datasets/tldr.py b/examples/datasets/tldr.py index 0fc27bd8c8..1f14943594 100644 --- a/examples/datasets/tldr.py +++ b/examples/datasets/tldr.py @@ -72,7 +72,7 @@ def to_prompt_completion(example): - **Type**: [Prompt-completion](https://huggingface.co/docs/trl/main/dataset_formats#prompt-completion) Columns: -- `"pompt"`: The unabridged Reddit post. +- `"prompt"`: The unabridged Reddit post. - `"completion"`: The concise "TL;DR" summary appended by the author. This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities. diff --git a/examples/datasets/tldr_preference.py b/examples/datasets/tldr_preference.py index f6c05f8e27..3de9a557a9 100644 --- a/examples/datasets/tldr_preference.py +++ b/examples/datasets/tldr_preference.py @@ -83,7 +83,7 @@ def to_preference(example): - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) Columns: -- `"pompt"`: The unabridged Reddit post. +- `"prompt"`: The unabridged Reddit post. - `"chosen"`: The concise "TL;DR" summary appended by the author. - `"rejected"`: An alternative summary or response that was not selected. diff --git a/examples/datasets/ultrafeedback-prompt.py b/examples/datasets/ultrafeedback-prompt.py index 7c218ee786..9036f9607e 100644 --- a/examples/datasets/ultrafeedback-prompt.py +++ b/examples/datasets/ultrafeedback-prompt.py @@ -77,7 +77,7 @@ def drop_long_prompt(example): - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only) Column: -- `"pompt"`: The input question or instruction provided to the model. +- `"prompt"`: The input question or instruction provided to the model. ## Generation script diff --git a/examples/datasets/ultrafeedback.py b/examples/datasets/ultrafeedback.py index 49e4e2cc0c..8ed25798ee 100644 --- a/examples/datasets/ultrafeedback.py +++ b/examples/datasets/ultrafeedback.py @@ -112,7 +112,7 @@ def to_unpaired_preference(example, model_name, aspect): - **Type**: [Unpaired preference](https://huggingface.co/docs/trl/main/dataset_formats#unpaired-preference) Column: -- `"pompt"`: The input question or instruction provided to the model. +- `"prompt"`: The input question or instruction provided to the model. - `"completion"`: The model's response to the prompt. - `"label"`: A binary value indicating whether the response is sufficiently helpful. diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index f95e37fc84..95fe8e7049 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -41,7 +41,7 @@ class IsConversationalTester(unittest.TestCase): { # Prompt only "prompt": [{"role": "user", "content": "What color is the sky?"}], }, - { # Pompt-completion + { # Prompt-completion "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], }, @@ -110,7 +110,7 @@ class ApplyChatTemplateTester(unittest.TestCase): { # Prompt only "prompt": [{"role": "user", "content": "What color is the sky?"}], }, - { # Pompt-completion + { # Prompt-completion "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], }, @@ -153,7 +153,7 @@ def test_apply_chat_template(self, tokenizer_id, example): # Checking if the result is a dictionary self.assertIsInstance(result, dict) - # The chat template should be applied to the the following keys + # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: self.assertIn(key, result) @@ -179,7 +179,7 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example): # Checking if the result is a dictionary self.assertIsInstance(result, dict) - # The chat template should be applied to the the following keys + # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: self.assertIn(key, result) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 59d71d1e44..3db1f95252 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -86,7 +86,7 @@ def __init__( peft_config: Optional["PeftConfig"] = None, formatting_func: Optional[Callable] = None, ): - # add remove_unused_columns=False to the the dataclass args + # add remove_unused_columns=False to the dataclass args args.remove_unused_columns = False data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) From ed14ed90438860fc59b8b7694d4e103a2a146a57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 29 Jan 2025 13:01:10 +0100 Subject: [PATCH 29/96] =?UTF-8?q?=E2=9A=A1=20vLLM=20for=20fast=20generatio?= =?UTF-8?q?n=20in=20GRPO=20(#2600)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * doc * fsdp * use vllm config * vllm * Update trl/trainer/grpo_config.py Co-authored-by: lewtun * Update trl/trainer/grpo_config.py Co-authored-by: lewtun * typo * top_k, top_p * Link to vllm pr * fix missing device * fix tests * fix citation * fix title and paper_id * formatting * output the correct number of generations * initial async vllm * fix missing args * fix promps * Pass prompt_token_ids directly * Repeat each prompt num_generations times * get the slice of results per processor * undo citation * OMG * nothing can resist me!!!! * working * vllm_device to "auto" * add vllm test * add initial vllm docs * add vllm link and pip instructions * add multi-gpu strategy fot vllm * Update docs/source/grpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update docs/source/grpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update docs/source/grpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * add doc strings * Update docs/source/grpo_trainer.md Co-authored-by: lewtun * Update trl/trainer/grpo_trainer.py Co-authored-by: lewtun * Update docs/source/grpo_trainer.md Co-authored-by: lewtun * add important tag * fix typo * overrides default batch size and grad accum and better doc * Under no circumstances should you examine the contents of this commit. * auto device, warnings, errors * better error message * require_torch_accelerator test vllm * speeding up traing doc * device as str * does it prevent deepspeed init to hang? * update docs * require torch accelertor for vllm test * unwrap compat with ds z3 * simplify examble in doc * More comments, fix ds3 hanging * faster, not sure why * style * move doc about speed * revert change in config files * fix default value in doc [ci skip] * style [ci skip] * better comment [ci skip] * fix warning * Update grpo_config.py * Update deepspeed_zero1.yaml * Update trl/trainer/grpo_trainer.py Co-authored-by: lewtun * Apply suggestions from code review Co-authored-by: lewtun * Update docs/source/grpo_trainer.md --------- Co-authored-by: lewtun Co-authored-by: Kashif Rasul --- docs/source/grpo_trainer.md | 34 ++++---- docs/source/speeding_up_training.md | 52 +++++++++++- tests/test_grpo_trainer.py | 37 ++++++++- trl/trainer/grpo_config.py | 64 +++++++++++++++ trl/trainer/grpo_trainer.py | 123 +++++++++++++++++++++++++--- trl/trainer/online_dpo_config.py | 4 +- trl/trainer/online_dpo_trainer.py | 3 +- 7 files changed, 283 insertions(+), 34 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 7250b72cab..76ae09a160 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -14,7 +14,7 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi ## Quick start -This example demonstrates how to train a model using the GRPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model and the [RM-Gemma-2B model](https://huggingface.co/weqweasdas/RM-Gemma-2B) as the reward model. We use the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here: +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here: -Below is the script to train the model. We use PEFT to reduce the memory requirements. +Below is the script to train the model. ```python # train_grpo.py from datasets import load_dataset -from peft import LoraConfig from trl import GRPOConfig, GRPOTrainer -# Load the dataset dataset = load_dataset("trl-lib/tldr", split="train") -training_args = GRPOConfig( - output_dir="Qwen2-0.5B-GRPO", - learning_rate=1e-5, - logging_steps=10, - gradient_accumulation_steps=16, - max_completion_length=128, -) +# Define the reward function, which rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [abs(20 - len(completion)) for completion in completions] + +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs="weqweasdas/RM-Gemma-2B", + reward_funcs=reward_len, args=training_args, train_dataset=dataset, - peft_config=LoraConfig(task_type="CAUSAL_LM"), ) - trainer.train() ``` @@ -118,6 +112,18 @@ The GRPO Trainer logs the following metrics: ## Customization +## Speed up training with vLLM-powered generation + +Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_vllm=True) +``` + +For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods). + ### Using a custom reward function The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements: diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index f47f1b2907..83d14cb5a2 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -8,14 +8,21 @@ Section under construction. Feel free to contribute! ## vLLM for fast generation in online methods -Online methods such as Online DPO or Nash-MD require the model to generate completions, which is often a slow process and can significantly impact training time. -To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. +Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. +To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. + +To use [vLLM](https://github.com/vllm-project/vllm), first install it using: -To use vLLM, first install it using: ```bash pip install vllm ``` +or + +```bash +pip install "trl[vllm]" +``` + @@ -24,7 +31,44 @@ Then, enable it by passing `use_vllm=True` in the training arguments. ```python from trl import OnlineDPOConfig -training_args = DPOConfig(..., use_vllm=True) +training_args = OnlineDPOConfig(..., use_vllm=True) +``` + + + + +Then, enable it by passing `use_vllm=True` in the training arguments. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_vllm=True) +``` + +The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training. + + + +When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes `. + +For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation. +```bash +accelerate launch --multi_gpu --num_processes 3 train_grpo.py +``` + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png) + + + +You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`]. + +```python +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_device="cuda:4", + vllm_gpu_memory_utilization=0.7, +) ``` diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 5de52ef69d..e1c0ac38ba 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -19,10 +19,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft +from transformers.testing_utils import require_peft, require_torch_accelerator from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer +from trl.import_utils import is_vllm_available if is_peft_available(): @@ -330,3 +331,37 @@ def reward_func(completions, some_values, **kwargs): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") + @require_torch_accelerator + def test_training_vllm(self): + """Test that training works with vLLM for generation.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/small-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 310a9eceb6..0fd0d9f5d2 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -51,11 +51,31 @@ class GRPOConfig(TrainingArguments): max_completion_length (`int` or `None`, *optional*, defaults to `256`): Maximum length of the generated completion. + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for + training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). + vllm_device (`str`, *optional*, defaults to `"auto"`): + Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will + automatically select the next available GPU after the last one used for training. This assumes that + training has not already occupied all available GPUs. + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the + device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus + improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors + during initialization. + > Parameters that control the training learning_rate (`float`, *optional*, defaults to `1e-6`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. + per_device_train_batch_size (`int`, *optional*, defaults to `1`): + Number of prompts sampled per device for training. The actual batch passed into the model will be this + value multiplied by `num_generations`. + gradient_accumulation_steps (`int`, *optional*, defaults to `8`): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. """ @@ -98,6 +118,33 @@ class GRPOConfig(TrainingArguments): metadata={"help": "Maximum length of the generated completion."}, ) + # Parameters that control generation acceleration powered by vLLM + use_vllm: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept " + "unused for training, as vLLM will require one for generation. vLLM must be installed " + "(`pip install vllm`)." + }, + ) + vllm_device: Optional[str] = field( + default="auto", + metadata={ + "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system " + "will automatically select the next available GPU after the last one used for training. This assumes " + "that training has not already occupied all available GPUs." + }, + ) + vllm_gpu_memory_utilization: float = field( + default=0.9, + metadata={ + "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " + "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " + "size and thus improve the model's throughput. However, if the value is too high, it may cause " + "out-of-memory (OOM) errors during initialization." + }, + ) + # Parameters that control the training learning_rate: float = field( default=1e-6, @@ -106,6 +153,23 @@ class GRPOConfig(TrainingArguments): "`transformers.TrainingArguments`." }, ) + # GRPO generates multiple completions per prompt, increasing memory usage. + # To accommodate this, the per-device train batch size is decreased (overriden from the parent class), + # and the number gradient accumulation steps is increased to maintain the effective batch size. + per_device_train_batch_size: int = field( + default=1, + metadata={ + "help": "Number of prompts sampled per device for training. The actual batch passed into the model will " + "be this value multiplied by `num_generations`." + }, + ) + gradient_accumulation_steps: int = field( + default=8, + metadata={ + "help": "Number of updates steps to accumulate the gradients for, before performing a backward/update " + "pass." + }, + ) beta: float = field( default=0.04, metadata={"help": "KL coefficient."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a5f6e6f219..66911f90cd 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,12 +14,15 @@ import os import textwrap +import warnings from collections import defaultdict from typing import Any, Callable, Optional, Union +from unittest.mock import patch import torch import torch.utils.data import transformers +from accelerate.utils import broadcast_object_list, gather_object from datasets import Dataset, IterableDataset from packaging import version from transformers import ( @@ -37,14 +40,18 @@ from transformers.utils import is_peft_available from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ..import_utils import is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .grpo_config import GRPOConfig -from .utils import generate_model_card, get_comet_experiment_url +from .utils import generate_model_card, get_comet_experiment_url, pad if is_peft_available(): from peft import PeftConfig, get_peft_model +if is_vllm_available(): + from vllm import LLM, SamplingParams + if is_wandb_available(): import wandb @@ -244,13 +251,8 @@ def data_collator(features): # No data collation is needed in GRPO self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper - self.generation_config = GenerationConfig( - max_new_tokens=self.max_completion_length, - do_sample=True, - temperature=args.temperature, - num_return_sequences=self.num_generations, - pad_token_id=processing_class.pad_token_id, - ) + self.use_vllm = args.use_vllm + self.beta = args.beta # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the @@ -275,6 +277,65 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + + if self.accelerator.is_main_process: + vllm_device = self.args.vllm_device + if vllm_device == "auto": + vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + # Check that the requested device is available + if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): + raise ValueError( + f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM " + "without restricting the number of GPUs for training. Set the `--num_processes` argument to a " + "value lower than the number of GPUs available on your machine—typically, reducing it by one " + f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`." + ) + # Check that the requested device is not also used for training + if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: + warnings.warn( + f"The requested device {vllm_device} is also used for training. This may lead to unexpected " + "behavior. It is recommended to use a dedicated device for vLLM." + ) + # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM + # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our + # setting (profiling_patch). + world_size_patch = patch("torch.distributed.get_world_size", return_value=1) + profiling_patch = patch( + "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None + ) + with world_size_patch, profiling_patch: + self.llm = LLM( + model=model.name_or_path, + device=vllm_device, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + ) + self.sampling_params = SamplingParams( + n=self.num_generations, + temperature=args.temperature, + max_tokens=self.max_completion_length, + ) + + self._last_loaded_step = 0 # tag to avoid useless loading during grad checkpointing + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + self.generation_config = GenerationConfig( + max_new_tokens=self.max_completion_length, + do_sample=True, + temperature=args.temperature, + num_return_sequences=self.num_generations, + pad_token_id=processing_class.pad_token_id, + ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set # self.model_accepts_loss_kwargs to False to enable scaling. @@ -310,6 +371,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") + device = self.accelerator.device prompts = [x["prompt"] for x in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -321,9 +383,46 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] - # Generate completions - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config) + # Generate completions using either vLLM or regular generation + if self.args.use_vllm: + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights(state_dict.items()) + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) + completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] + else: + completion_ids = [None] * len(all_prompts_text) * self.num_generations + + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts) * self.num_generations, + (self.accelerator.process_index + 1) * len(prompts) * self.num_generations, + ) + completion_ids = completion_ids[process_slice] + + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0) + prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + prompt_completion_ids = unwrapped_model.generate( + **prompt_inputs, generation_config=self.generation_config + ) + prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] @@ -357,7 +456,6 @@ def get_per_token_logps(model, input_ids): # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id - device = self.accelerator.device eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) @@ -483,6 +581,7 @@ def create_model_card( author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, year = 2024, eprint = {arXiv:2402.03300}, + } """ ) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 8557d29628..12daa74a07 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -63,7 +63,7 @@ class OnlineDPOConfig(TrainingArguments): disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model and reference model. use_vllm (`bool`, *optional*, defaults to `False`): - Whether to use the vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). + Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM @@ -140,7 +140,7 @@ class OnlineDPOConfig(TrainingArguments): use_vllm: bool = field( default=False, metadata={ - "help": "Whether to use the vLLM for generating completions. Requires vLLM to be installed " + "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " "(`pip install vllm`)." }, ) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 44bc02c563..9abefc5140 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -262,7 +262,7 @@ def __init__( top_p=1.0, detokenize=False, # to avoid vllm to decode (we don't need it) ) - # vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instanciation. + # vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instantiation. # A larger cache size improves speed, so we would expect gpu_memory_utilization=1. # However, at this stage, the optimizer's weights are not yet loaded onto the GPU; they will be loaded # after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough @@ -272,6 +272,7 @@ def __init__( gpu_memory_utilization=0.55, dtype=torch.float32, # When release by vLLM, we would be able to distribute the model on multiple GPUs + # See https://github.com/vllm-project/vllm/pull/12071 # tensor_parallel_size=torch.cuda.device_count(), # distributed_executor_backend="external_launcher", ) From 801582ec240c46ecd487a72fcc8944f268380830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:12:18 +0100 Subject: [PATCH 30/96] =?UTF-8?q?=F0=9F=93=89=20Use=20`num=5Flogits=5Fto?= =?UTF-8?q?=5Fkeep`=20to=20reduce=20memory=20usage=20in=20GRPO=20(#2683)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * use num_logits to keep * add comment back * Update trl/trainer/grpo_trainer.py --- trl/trainer/grpo_trainer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 66911f90cd..b0a36f7590 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -427,29 +427,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N completion_ids = prompt_completion_ids[:, prompt_length:] # Get the per-token log probabilities for the completions for the model and the reference model - def get_per_token_logps(model, input_ids): - logits = model(input_ids).logits # (B, L, V) + def get_per_token_logps(model, input_ids, num_logits_to_keep): + # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V) logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it + # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. per_token_logps = [] - for logits_row, input_ids_row in zip(logits, input_ids): + for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]): log_probs = logits_row.log_softmax(dim=-1) token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps) - per_token_logps = get_per_token_logps(model, prompt_completion_ids) - # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) - per_token_logps = per_token_logps[:, prompt_length - 1 :] + num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) with torch.inference_mode(): if self.ref_model is not None: - ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids) + ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep) else: with self.accelerator.unwrap_model(model).disable_adapter(): - ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids) - ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :] + ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) # Compute the KL divergence between the model and the reference model per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From 56880ba73d5a58d29838fcaecd19a0750772bf1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:23:31 +0100 Subject: [PATCH 31/96] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Bump=20dev=20version?= =?UTF-8?q?=20(#2689)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/tests_latest.yml | 2 +- CITATION.cff | 2 +- setup.py | 2 +- trl/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests_latest.yml b/.github/workflows/tests_latest.yml index d8ea6d524e..ab071f05d3 100644 --- a/.github/workflows/tests_latest.yml +++ b/.github/workflows/tests_latest.yml @@ -17,7 +17,7 @@ jobs: steps: - name: Git checkout uses: actions/checkout@v4 - with: { ref: v0.13-release } + with: { ref: v0.14-release } - name: Set up Python 3.12 uses: actions/setup-python@v5 with: diff --git a/CITATION.cff b/CITATION.cff index 05b9ae2bbc..cc7130b2b1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -31,4 +31,4 @@ keywords: - pytorch - transformers license: Apache-2.0 -version: 0.13 +version: 0.14 diff --git a/setup.py b/setup.py index 87c52071ee..205666f90f 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ from setuptools import find_packages, setup -__version__ = "0.14.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.15.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "accelerate>=0.34.0", diff --git a/trl/__init__.py b/trl/__init__.py index 1eb78822e1..4d5e9e041e 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.14.0.dev0" +__version__ = "0.15.0.dev0" from typing import TYPE_CHECKING From df8f619ec590f6e5e785e8df761dee5b60d24400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:31:08 +0100 Subject: [PATCH 32/96] =?UTF-8?q?=F0=9F=93=A6=20`trl.templates`=20in=20exc?= =?UTF-8?q?luded=20packages=20(#2690)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 205666f90f..22e8d6c2cb 100644 --- a/setup.py +++ b/setup.py @@ -123,7 +123,7 @@ package_data={ "trl": ["templates/*.md"], }, - packages=find_packages(exclude={"tests", "tests.slow"}), + packages=find_packages(exclude={"tests", "tests.slow", "trl.templates"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, python_requires=">=3.9", From 094d51b599717a1bd06fc4a5f89bd9bdcbc8e69a Mon Sep 17 00:00:00 2001 From: Elias Rad <146735585+nnsW3@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:42:14 +0200 Subject: [PATCH 33/96] =?UTF-8?q?=F0=9F=93=96=20=20Docs=20fix=20spelling?= =?UTF-8?q?=20issues=20(#2682)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update alignprop_trainer.md * Update best_of_n.md * Update clis.md * Update community_tutorials.md * Update cpo_trainer.md * Update dataset_formats.md * Update detoxifying_a_lm.md * Update dpo_trainer.md * Update rloo_trainer.md * Update clis.md * Update rloo_trainer.md --- docs/source/alignprop_trainer.md | 4 ++-- docs/source/best_of_n.md | 2 +- docs/source/clis.md | 2 +- docs/source/community_tutorials.md | 2 +- docs/source/cpo_trainer.md | 4 ++-- docs/source/dataset_formats.md | 2 +- docs/source/detoxifying_a_lm.md | 4 ++-- docs/source/dpo_trainer.md | 4 ++-- docs/source/rloo_trainer.md | 4 ++-- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/alignprop_trainer.md b/docs/source/alignprop_trainer.md index a4c6b007ef..4c3b21042c 100644 --- a/docs/source/alignprop_trainer.md +++ b/docs/source/alignprop_trainer.md @@ -16,7 +16,7 @@ The `alignprop.py` script is a working example of using the `AlignProp` trainer **Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1. -Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running ```batch python alignprop.py --hf_user_access_token @@ -26,7 +26,7 @@ To obtain the documentation of `stable_diffusion_tuning.py`, please run `python The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) -- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) +- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) - The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False ## Setting up the image logging hook function diff --git a/docs/source/best_of_n.md b/docs/source/best_of_n.md index 9dd56aba2c..8b2978c2a3 100644 --- a/docs/source/best_of_n.md +++ b/docs/source/best_of_n.md @@ -67,6 +67,6 @@ best_of_n.generate(query_tensors, device=device) ``` -Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query +Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query diff --git a/docs/source/clis.md b/docs/source/clis.md index 885227a116..d165a49668 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -13,7 +13,7 @@ Currently supported CLIs are: #### Other commands -- `trl chat`: quickly spin up a LLM fine-tuned for chatting +- `trl chat`: quickly spin up an LLM fine-tuned for chatting - `trl env`: get the system information ## Fine-tuning with the CLI diff --git a/docs/source/community_tutorials.md b/docs/source/community_tutorials.md index 4b2b9a6e54..67c37442f5 100644 --- a/docs/source/community_tutorials.md +++ b/docs/source/community_tutorials.md @@ -1,6 +1,6 @@ # Community Tutorials -Community tutorials are made by active members of the Hugging Face community that want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities. +Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities. # Language Models diff --git a/docs/source/cpo_trainer.md b/docs/source/cpo_trainer.md index 3f9fb88cfc..24e0f3fdae 100644 --- a/docs/source/cpo_trainer.md +++ b/docs/source/cpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat. +Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat. CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. @@ -105,4 +105,4 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype ## CPOConfig -[[autodoc]] CPOConfig \ No newline at end of file +[[autodoc]] CPOConfig diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index d29770abb4..a8e10f1830 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -341,7 +341,7 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation. +We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation. For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks. diff --git a/docs/source/detoxifying_a_lm.md b/docs/source/detoxifying_a_lm.md index 2bc639b11b..eb0ab5fd80 100644 --- a/docs/source/detoxifying_a_lm.md +++ b/docs/source/detoxifying_a_lm.md @@ -30,7 +30,7 @@ We selected the following models for our experiments to show that TRL can be eas * [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters) * [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters) -For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). +For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). | Model | Mean toxicity score | |---|---| @@ -88,7 +88,7 @@ As a compromise between the two we took for a context window of 10 to 15 tokens ### How to deal with OOM issues -Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: +Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: - Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2: diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index b0d6b1f8d6..dac5c227d5 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -81,7 +81,7 @@ The best programming language based on these factors is subjective and depends o ## Expected dataset type -DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. @@ -280,4 +280,4 @@ dpo_trainer = DPOTrainer( ## DataCollatorForPreference -[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference \ No newline at end of file +[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 3ef57a3dc6..5ad5eca53c 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -2,7 +2,7 @@ [![](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl) -TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL. +TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, whereas PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL. References: - [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740) @@ -58,7 +58,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an ## Cookbook * Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. -* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it. +* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. * Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. From 9ac8d9773b2ea6e291653e5baa914dde6041145b Mon Sep 17 00:00:00 2001 From: Adam Yanxiao Zhao Date: Thu, 30 Jan 2025 16:57:43 +0800 Subject: [PATCH 34/96] =?UTF-8?q?=F0=9F=93=84=20Add=20GRPO=20batch=20size?= =?UTF-8?q?=20note=20in=20docs=20(#2672)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add note for OOM error * update note * Apply suggestions from code review --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/grpo_trainer.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 76ae09a160..02056ac662 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -24,6 +24,8 @@ This example demonstrates how to train a model using the GRPO method. We train a > Below is the script to train the model. +Note that the input tensor for the forward pass has a size of `num_generations * per_device_train_batch_size` because GRPO generates `num_generations` completions for each prompt in the batch. Adjusting these values appropriately can help prevent OOM errors. +Consequently, the effective train batch size is `num_generations * per_device_train_batch_size * gradient_accumulation_steps`. ```python # train_grpo.py From 41979563959d0e559aa4da61aecf6cdd7a061706 Mon Sep 17 00:00:00 2001 From: wizard <112275929+famouswizard@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:17:02 +0300 Subject: [PATCH 35/96] =?UTF-8?q?=F0=9F=99=88=20Fixed=20typo=20in=20the=20?= =?UTF-8?q?GRPO=20documentation=20(#2691)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/grpo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 02056ac662..fa601bbeeb 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -14,7 +14,7 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi ## Quick start -This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here: +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here: Below is the script to train the model. -Note that the input tensor for the forward pass has a size of `num_generations * per_device_train_batch_size` because GRPO generates `num_generations` completions for each prompt in the batch. Adjusting these values appropriately can help prevent OOM errors. -Consequently, the effective train batch size is `num_generations * per_device_train_batch_size * gradient_accumulation_steps`. ```python # train_grpo.py diff --git a/tests/test_cli.py b/tests/test_cli.py index 234b4e7ba1..6b00ed1ed0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -47,7 +47,7 @@ def test_grpo(self): from trl.cli import main with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 3 --max_completion_length 32 --report_to none" + command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 4 --max_completion_length 32 --report_to none" with patch("sys.argv", command.split(" ")): main() diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 6c85ddfc4f..54d2b3965d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -49,7 +49,7 @@ def test_training(self, config_name): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -78,8 +78,8 @@ def test_training_with_eval(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = GRPOConfig( output_dir=tmp_dir, - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage - per_device_eval_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage eval_strategy="steps", @@ -106,7 +106,7 @@ def test_training_peft(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -149,7 +149,7 @@ def test_training_different_reward_model(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -185,7 +185,7 @@ def reward_func(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -221,7 +221,7 @@ def reward_func(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -260,7 +260,7 @@ def reward_func2(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -295,7 +295,7 @@ def reward_func(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -334,7 +334,7 @@ def reward_func(completions, some_values, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -367,7 +367,7 @@ def test_training_vllm(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -400,7 +400,7 @@ def test_training_torch_compile(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage torch_compile=True, @@ -431,7 +431,7 @@ def test_training_with_sync_ref_model(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage sync_ref_model=True, diff --git a/trl/models/utils.py b/trl/models/utils.py index 22a30c0afb..dce9d60228 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -137,6 +137,8 @@ def setup_chat_format( def remove_hooks(model: "DeepSpeedEngine") -> None: """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: @@ -164,6 +166,8 @@ def iter_params(module, recurse=False): def add_hooks(model: "DeepSpeedEngine") -> None: """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e641065968..b30cf9899d 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -45,7 +45,8 @@ class GRPOConfig(TrainingArguments): max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): - Number of generations per prompt to sample. + Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) + must be divisible by this value. temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. max_completion_length (`int` or `None`, *optional*, defaults to `256`): @@ -83,11 +84,6 @@ class GRPOConfig(TrainingArguments): learning_rate (`float`, *optional*, defaults to `1e-6`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - per_device_train_batch_size (`int`, *optional*, defaults to `1`): - Number of prompts sampled per device for training. The actual batch passed into the model will be this - value multiplied by `num_generations`. - gradient_accumulation_steps (`int`, *optional*, defaults to `8`): - Number of updates steps to accumulate the gradients for, before performing a backward/update pass. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. sync_ref_model (`bool`, *optional*, defaults to `False`): @@ -132,7 +128,10 @@ class GRPOConfig(TrainingArguments): ) num_generations: Optional[int] = field( default=8, - metadata={"help": "Number of generations to sample."}, + metadata={ + "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " + "must be divisible by this value." + }, ) temperature: Optional[float] = field( default=0.9, @@ -202,23 +201,6 @@ class GRPOConfig(TrainingArguments): "`transformers.TrainingArguments`." }, ) - # GRPO generates multiple completions per prompt, increasing memory usage. - # To accommodate this, the per-device train batch size is decreased (overriden from the parent class), - # and the number gradient accumulation steps is increased to maintain the effective batch size. - per_device_train_batch_size: int = field( - default=1, - metadata={ - "help": "Number of prompts sampled per device for training. The actual batch passed into the model will " - "be this value multiplied by `num_generations`." - }, - ) - gradient_accumulation_steps: int = field( - default=8, - metadata={ - "help": "Number of updates steps to accumulate the gradients for, before performing a backward/update " - "pass." - }, - ) beta: float = field( default=0.04, metadata={"help": "KL coefficient."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c6c5f4fa76..fa55a4bb8e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -16,17 +16,18 @@ import textwrap import warnings from collections import defaultdict -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch import torch import torch.utils.data import transformers -from accelerate.utils import broadcast_object_list, gather_object +from accelerate.utils import broadcast_object_list, gather, gather_object from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset from packaging import version from torch import nn +from torch.utils.data import Sampler from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, @@ -63,6 +64,37 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +class RepeatRandomSampler(Sampler): + """ + Sampler that repeats the indices of a dataset N times. + + Args: + data_source (`Sized`): + Dataset to sample from. + repeat_count (`int`): + Number of times to repeat each index. + + Example: + ```python + >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2) + >>> list(sampler) + [2, 2, 0, 0, 3, 3, 1, 1] + ``` + """ + + def __init__(self, data_source: Sized, repeat_count: int): + self.data_source = data_source + self.repeat_count = repeat_count + self.num_samples = len(data_source) + + def __iter__(self): + indexes = [idx for idx in torch.randperm(self.num_samples).tolist() for _ in range(self.repeat_count)] + return iter(indexes) + + def __len__(self): + return self.num_samples * self.repeat_count + + class GRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -280,6 +312,26 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations + num_processes = self.accelerator.num_processes + global_batch_size = args.per_device_train_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " + f"batch size, the valid values for the number of generations are: {possible_values}." + ) + if self.args.eval_strategy != "no": + global_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " + f"eval batch size, the valid values for the number of generations are: {possible_values}." + ) + if self.use_vllm: if not is_vllm_available(): raise ImportError( @@ -325,12 +377,11 @@ def data_collator(features): # No data collation is needed in GRPO max_model_len=self.args.vllm_max_model_len, ) self.sampling_params = SamplingParams( - n=self.num_generations, temperature=args.temperature, max_tokens=self.max_completion_length, ) - self._last_loaded_step = 0 # tag to avoid useless loading during grad checkpointing + self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation # When using vLLM, the main process is responsible for loading the model weights. This can cause process # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we @@ -341,7 +392,6 @@ def data_collator(features): # No data collation is needed in GRPO max_new_tokens=self.max_completion_length, do_sample=True, temperature=args.temperature, - num_return_sequences=self.num_generations, pad_token_id=processing_class.pad_token_id, ) @@ -374,12 +424,17 @@ def _set_signature_columns_if_needed(self): if self._signature_columns is None: self._signature_columns = ["prompt"] + # We need a custom sampler that samples the same prompt multiple times + def _get_train_sampler(self) -> Sampler: + return RepeatRandomSampler(self.train_dataset, self.num_generations) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + return RepeatRandomSampler(eval_dataset, self.num_generations) + # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 - ).logits # (B, L, V) + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids = input_ids[:, -logits_to_keep:] @@ -389,8 +444,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Compute the log probabilities for the input tokens. token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - # use a loop to reduce memory peak - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits))) return token_log_probs @@ -430,22 +484,19 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] else: - completion_ids = [None] * len(all_prompts_text) * self.num_generations - + completion_ids = [None] * len(all_prompts_text) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) process_slice = slice( - self.accelerator.process_index * len(prompts) * self.num_generations, - (self.accelerator.process_index + 1) * len(prompts) * self.num_generations, + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), ) completion_ids = completion_ids[process_slice] # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) - prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0) - prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: # Regular generation path @@ -458,7 +509,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0) # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id @@ -488,9 +538,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if is_conversational(inputs[0]): completions = [[{"role": "assistant", "content": completion}] for completion in completions] - # Compute the rewards - prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) @@ -509,14 +556,15 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations - reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} - for key in reward_kwargs: - for example in inputs: - # Repeat each value in the column for `num_generations` times - reward_kwargs[key].extend([example[key]] * self.num_generations) + keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + # Sum the rewards from all reward functions rewards = rewards_per_func.sum(dim=1) @@ -529,8 +577,15 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + # Log the metrics - reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) + reward_per_func = rewards_per_func.mean(0) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models reward_func_name = reward_func.config._name_or_path.split("/")[-1] @@ -538,8 +593,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s reward_func_name = reward_func.__name__ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) - self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) - self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) + self._metrics["reward"].append(rewards.mean().item()) + self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) return { "prompt_ids": prompt_ids, From 2241f17914c8964f25e821bac192342b327bc419 Mon Sep 17 00:00:00 2001 From: binary-husky <96192199+binary-husky@users.noreply.github.com> Date: Fri, 7 Feb 2025 18:08:49 +0800 Subject: [PATCH 66/96] =?UTF-8?q?=F0=9F=86=9A=20Distinguish=20padding=20an?= =?UTF-8?q?d=20eos=20when=20they=20differ=20(#2793)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/scripts/sft.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 2095df3074..764ca3c929 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -84,7 +84,8 @@ def main(script_args, training_args, model_args): tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) - tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token ################ # Dataset From 84d73fd00b188721e28bd9a18ad38f100114dbda Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 7 Feb 2025 11:09:46 +0100 Subject: [PATCH 67/96] =?UTF-8?q?=F0=9F=8E=AF=20[SFT]=20add=20token=20accu?= =?UTF-8?q?racy=20metric=20(#2597)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add token accuracy metric * fix return type * shift tokens * use compute_loss so that the model is called only once * add to logs * log from main process --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- tests/test_utils.py | 55 ++++++++++++++++++++++++++++++++++++++ trl/__init__.py | 4 +-- trl/trainer/__init__.py | 2 ++ trl/trainer/sft_trainer.py | 46 +++++++++++++++++++++++++++++++ trl/trainer/utils.py | 21 +++++++++++++++ 5 files changed, 126 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0061dd5e5e..9fd7ed9e0f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,6 +26,7 @@ from trl.trainer.utils import ( DataCollatorForChatML, batch_generation, + compute_token_accuracy, decode_and_strip_padding, flush_left, generate_model_card, @@ -451,3 +452,57 @@ def test_no_tensors(self): expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestComputeTokenAccuracy(unittest.TestCase): + def test_basic_accuracy(self): + # Test basic accuracy computation + logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2] + labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2] + accuracy = compute_token_accuracy(logits, labels) + self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens + + def test_with_ignore_index(self): + # Test accuracy computation with ignored tokens + logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) + labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored + accuracy = compute_token_accuracy(logits, labels, ignore_index=-100) + self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens + + def test_all_ignored(self): + # Test case where all tokens are ignored + logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) + labels = torch.tensor([[-100, -100]]) + accuracy = compute_token_accuracy(logits, labels) + self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy + + def test_perfect_accuracy(self): + # Test case with 100% accuracy + logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) + labels = torch.tensor([[1, 0]]) + accuracy = compute_token_accuracy(logits, labels) + self.assertEqual(accuracy, 1.0) # All predictions correct + + def test_zero_accuracy(self): + # Test case with 0% accuracy + logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) + labels = torch.tensor([[0, 1]]) + accuracy = compute_token_accuracy(logits, labels) + self.assertEqual(accuracy, 0.0) # All predictions wrong + + def test_batch_accuracy(self): + # Test accuracy computation across multiple batches + logits = torch.tensor( + [ + [[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1 + [[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2 + ] + ) + labels = torch.tensor( + [ + [1, 0, 1], # Batch 1 + [1, 0, -100], # Batch 2 (last token ignored) + ] + ) + accuracy = compute_token_accuracy(logits, labels) + self.assertAlmostEqual(accuracy, 0.8) diff --git a/trl/__init__.py b/trl/__init__.py index 4d5e9e041e..44a4333d53 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -100,7 +100,7 @@ "XPOTrainer", ], "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], - "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"], } try: @@ -200,7 +200,7 @@ XPOTrainer, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback - from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config try: if not is_diffusers_available(): diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 85968218cc..9ef887864a 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -76,6 +76,7 @@ "disable_dropout_in_model", "empty_cache", "peft_module_casting_to_bf16", + "compute_token_accuracy", ], "xpo_config": ["XPOConfig"], "xpo_trainer": ["XPOTrainer"], @@ -144,6 +145,7 @@ DataCollatorForCompletionOnlyLM, RunningMoments, compute_accuracy, + compute_token_accuracy, disable_dropout_in_model, empty_cache, peft_module_casting_to_bf16, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9e2e5fe04f..086a5bac79 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -16,15 +16,18 @@ import inspect import os import warnings +from collections import defaultdict from typing import Callable, Optional, Union import datasets import torch import torch.nn as nn +import transformers from accelerate.state import PartialState from datasets import Dataset from datasets.arrow_writer import SchemaInferenceError from datasets.builder import DatasetGenerationError +from packaging import version from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -48,6 +51,7 @@ from .utils import ( ConstantLengthDataset, DataCollatorForCompletionOnlyLM, + compute_token_accuracy, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16, @@ -304,6 +308,9 @@ def make_inputs_require_grad(module, input, output): UserWarning, ) + # Initialize the metrics + self._metrics = defaultdict(list) + super().__init__( model=model, args=args, @@ -546,3 +553,42 @@ def create_model_card( ) model_card.save(os.path.join(self.args.output_dir, "README.md")) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + + # Compute token accuracy if we have labels + if "labels" in inputs: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = inputs["labels"][..., 1:].contiguous() + + # Gather logits and labels from all GPUs first + shift_logits = self.accelerator.gather_for_metrics(shift_logits) + shift_labels = self.accelerator.gather_for_metrics(shift_labels) + + # Then compute accuracy on the gathered tensors + if self.accelerator.is_main_process: + accuracy = compute_token_accuracy(shift_logits, shift_labels) + self._metrics["mean_token_accuracy"].append(accuracy) + + return (loss, outputs) if return_outputs else loss + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics.clear() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 719d952f1f..029e5639ab 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1647,3 +1647,24 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask else: return mask, *tensors + + +def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float: + """ + Compute the mean token accuracy. + """ + # Get predictions + predictions = logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming pad_token_id is ignore_index) + mask = labels != ignore_index + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Calculate accuracy + accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 + + return accuracy From 82d12eb75103821cd4af1978e99b1026a90ac67d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 7 Feb 2025 12:41:58 +0100 Subject: [PATCH 68/96] =?UTF-8?q?=F0=9F=93=A0=20Log=20completions=20for=20?= =?UTF-8?q?GRPO=20(#2772)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * log completions * typo * wandb * Fix completions * Fix style? * Remove double import * Revert * group logging --------- Co-authored-by: lewtun --- trl/trainer/grpo_config.py | 11 +++++++++++ trl/trainer/grpo_trainer.py | 26 ++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index b30cf9899d..708a3188c8 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -99,6 +99,11 @@ class GRPOConfig(TrainingArguments): τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how frequently the current policy is synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`. + + > Parameters that control the logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log the completions during training. """ # Parameters that control the model and reference model @@ -227,3 +232,9 @@ class GRPOConfig(TrainingArguments): "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." }, ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={"help": "Whether to log the completions during training."}, + ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index fa55a4bb8e..4cbba2fcaa 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -300,6 +300,7 @@ def data_collator(features): # No data collation is needed in GRPO # Initialize the metrics self._metrics = defaultdict(list) + self.log_completions = args.log_completions super().__init__( model=model, @@ -534,9 +535,11 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s ) # Decode the generated completions - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): - completions = [[{"role": "assistant", "content": completion}] for completion in completions] + completions = [[{"role": "assistant", "content": completion}] for completion in completions_text] + else: + completions = completions_text rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( @@ -596,6 +599,25 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s self._metrics["reward"].append(rewards.mean().item()) self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) + if ( + self.log_completions + and self.state.global_step % self.args.logging_steps == 0 + and "wandb" in self.args.report_to + ): + import pandas as pd + + # For logging + table = { + "step": [str(self.state.global_step)] * len(rewards), + "prompt": gather_object(prompts_text), + "completion": gather_object(completions_text), + "reward": rewards.tolist(), + } + df = pd.DataFrame(table) + + if wandb.run is not None and self.accelerator.is_main_process: + wandb.log({"completions": wandb.Table(dataframe=df)}) + return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, From 5b9236d1e8d062ad76f088b0730e36c724ec170c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 8 Feb 2025 00:21:36 +0100 Subject: [PATCH 69/96] =?UTF-8?q?=F0=9F=94=AC=20SFT=20simplification=20(#2?= =?UTF-8?q?405)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit * update * Refactor SFTTrainer and SFTConfig * Update SFTConfig class in sft_config.py * Fix SFTConfig torch_dtype validation and dataset preprocessing flag * Refactor dataset mapping and conversion * Refactor dataset mapping in SFTTrainer * Fix SFTTrainerTester unit test by removing unnecessary code * Remove unused variables and update tokenization logic * Remove pack_dataset function * Add deprecation warning for tokenizer in SFTTrainer constructor * add docstring back * Update model parameter type annotation * Update SFTTrainer class definition * style * preprocess_dataset -> _prepare_dataset * Retro compat * Update formatting_func type hint in SFTTrainer constructor * typo * better comment * simplify tokenize row * Fix type hint for peft_config * fix doc * Add pack_examples function to `test_data_utils.py` * promote pack_examples and document * improve doc * Add new SFTTrainerTester2 class for testing * test was reversed * ©️ Copyrights update (#2454) * First changes * Other files * Finally * rm comment * fix nashmd * Fix example * Fix example * 💬 Fix chat for windows (#2443) * fix chat for windows * add some tests back * Revert "add some tests back" This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06. * 🆔 Add `datast_config` to `ScriptArguments` (#2440) * datast_config_name * Update trl/utils.py * sort import * typo * Trigger CI * Rename `dataset_config_name` to `dataset_config` * 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417) * Remove unused deepspeed code * add model prep back * add deepspeed even if it doesn't work * rm old code * 👯 Standardize `model_args` (#2442) * `model_config` -> `model_args` * sort * refactor config * drop skip prepare dataset * add sep to packing * drop prompt-completion for now * Revert "drop prompt-completion for now" This reverts commit 16ef195031ac9c860f8f2ac383ff34133fcbe70f. * Revert "add sep to packing" This reverts commit dc84d08da7a4b7804c064be1a15605f1770549e2. * Revert "drop skip prepare dataset" This reverts commit d2ee070d994a4b29ad33128a8ef99f101994a6c7. * Revert "refactor config" This reverts commit f732aa8728e42623ee5817b514263912cab337e4. * Format * Update doc-builder workflow to use specific commit sha * add peft edge cases * no logits when using liger * remove unused columns * proper handle of prompt-completion * trick to keep messages * fix messages missing * for Liger kernel, ensure only input_ids is present * packing and liger are compatible * shinny doc and final nits * another nit * refactor config and doc * re add truncation * fix ci * drop deprecated params in tests * fix link * fix config docstring --------- Co-authored-by: Kashif Rasul --- docs/source/data_utils.md | 4 + tests/test_data_utils.py | 43 +++ tests/test_sft_trainer.py | 120 +++--- tests/test_trainers_args.py | 4 - trl/__init__.py | 2 + trl/data_utils.py | 34 ++ trl/trainer/gkd_trainer.py | 10 +- trl/trainer/sft_config.py | 162 +++++--- trl/trainer/sft_trainer.py | 744 ++++++++++++++++-------------------- 9 files changed, 597 insertions(+), 526 deletions(-) diff --git a/docs/source/data_utils.md b/docs/source/data_utils.md index 9b8391278d..7faafd256a 100644 --- a/docs/source/data_utils.md +++ b/docs/source/data_utils.md @@ -27,3 +27,7 @@ ## maybe_unpair_preference_dataset [[autodoc]] maybe_unpair_preference_dataset + +## pack_examples + +[[autodoc]] pack_examples diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 95fe8e7049..b1908fc996 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -26,6 +26,7 @@ maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, + pack_examples, unpair_preference_dataset, ) @@ -392,6 +393,48 @@ def test_maybe_extract_prompt_standard_already_explicit(self): ) +class TestPackExamples(unittest.TestCase): + def test_pack_examples_larger_chunks(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + seq_length = 5 + expected_output = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7, 8]], + "attention_mask": [[0, 1, 1, 0, 0], [1, 1, 1]], + } + result = pack_examples(examples, seq_length) + self.assertEqual(result, expected_output) + + def test_pack_examples_smaller_chunks(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + seq_length = 2 + expected_output = { + "input_ids": [[1, 2], [3, 4], [5, 6], [7, 8]], + "attention_mask": [[0, 1], [1, 0], [0, 1], [1, 1]], + } + result = pack_examples(examples, seq_length) + self.assertEqual(result, expected_output) + + def test_pack_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": seq_length}) + self.assertEqual(dataset.to_dict(), expected_output) + + # Run the tests if __name__ == "__main__": unittest.main() diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 1724ff4c13..158a87586e 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -53,7 +53,7 @@ def formatting_prompts_func_batched(example): if is_peft_available(): - from peft import LoraConfig, PeftModel + from peft import LoraConfig, PeftModel, get_peft_model if is_vision_available(): from PIL import Image as PILImage @@ -327,7 +327,6 @@ def test_sft_trainer_uncorrect_data(self): save_steps=1, per_device_train_batch_size=2, max_seq_length=32, # make sure there is at least 1 packed sequence - num_of_sequences=32, packing=True, report_to="none", ) @@ -408,45 +407,6 @@ def test_sft_trainer_uncorrect_data(self): formatting_func=formatting_prompts_func, ) - # This should not work because not enough data for one sample - training_args = SFTConfig( - output_dir=tmp_dir, - dataloader_drop_last=True, - max_steps=2, - eval_steps=1, - save_steps=1, - per_device_train_batch_size=2, - max_seq_length=1024, # make sure there is NOT at least 1 packed sequence - packing=True, - report_to="none", - ) - with self.assertRaises(ValueError): - _ = SFTTrainer( - model=self.model, - args=training_args, - train_dataset=self.dummy_dataset, - formatting_func=formatting_prompts_func, - ) - - # This should not work as well - with self.assertRaises(ValueError): - training_args = SFTConfig( - output_dir=tmp_dir, - dataloader_drop_last=True, - max_steps=2, - eval_steps=1, - save_steps=1, - per_device_train_batch_size=2, - packing=False, - report_to="none", - ) - _ = SFTTrainer( - model=self.model, - args=training_args, - train_dataset=self.dummy_dataset, - formatting_func=formatting_prompts_func, - ) - # but this should work training_args = SFTConfig( output_dir=tmp_dir, @@ -502,7 +462,6 @@ def test_sft_trainer_with_model_num_train_epochs(self): num_train_epochs=2, per_device_train_batch_size=2, max_seq_length=16, - num_of_sequences=16, packing=True, report_to="none", ) @@ -576,7 +535,6 @@ def test_sft_trainer_with_model(self): save_steps=1, per_device_train_batch_size=2, max_seq_length=16, - num_of_sequences=16, packing=True, report_to="none", ) @@ -601,7 +559,6 @@ def test_sft_trainer_with_model(self): save_steps=1, per_device_train_batch_size=2, max_seq_length=16, - num_of_sequences=16, packing=True, report_to="none", ) @@ -808,8 +765,6 @@ def test_sft_trainer_infinite_with_model(self): eval_dataset=self.eval_dataset, ) - self.assertTrue(trainer.train_dataset.infinite) - trainer.train() self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) @@ -837,8 +792,6 @@ def test_sft_trainer_infinite_with_model_epochs(self): eval_dataset=self.eval_dataset, ) - self.assertFalse(trainer.train_dataset.infinite) - trainer.train() self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) @@ -1345,6 +1298,75 @@ def test_sft_trainer_torch_dtype(self): ) self.assertIn( - "Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.", + "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + "a `torch.dtype` (e.g., 'float32'), but got -1.", str(context.exception), ) + + +# This new tester aims to replace the first one at some point +class SFTTrainerTester2(unittest.TestCase): + def test_train(self): + # Get the model and dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif ( + "base_layer" not in n + ): # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 25ed71ff0b..251b1f5a96 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -375,8 +375,6 @@ def test_sft(self): model_init_kwargs={"trust_remote_code": True}, dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True}, eval_packing=True, - num_of_sequences=32, - chars_per_token=4.2, ) trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") @@ -389,8 +387,6 @@ def test_sft(self): self.assertIn("append_concat_token", trainer.args.dataset_kwargs) self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True) self.assertEqual(trainer.args.eval_packing, True) - self.assertEqual(trainer.args.num_of_sequences, 32) - self.assertEqual(trainer.args.chars_per_token, 4.2) @parameterized.expand([(False,), (True,)]) def test_xpo(self, alpha_list): diff --git a/trl/__init__.py b/trl/__init__.py index 44a4333d53..0cd4be772f 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -28,6 +28,7 @@ "maybe_apply_chat_template", "maybe_extract_prompt", "maybe_unpair_preference_dataset", + "pack_examples", "unpair_preference_dataset", ], "environment": ["TextEnvironment", "TextHistory"], @@ -127,6 +128,7 @@ maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, + pack_examples, unpair_preference_dataset, ) from .environment import TextEnvironment, TextHistory diff --git a/trl/data_utils.py b/trl/data_utils.py index 35332fd45f..2d69cdd431 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -412,3 +412,37 @@ def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]: if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv): return example return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]}) + + +def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, list[list]]: + """ + Pack examples into chunks of size `seq_length`. + + Args: + examples (`dict[str, list[list]]`): + Dictionary of examples with keys as strings and values as lists of lists. + seq_length (`int`): + Maximum sequence length. + + Returns: + `dict[str, list[list]]`: Dictionary of examples with keys as strings and values as lists of lists. + + Example: + + ```python + >>> from trl import pack_examples + >>> examples = { + ... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + ... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + ... } + >>> pack_examples(examples, seq_length=5) + {'input_ids': [[1, 2, 3, 4, 5], [6, 7, 8]], 'attention_mask': [[0, 1, 1, 0, 0], [1, 1, 1]]} + >>> pack_examples(examples, seq_length=2) + {'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]} + ``` + """ + # Join all the values into a single list + examples = {k: sum(v, []) for k, v in examples.items()} + # Split the values into chunks of size seq_length + examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()} + return examples diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 3db1f95252..5e76f30ed2 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -78,7 +78,6 @@ def __init__( processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), @@ -97,7 +96,6 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, @@ -158,6 +156,14 @@ def __init__( ): self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + def _prepare_dataset(self, dataset, *args): + # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we + # need to keep the messages column as it is. We use the following workaround to keep the messages column. + dataset = dataset.add_column("_messages", dataset["messages"]) + dataset = super()._prepare_dataset(dataset, *args) + dataset = dataset.rename_column("_messages", "messages") + return dataset + @staticmethod def generalized_jsd_loss( student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 250eb74a0a..ad0e936c18 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Any, Optional @@ -23,103 +24,144 @@ class SFTConfig(TrainingArguments): r""" Configuration class for the [`SFTTrainer`]. + Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the + [`~transformers.TrainingArguments`] documentation. + Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: - dataset_text_field (`str`, *optional*, defaults to `"text"`): - Name of the text field of the dataset. If provided, the trainer will automatically create a - [`ConstantLengthDataset`] based on `dataset_text_field`. - packing (`bool`, *optional*, defaults to `False`): - Controls whether the [`ConstantLengthDataset`] packs the sequences of the dataset. - learning_rate (`float`, *optional*, defaults to `2e-5`): - Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - max_seq_length (`int` or `None`, *optional*, defaults to `None`): - Maximum sequence length for the [`ConstantLengthDataset`] and for automatically creating the dataset. If - `None`, it uses the smaller value between `tokenizer.model_max_length` and `1024`. - dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): - Number of processes to use for processing the dataset. Only used when `packing=False`. - dataset_batch_size (`Union[int, None]`, *optional*, defaults to `1000`): - Number of examples to tokenize per batch. If `dataset_batch_size <= 0` or `dataset_batch_size is None`, - tokenizes the full dataset as a single batch. + > Parameters that control the model + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SFTTrainer`] is provided as a string. + use_liger (`bool`, *optional*, defaults to `False`): + Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. + + > Parameters that control the data preprocessing + + dataset_text_field (`str`, *optional*, defaults to `"text"`): + Name of the column that contains text data in the dataset. dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + max_seq_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the + right. + If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + packing (`bool`, *optional*, defaults to `False`): + Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence + length. eval_packing (`bool` or `None`, *optional*, defaults to `None`): Whether to pack the eval dataset. If `None`, uses the same value as `packing`. - num_of_sequences (`int`, *optional*, defaults to `1024`): - Number of sequences to use for the [`ConstantLengthDataset`]. - chars_per_token (`float`, *optional*, defaults to `3.6`): - Number of characters per token to use for the [`ConstantLengthDataset`]. See - [chars_token_ratio](https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53) for more details. - use_liger (`bool`, *optional*, defaults to `False`): - Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. + + > Parameters that control the training + + learning_rate (`float`, *optional*, defaults to `2e-5`): + Initial learning rate for [`AdamW`] optimizer. The default value replaces that of + [`~transformers.TrainingArguments`]. """ - dataset_text_field: str = field( - default="text", + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Name of the text field of the dataset. If provided, the trainer will automatically create a " - "`ConstantLengthDataset` based on `dataset_text_field`." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `SFTTrainer` is provided as a string." }, ) - packing: bool = field( + use_liger: bool = field( default=False, - metadata={"help": "Controls whether the `ConstantLengthDataset` packs the sequences of the dataset."}, + metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."}, ) - learning_rate: float = field( - default=2.0e-5, - metadata={ - "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " - "`TrainingArguments`." - }, + + # Parameters that control the data preprocessing + dataset_text_field: str = field( + default="text", + metadata={"help": "Name of the column that contains text data in the dataset."}, ) - max_seq_length: Optional[int] = field( + dataset_kwargs: Optional[dict[str, Any]] = field( default=None, metadata={ - "help": "Maximum sequence length for the `ConstantLengthDataset` and for automatically creating the " - "dataset. If `None`, it uses the smaller value between `tokenizer.model_max_length` and `1024`." + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`." }, ) dataset_num_proc: Optional[int] = field( default=None, - metadata={"help": "Number of processes to use for processing the dataset. Only used when `packing=False`."}, + metadata={"help": "Number of processes to use for processing the dataset."}, ) - dataset_batch_size: int = field( - default=1000, + max_seq_length: Optional[int] = field( + default=1024, metadata={ - "help": "Number of examples to tokenize per batch. If `dataset_batch_size <= 0` or `dataset_batch_size is " - "None`, tokenizes the full dataset as a single batch." + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated " + "from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "sequence length." }, ) - model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, + packing: bool = field( + default=False, metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " - "from a string." + "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to " + "define sequence length." }, ) - dataset_kwargs: Optional[dict[str, Any]] = field( + eval_packing: Optional[bool] = field( default=None, + metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + ) + + # Parameters that control the training + learning_rate: float = field( + default=2.0e-5, metadata={ - "help": "Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets." + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`TrainingArguments`." }, ) - eval_packing: Optional[bool] = field( + + # Deprecated parameters + dataset_batch_size: int = field( default=None, - metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + metadata={"help": "Deprecated. You can safely remove this parameter from your configuration."}, ) num_of_sequences: int = field( - default=1024, - metadata={"help": "Number of sequences to use for the `ConstantLengthDataset`."}, + default=None, + metadata={ + "help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized " + "sequence, unlike `num_of_sequences`, which referred to string sequences." + }, ) chars_per_token: float = field( - default=3.6, metadata={"help": "Number of characters per token to use for the `ConstantLengthDataset`."} - ) - use_liger: bool = field( - default=False, - metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."}, + default=None, + metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."}, ) + + def __post_init__(self): + super().__post_init__() + + if self.dataset_batch_size is not None: + warnings.warn( + "`dataset_batch_size` is deprecated and will be remove in version 0.18.0. You can safely remove this " + "parameter from your configuration.", + DeprecationWarning, + ) + + if self.num_of_sequences is not None: + warnings.warn( + "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, " + "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r" + "eferred to string sequences.", + DeprecationWarning, + ) + + if self.chars_per_token is not None: + warnings.warn( + "`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the " + "packing length, use `max_seq_length`.", + DeprecationWarning, + ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 086a5bac79..1229e4c6ac 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -13,20 +13,16 @@ # limitations under the License. import dataclasses -import inspect import os import warnings from collections import defaultdict -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Type, Union -import datasets import torch import torch.nn as nn import transformers -from accelerate.state import PartialState -from datasets import Dataset -from datasets.arrow_writer import SchemaInferenceError -from datasets.builder import DatasetGenerationError +from accelerate import PartialState +from datasets import Dataset, IterableDataset from packaging import version from transformers import ( AutoModelForCausalLM, @@ -39,6 +35,7 @@ PreTrainedTokenizerBase, ProcessorMixin, Trainer, + TrainingArguments, is_wandb_available, ) from transformers.trainer_callback import TrainerCallback @@ -46,11 +43,10 @@ from transformers.utils import is_liger_kernel_available, is_peft_available from transformers.utils.deprecation import deprecate_kwarg -from ..extras.dataset_formatting import get_formatting_func_from_dataset +from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples from .sft_config import SFTConfig from .utils import ( ConstantLengthDataset, - DataCollatorForCompletionOnlyLM, compute_token_accuracy, generate_model_card, get_comet_experiment_url, @@ -59,6 +55,7 @@ if is_peft_available(): + import peft from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_liger_kernel_available(): @@ -69,45 +66,76 @@ class SFTTrainer(Trainer): - r""" - Class definition of the Supervised Finetuning Trainer (SFT Trainer). - This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. - The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. + """ + Trainer for Supervised Fine-Tuning (SFT) method. + + This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from datasets import load_dataset + from trl import SFTTrainer + + dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + + trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` Args: - model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): - The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to - load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is - passed to the `peft_config` argument. - args (`Optional[SFTConfig]`): - The arguments to tweak for training. Will default to a basic instance of [`SFTConfig`] with the `output_dir` - set to a directory named *tmp_trainer* in the current directory if not provided. - data_collator (`Optional[transformers.DataCollator]`): - The data collator to use for training. - train_dataset (`Optional[datasets.Dataset]`): - The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. - eval_dataset (Optional[Union[`datasets.Dataset`, dict[`str`, `datasets.Dataset`]]]): - The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. - processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - This supercedes the `tokenizer` argument, which is now deprecated. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be used. - compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to None): - The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. - If not specified, only the loss will be computed during evaluation. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`Optional[PeftConfig]`): - The PeftConfig object to use to initialize the PeftModel. + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + args ([`SFTConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator (`DataCollator`, *optional*): + Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`. + Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance + of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or + tokenizer. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoTokenizer.from_pretrained`]. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. formatting_func (`Optional[Callable]`): - The formatting function to be used for creating the `ConstantLengthDataset`. + Formatting function applied to the dataset before tokenization. """ _tag_names = ["trl", "sft"] @@ -117,200 +145,94 @@ class SFTTrainer(Trainer): ) def __init__( self, - model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: Optional[SFTConfig] = None, + model: Union[str, nn.Module, PreTrainedModel], + args: Optional[Union[SFTConfig, TrainingArguments]] = None, data_collator: Optional[DataCollator] = None, # type: ignore - train_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, - formatting_func: Optional[Callable] = None, + formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None, ): + # Args if args is None: - args = SFTConfig(output_dir="tmp_trainer") - elif args is not None and args.__class__.__name__ == "TrainingArguments": - args_as_dict = args.to_dict() - # Manually copy token values as TrainingArguments.to_dict() redacts them - args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) - args = SFTConfig(**args_as_dict) - - if getattr(args, "model_init_kwargs", None) is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." - ) - model_init_kwargs["torch_dtype"] = torch_dtype - - if isinstance(model, str): - if args.use_liger: - model = AutoLigerKernelForCausalLM.from_pretrained(model, **model_init_kwargs) - else: - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - if args.packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): - raise ValueError( - "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = SFTConfig(f"{model_name}-SFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = SFTConfig(**dict_args) + + # Model + if args.model_init_kwargs is not None and not isinstance(model, str): + warnings.warn( + "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." ) + if isinstance(model, str): + model = self._create_model_from_path(model, args) - if is_peft_available() and peft_config is not None: - if not isinstance(peft_config, PeftConfig): - raise ValueError( - "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." - f" and you passed a {type(peft_config)}." - ) - - if not isinstance(model, PeftModel): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} - is_sharded_qlora = False - # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call - # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing - # QLoRA + FSDP / DS-Zero3 - if getattr(model, "is_loaded_in_4bit", False): - for _, param in model.named_parameters(): - if param.__class__.__name__ == "Params4bit": - is_sharded_qlora = param.data.device.type in {"cpu", "meta"} - break - if getattr(model, "is_loaded_in_8bit", False) or ( - getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora - ): - prepare_model_kwargs = { - "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) - } - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - - if args is not None: - args = dataclasses.replace(args, gradient_checkpointing=False) - elif getattr(args, "gradient_checkpointing", False) and ( - "use_reentrant" not in gradient_checkpointing_kwargs - or gradient_checkpointing_kwargs["use_reentrant"] - ): - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if ( - "autocast_adapter_dtype" in list(inspect.signature(get_peft_model).parameters) - and getattr(model, "is_loaded_in_4bit", False) - and is_sharded_qlora - ): - model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) - else: - model = get_peft_model(model, peft_config) - if ( - args is not None - and args.bf16 - and getattr(model, "is_loaded_in_4bit", False) - and not is_sharded_qlora - ): - peft_module_casting_to_bf16(model) + # PEFT configuration and model wrapping + if peft_config is not None: + model = self._prepare_peft_model(model, peft_config, args) + # 3. Handle the tokenizer if processing_class is None: processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path) - if getattr(processing_class, "pad_token", None) is None: - processing_class.pad_token = processing_class.eos_token - - if args.max_seq_length is None: - # to overcome some issues with broken tokenizers - args.max_seq_length = min(processing_class.model_max_length, 1024) - - self.dataset_num_proc = args.dataset_num_proc - self.dataset_batch_size = args.dataset_batch_size - - if args.dataset_kwargs is None: - args.dataset_kwargs = {} - - if formatting_func is None: - # check if dataset has ChatML format or instruction format and is supported - # if not stays None - formatting_func = get_formatting_func_from_dataset(train_dataset, processing_class) - # if a template is detected, we don't need to add special tokens again - if formatting_func is not None: - args.dataset_kwargs["add_special_tokens"] = False - - if not args.packing: - if data_collator is None: - data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) - - # Pre-process the datasets only once per node. The remaining processes will use the cache. - with PartialState().local_main_process_first(): - if train_dataset is not None: - train_dataset = self._prepare_dataset( - train_dataset, - processing_class, - args.packing, - args.dataset_text_field, - args.max_seq_length, - formatting_func, - args.num_of_sequences, - args.chars_per_token, - remove_unused_columns=args.remove_unused_columns if args is not None else True, - **args.dataset_kwargs, - ) + if processing_class.pad_token is None: + processing_class.pad_token = processing_class.eos_token # required for padding when collating data + + # Dataset + preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) + if preprocess_dataset: + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) if eval_dataset is not None: - _multiple = isinstance(eval_dataset, dict) - _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} - - eval_packing = args.packing if args.eval_packing is None else args.eval_packing - - for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): - _eval_datasets[_eval_dataset_name] = self._prepare_dataset( - _eval_dataset, - processing_class, - eval_packing, - args.dataset_text_field, - args.max_seq_length, - formatting_func, - args.num_of_sequences, - args.chars_per_token, - remove_unused_columns=args.remove_unused_columns if args is not None else True, - **args.dataset_kwargs, + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" ) - if not _multiple: - eval_dataset = _eval_datasets["singleton"] - if processing_class.padding_side is not None and processing_class.padding_side != "right": - warnings.warn( - "You passed a processing_class with `padding_side` not equal to `right` to the SFTTrainer. This might " - "lead to some unexpected behaviour due to overflow issues when training a model in half-precision. " - "You might consider adding `processing_class.padding_side = 'right'` to your code.", - UserWarning, - ) + # Data collator + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) # Initialize the metrics self._metrics = defaultdict(list) + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation + # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped. + super_init_kwargs = {} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs + else: + if optimizer_cls_and_kwargs is not None: + warnings.warn( + "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. " + "The default optimizer will be used. " + "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`." + ) super().__init__( model=model, args=args, @@ -318,196 +240,235 @@ def make_inputs_require_grad(module, input, output): train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, + compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, + **super_init_kwargs, ) # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - if self.train_dataset is not None: - if self.args.max_steps > 0 and args.packing: - self.train_dataset.infinite = True - elif self.args.max_steps == -1 and args.packing: - self.train_dataset.infinite = False + def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel: + """Creates a model from a path or model identifier.""" + model_init_kwargs = args.model_init_kwargs or {} + # Handle torch dtype + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + if args.gradient_checkpointing: + model_init_kwargs["use_cache"] = False + + # Create model + if args.use_liger: + if not is_liger_kernel_available(): + raise ImportError("Please install Liger-kernel for use_liger=True") + return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + + def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + if not is_peft_available(): + raise ImportError("To use PeftModel, you need to install the `peft` library.") + + if not isinstance(peft_config, PeftConfig): + raise ValueError( + f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need " + "to pass a PeftConfig object to the SFTTrainer." + ) - def _prepare_dataset( - self, - dataset, - processing_class, - packing, - dataset_text_field: str, - max_seq_length, - formatting_func: Optional[Callable], - num_of_sequences, - chars_per_token, - remove_unused_columns=True, - append_concat_token=True, - add_special_tokens=True, - skip_prepare_dataset=False, - ): - if dataset is None: - raise ValueError("The dataset should not be None") + if isinstance(model, PeftModel): + return model + + # Handle quantized models (QLoRA) + is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) + + is_sharded_qlora = False + if getattr(model, "is_loaded_in_4bit", False): + # Check if model is sharded (FSDP/DS-Zero3) + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} + break + + # Prepare model for kbit training if needed + if is_qlora and not is_sharded_qlora: + model = self._prepare_model_for_kbit_training(model, args) + # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training + args = dataclasses.replace(args, gradient_checkpointing=False) + elif args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Create PEFT model + if ( + version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) - if skip_prepare_dataset: - return dataset + # Handle bf16 casting for 4-bit models + if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: + peft_module_casting_to_bf16(model) + + return model + + def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: + """Prepares a quantized model for kbit training.""" + prepare_model_kwargs = { + "use_gradient_checkpointing": args.gradient_checkpointing, + "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {}, + } - # If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is - # a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset - column_names = ( - dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None + return prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] ) - if column_names and "input_ids" in column_names: - if formatting_func is not None: - warnings.warn( - "You passed a dataset that is already processed (contains an `input_ids` field) together with a " - "valid formatting function. Therefore `formatting_func` will be ignored. Either remove the " - "`formatting_func` or pass a dataset that is not already processed.", - UserWarning, - ) - def formatting_func(x): - return x["input_ids"] + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - if not packing: - return dataset + return model - # check if torch dataset / dataloader and do nothing - # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check - if isinstance( - dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset) - ) and not isinstance(dataset, datasets.IterableDataset): + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: SFTConfig, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Convert the dataset to an IterableDataset if it is a ConstantLengthDataset + if isinstance(dataset, ConstantLengthDataset): return dataset - if not packing: - return self._prepare_non_packed_dataloader( - processing_class, - dataset, - dataset_text_field, - max_seq_length, - formatting_func, - add_special_tokens, - remove_unused_columns, - ) + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc - else: - return self._prepare_packed_dataloader( - processing_class, - dataset, - dataset_text_field, - max_seq_length, - num_of_sequences, - chars_per_token, - formatting_func, - append_concat_token, - add_special_tokens, - ) + with PartialState().local_main_process_first(): + # Apply the formatting function if any + if formatting_func is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" - def _prepare_non_packed_dataloader( - self, - processing_class, - dataset, - dataset_text_field: str, - max_seq_length, - formatting_func: Optional[Callable] = None, - add_special_tokens=True, - remove_unused_columns=True, - ): - # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt - def tokenize(element): - outputs = processing_class( - element[dataset_text_field] if formatting_func is None else formatting_func(element), - add_special_tokens=add_special_tokens, - truncation=True, - padding=False, - max_length=max_seq_length, - return_overflowing_tokens=False, - return_length=False, - ) + batched = isinstance(formatting_func(next(iter(dataset))), list) - if formatting_func is not None and not isinstance(formatting_func(element), list): - raise ValueError( - "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." - ) + def _func(example): + return {"text": formatting_func(example)} - return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + dataset = dataset.map(_func, batched=batched, **map_kwargs) - signature_columns = ["input_ids", "labels", "attention_mask"] + # If the dataset is prompt-completion, convert it to language modeling type + if "prompt" in dataset.column_names and "completion" in dataset.column_names: + key = "messages" if is_conversational(dataset[0]) else "text" - if dataset.column_names is not None: # None for IterableDataset - extra_columns = list(set(dataset.column_names) - set(signature_columns)) - else: - extra_columns = [] + def concat_prompt_completion(example): + return {key: example["prompt"] + example["completion"]} - if not remove_unused_columns and len(extra_columns) > 0: - warnings.warn( - "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with " - "the default collator and yield to errors. If you want to inspect dataset other columns (in this " - f"case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the " - "default collator and create your own data collator in order to inspect the unused dataset columns.", - UserWarning, + dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"]) + + # Apply the chat template if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + remove_columns="messages" if "messages" in dataset.column_names else None, # renamed to "text" + **map_kwargs, ) - map_kwargs = { - "batched": True, - "remove_columns": dataset.column_names if remove_unused_columns else None, - "batch_size": self.dataset_batch_size, - } - if isinstance(dataset, datasets.Dataset): - map_kwargs["num_proc"] = self.dataset_num_proc # this arg is not available for IterableDataset - tokenized_dataset = dataset.map(tokenize, **map_kwargs) + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs) + + # Pack or truncate + if packing: + if args.max_seq_length is None: + raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + dataset = dataset.select_columns("input_ids") + dataset = dataset.map( + pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs + ) + elif args.max_seq_length is not None: + dataset = dataset.map( + lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]}, + **map_kwargs, + ) + # For Liger kernel, ensure only input_ids is present + if args.use_liger: + dataset = dataset.select_columns("input_ids") - return tokenized_dataset + return dataset - def _prepare_packed_dataloader( - self, - processing_class, - dataset, - dataset_text_field: str, - max_seq_length, - num_of_sequences, - chars_per_token, - formatting_func: Optional[Callable] = None, - append_concat_token=True, - add_special_tokens=True, - ): - if processing_class is None: - raise ValueError("You need to pass a processing_class with `SFTTrainer`.") - - constant_length_iterator = ConstantLengthDataset( - processing_class, - dataset, - dataset_text_field=None if formatting_func is not None else dataset_text_field, - formatting_func=formatting_func, - seq_length=max_seq_length, - infinite=False, - num_of_sequences=num_of_sequences, - chars_per_token=chars_per_token, - eos_token_id=processing_class.eos_token_id, - append_concat_token=append_concat_token, - add_special_tokens=add_special_tokens, + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch ) - if isinstance(dataset, datasets.IterableDataset): - return constant_length_iterator + # Compute token accuracy if we have labels and if the model is not using Liger (no logits) + if "labels" in inputs and not self.args.use_liger: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = inputs["labels"][..., 1:].contiguous() - def data_generator(constant_length_iterator): - yield from constant_length_iterator + # Gather logits and labels from all GPUs first + shift_logits = self.accelerator.gather_for_metrics(shift_logits) + shift_labels = self.accelerator.gather_for_metrics(shift_labels) - try: - packed_dataset = Dataset.from_generator( - data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator} - ) - except (DatasetGenerationError, SchemaInferenceError) as exc: - raise ValueError( - "Error occurred while packing the dataset. " - "Make sure that your dataset has enough samples to at least yield one packed sequence." - ) from exc - return packed_dataset + # Then compute accuracy on the gathered tensors + if self.accelerator.is_main_process: + accuracy = compute_token_accuracy(shift_logits, shift_labels) + self._metrics["mean_token_accuracy"].append(accuracy) + + return (loss, outputs) if return_outputs else loss + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics.clear() def create_model_card( self, @@ -553,42 +514,3 @@ def create_model_card( ) model_card.save(os.path.join(self.args.output_dir, "README.md")) - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - """ - Compute training loss and additionally compute token accuracies - """ - (loss, outputs) = super().compute_loss( - model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch - ) - - # Compute token accuracy if we have labels - if "labels" in inputs: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = inputs["labels"][..., 1:].contiguous() - - # Gather logits and labels from all GPUs first - shift_logits = self.accelerator.gather_for_metrics(shift_logits) - shift_labels = self.accelerator.gather_for_metrics(shift_labels) - - # Then compute accuracy on the gathered tensors - if self.accelerator.is_main_process: - accuracy = compute_token_accuracy(shift_logits, shift_labels) - self._metrics["mean_token_accuracy"].append(accuracy) - - return (loss, outputs) if return_outputs else loss - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics - - # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` - # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. - if next(iter(logs.keys())).startswith("eval_"): - metrics = {f"eval_{key}": val for key, val in metrics.items()} - - logs = {**logs, **metrics} - if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - super().log(logs, start_time) - else: # transformers<=4.46 - super().log(logs) - self._metrics.clear() From 7fdb69aa7d2799b2806d5207f56548be1cb4032a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 8 Feb 2025 00:29:26 +0100 Subject: [PATCH 70/96] =?UTF-8?q?=E2=9E=96=20Fix=20GRPO=20example=20in=20R?= =?UTF-8?q?EADME=20(#2800)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 90887a2715..01f50ad7ae 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ dataset = load_dataset("trl-lib/tldr", split="train") # Dummy reward function: rewards completions that are close to 20 characters def reward_len(completions, **kwargs): - return [abs(20 - len(completion)) for completion in completions] + return [-abs(20 - len(completion)) for completion in completions] training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) trainer = GRPOTrainer( From 09eefa73abbba26c73851f1094f325aa1671f3fa Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 15:59:46 -0800 Subject: [PATCH 71/96] =?UTF-8?q?=E2=9B=B0=EF=B8=8F=20Reduce=20peak=20vram?= =?UTF-8?q?=20consumption=20with=20efficient=20selective=20log=5Fsoftmax?= =?UTF-8?q?=20(#2799)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Reduce mem consumption across many trainers with efficient selective log-softmax approach * rename * typo fix * precommit * Update tests/test_core.py * relocate * precommit * style * smaller values for test, and run on cpu * nit doc improvements * style * fix test --------- Co-authored-by: Quentin Gallouédec --- tests/test_utils.py | 23 ++++++++++++++++++++++ trl/trainer/bco_trainer.py | 11 +++++++---- trl/trainer/cpo_trainer.py | 3 ++- trl/trainer/dpo_trainer.py | 11 +++++++---- trl/trainer/grpo_trainer.py | 9 ++------- trl/trainer/kto_trainer.py | 11 +++++++---- trl/trainer/nash_md_trainer.py | 4 ++-- trl/trainer/orpo_trainer.py | 3 ++- trl/trainer/ppo_trainer.py | 25 ++++++++++++----------- trl/trainer/rloo_trainer.py | 20 ++++++++----------- trl/trainer/utils.py | 36 ++++++++++++++++++++++++++++++++++ trl/trainer/xpo_trainer.py | 4 ++-- 12 files changed, 110 insertions(+), 50 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9fd7ed9e0f..64ab68bf74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ import numpy as np import torch from datasets import load_dataset +from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available @@ -32,6 +33,7 @@ generate_model_card, get_peft_config, pad, + selective_log_softmax, ) @@ -506,3 +508,24 @@ def test_batch_accuracy(self): ) accuracy = compute_token_accuracy(logits, labels) self.assertAlmostEqual(accuracy, 0.8) + + +class TestSelectiveLogSoftmax(unittest.TestCase): + @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) + def test_selective_log_softmax(self, dtype): + """Test selective_log_softmax with logits of different dtypes""" + vocab_size = 1024 + batch_size = 4 + seq_len = 32 + + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) + + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + actual_output = selective_log_softmax(logits, input_ids) + + if dtype in [torch.float16, torch.bfloat16]: + # half-precision dtypes fall back to an exact method + self.assertTrue(torch.equal(actual_output, expected_output)) + else: + torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 18df4a7210..db4fa156e4 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -65,6 +65,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -897,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1062,7 +1065,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 644a2d5353..174cb4f255 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -60,6 +60,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -711,7 +712,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index cd3c3b4dca..a16edb6f37 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -69,6 +69,7 @@ pad, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -822,9 +823,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1211,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Compute the log probabilities of the labels labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) per_token_logps[~loss_mask] = 0 per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 4cbba2fcaa..bab86f7bcc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -47,7 +47,7 @@ from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig -from .utils import generate_model_card, get_comet_experiment_url, pad +from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax if is_peft_available(): @@ -442,12 +442,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - - # Compute the log probabilities for the input tokens. - token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak - token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits))) - return token_log_probs + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index c45a88d554..0c92ad70b8 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -63,6 +63,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -812,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1032,7 +1035,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index cbe218066e..5d2a8e830d 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -46,6 +46,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, ) @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions under the model diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 832927d05c..72436d321d 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -64,6 +64,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -718,7 +719,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels = torch.where(labels == label_pad_token_id, 0, labels) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 27cbdd016c..cf7f6768a6 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -25,7 +25,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -65,6 +64,7 @@ peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, ) @@ -310,9 +310,11 @@ def get_eval_dataloader(self) -> DataLoader: @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model.policy - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.policy.set_adapter(self.ref_adapter_name) yield @@ -427,9 +429,8 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = selective_log_softmax(logits, response) + del logits torch.cuda.empty_cache() if ref_policy is None: @@ -439,9 +440,8 @@ def repeat_generator(): ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -547,8 +547,7 @@ def repeat_generator(): output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -599,7 +598,7 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6626baa15c..344253c2b8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -24,7 +24,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -56,6 +55,7 @@ get_reward, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, ) from .rloo_config import RLOOConfig @@ -330,17 +330,15 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = selective_log_softmax(logits, response) + del logits torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -467,8 +465,7 @@ def repeat_generator(): logits /= args.temperature + 1e-7 # Compute new logprobs - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -512,9 +509,8 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, logits, new_all_logprobs, new_logprobs, - logprobs_diff, ratio, pg_losses, pg_losses2, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, + output, logits, new_logprobs, logprobs_diff, ratio, pg_losses, + pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_advantage, mb_responses, mb_query_responses, mb_logprobs, ) # fmt: on diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 029e5639ab..ea603b9637 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -26,6 +26,7 @@ import numpy as np import pandas as pd import torch +import torch.nn.functional as F import torch.utils.data from accelerate import Accelerator, PartialState from accelerate.state import AcceleratorState @@ -1668,3 +1669,38 @@ def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_in accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 return accuracy + + +def selective_log_softmax(logits, index): + """ + A memory-efficient implementation of the common `log_softmax -> gather` operation. + + This function is equivalent to the following naive implementation: + ```python + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 2d535344e7..6c7579ae8a 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -44,6 +44,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, ) from .xpo_config import XPOConfig @@ -274,8 +275,7 @@ def _compute_logprobs(self, model, model_data, ref_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions From 55e680e142d88e090dcbf5a469eab1ebba28ddef Mon Sep 17 00:00:00 2001 From: Maxim Evtush <154841002+maximevtush@users.noreply.github.com> Date: Sat, 8 Feb 2025 20:46:47 +0100 Subject: [PATCH 72/96] fix: typos in documentation files (#2804) --- docs/source/learning_tools.md | 4 ++-- docs/source/text_environments.md | 2 +- docs/source/xpo_trainer.md | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/learning_tools.md b/docs/source/learning_tools.md index bf42994194..368d666fe8 100644 --- a/docs/source/learning_tools.md +++ b/docs/source/learning_tools.md @@ -22,7 +22,7 @@ Note that the scripts above rely heavily on the `TextEnvironment` API which is s The rough idea is as follows: -1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number: +1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number: ```python from transformers import AutoTokenizer, load_tool tool = load_tool("ybelkada/simple-calculator") @@ -154,7 +154,7 @@ We then basically deployed this snippet as a Hugging Face space [here](https://h We use the following settings: * use the `bigcode/starcoderbase` model as the base model -* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. +* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. * test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0. * notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer. * used the following prompt that demonstrates the usage of the wiki tool. diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index c7b0bd0cfd..15807f576d 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -174,7 +174,7 @@ With these attributes you can reconstruct every interaction of the model with th ### Visualization -When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods). +When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods). You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`: diff --git a/docs/source/xpo_trainer.md b/docs/source/xpo_trainer.md index 07a76f36dc..4501aaf68b 100644 --- a/docs/source/xpo_trainer.md +++ b/docs/source/xpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data. +Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data. The abstract from the paper is the following: From b9df81045b02a0540cdfe4e8b285d63f80224b0b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 10 Feb 2025 09:20:38 -0500 Subject: [PATCH 73/96] =?UTF-8?q?=F0=9F=93=A4=20GRPO=20refactor=20loading?= =?UTF-8?q?=20the=20model=20weights=20to=20vllm=20(#2817)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * GRPO refactor loading the model weights to vllm * style --------- Co-authored-by: Quentin Gallouédec --- trl/trainer/grpo_trainer.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bab86f7bcc..789dfdb800 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -444,6 +444,18 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) logits = logits[:, -logits_to_keep:] return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + def _move_model_to_vllm(self): + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + if is_compiled_module(unwrapped_model): + state_dict = unwrapped_model._orig_mod.state_dict() + else: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights(state_dict.items()) + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device prompts = [x["prompt"] for x in inputs] @@ -462,16 +474,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if self.args.use_vllm: # First, have main process load weights if needed if self.state.global_step != self._last_loaded_step: - with unwrap_model_for_generation( - self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model: - if is_compiled_module(unwrapped_model): - state_dict = unwrapped_model._orig_mod.state_dict() - else: - state_dict = unwrapped_model.state_dict() - if self.accelerator.is_main_process: - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights(state_dict.items()) + self._move_model_to_vllm() self._last_loaded_step = self.state.global_step # Generate completions using vLLM: gather all prompts and use them in a single call in the main process From 674bb75f59d4ddc19558209af2b27a2f537c15c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 11 Feb 2025 10:30:27 +0100 Subject: [PATCH 74/96] =?UTF-8?q?=F0=9F=AB=98=20Add=20`set=5Fseed()`=20cal?= =?UTF-8?q?l=20in=20GRPO=20to=20ensure=20unique=20seed=20for=20each=20proc?= =?UTF-8?q?ess=20(#2824)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add set_seed() function to ensure unique seed for each process * share seed sampler * style --- trl/trainer/grpo_trainer.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 789dfdb800..7ffb26e6ed 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -22,7 +22,7 @@ import torch import torch.utils.data import transformers -from accelerate.utils import broadcast_object_list, gather, gather_object +from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset from packaging import version @@ -73,6 +73,8 @@ class RepeatRandomSampler(Sampler): Dataset to sample from. repeat_count (`int`): Number of times to repeat each index. + seed (`Optional[int]`): + Random seed for reproducibility (only affects this sampler). Example: ```python @@ -82,13 +84,21 @@ class RepeatRandomSampler(Sampler): ``` """ - def __init__(self, data_source: Sized, repeat_count: int): + def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None): self.data_source = data_source self.repeat_count = repeat_count self.num_samples = len(data_source) + self.seed = seed + self.generator = torch.Generator() # Create a local random generator + if seed is not None: + self.generator.manual_seed(seed) def __iter__(self): - indexes = [idx for idx in torch.randperm(self.num_samples).tolist() for _ in range(self.repeat_count)] + indexes = [ + idx + for idx in torch.randperm(self.num_samples, generator=self.generator).tolist() + for _ in range(self.repeat_count) + ] return iter(indexes) def __len__(self): @@ -333,6 +343,11 @@ def data_collator(features): # No data collation is needed in GRPO f"eval batch size, the valid values for the number of generations are: {possible_values}." ) + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + if self.use_vllm: if not is_vllm_available(): raise ImportError( @@ -425,12 +440,19 @@ def _set_signature_columns_if_needed(self): if self._signature_columns is None: self._signature_columns = ["prompt"] - # We need a custom sampler that samples the same prompt multiple times def _get_train_sampler(self) -> Sampler: - return RepeatRandomSampler(self.train_dataset, self.num_generations) + # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that + # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly + # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, + # preventing discrepancies in group formation. + return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed) def _get_eval_sampler(self, eval_dataset) -> Sampler: - return RepeatRandomSampler(eval_dataset, self.num_generations) + # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that + # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly + # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, + # preventing discrepancies in group formation. + return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed) # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): From e752fc6c2e298683196f4afd84c232991c1b32a0 Mon Sep 17 00:00:00 2001 From: Hesam Sheikh <41022652+hesamsheikh@users.noreply.github.com> Date: Tue, 11 Feb 2025 11:15:41 +0100 Subject: [PATCH 75/96] =?UTF-8?q?=E2=9A=96=EF=B8=8F=20Add=20reward=20weigh?= =?UTF-8?q?t=20in=20multi-reward=20settings=20for=20GRPO=20(#2676)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added reward weights for multi-reward runs in GRPO * reward_weights are float, moved from GRPOTrainer to GRPOConfig * minor comment fix * minor * fix test * missing link --------- Co-authored-by: Quentin Gallouédec Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/grpo_trainer.md | 7 +++--- tests/test_grpo_trainer.py | 43 +++++++++++++++++++++++++++++++++++++ trl/trainer/grpo_config.py | 10 +++++++++ trl/trainer/grpo_trainer.py | 21 +++++++++++++++--- 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 9227ff52cf..c1b9a5c28e 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -193,7 +193,7 @@ You can test this function as follows: #### Example 3: Reward completions based on a reference -Below is an example of a reward function that checks if the is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`. ```python @@ -219,7 +219,7 @@ You can test this function as follows: #### Passing the reward function to the trainer -To use your custom reward function, pass it to the `GRPOTrainer` as follows: +To use your custom reward function, pass it to the [`GRPOTrainer`] as follows: ```python from trl import GRPOTrainer @@ -240,8 +240,7 @@ trainer = GRPOTrainer( ..., ) ``` - -and the reward will be computed as the sum of the rewards from each function. +and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config. Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 54d2b3965d..3af3986f0d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -283,6 +283,49 @@ def reward_func2(completions, **kwargs): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + def test_training_multiple_reward_funcs_with_weights(self): + """Test that GRPOTrainer can handle multiple reward functions with weights.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + reward_weights=[0.7, 0.3], # weight of reward_func1 and reward_func2 respectively + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training logs contain both reward metrics + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIn("rewards/reward_func1", trainer.state.log_history[-1]) + self.assertIn("rewards/reward_func2", trainer.state.log_history[-1]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + def test_training_multiple_mixed_reward_funcs(self): # Test if the trainer can handle a mix of reward functions and reward models dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 708a3188c8..6809b5c1ad 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -86,6 +86,9 @@ class GRPOConfig(TrainingArguments): [`~transformers.TrainingArguments`]. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. + reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): + Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are + weighted equally with weight `1.0`. sync_ref_model (`bool`, *optional*, defaults to `False`): Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using the `ref_model_mixup_alpha` parameter. This synchronization originites from the @@ -210,6 +213,13 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + reward_weights: Optional[list[float]] = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " + "rewards are weighted equally with weight `1.0`." + }, + ) sync_ref_model: bool = field( default=False, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7ffb26e6ed..abbd7328d0 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -118,9 +118,13 @@ class GRPOTrainer(Trainer): dataset = load_dataset("trl-lib/tldr", split="train") + def reward_func(completions, **kwargs): + # Dummy reward function that rewards completions with more unique letters. + return [float(len(set(completion))) for completion in completions] + trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs="weqweasdas/RM-Gemma-2B", + reward_funcs=reward_func, train_dataset=dataset, ) @@ -267,6 +271,17 @@ def __init__( ) self.reward_funcs = reward_funcs + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(len(args.reward_weights))}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + # Reward processing class if reward_processing_classes is None: reward_processing_classes = [None] * len(reward_funcs) @@ -588,8 +603,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # completions may be distributed across processes rewards_per_func = gather(rewards_per_func) - # Sum the rewards from all reward functions - rewards = rewards_per_func.sum(dim=1) + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) From 9b67eea4735b1db5e09e1eafd57eae13b55353c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 11 Feb 2025 11:30:37 +0100 Subject: [PATCH 76/96] =?UTF-8?q?=F0=9F=99=8C=20Share=20vLLM=20device=20wi?= =?UTF-8?q?th=20training=20when=20only=201=20available=20(#2827)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix GPU device selection in GRPOTrainer in case training with onyl one * update doc * style * update warning --- trl/trainer/grpo_config.py | 3 ++- trl/trainer/grpo_trainer.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 6809b5c1ad..cd6cc91748 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -65,7 +65,8 @@ class GRPOConfig(TrainingArguments): vllm_device (`str`, *optional*, defaults to `"auto"`): Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will automatically select the next available GPU after the last one used for training. This assumes that - training has not already occupied all available GPUs. + training has not already occupied all available GPUs. If only one device is available, the device will be + shared between both training and vLLM. vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index abbd7328d0..0ab8f3eda9 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -373,7 +373,10 @@ def data_collator(features): # No data collation is needed in GRPO if self.accelerator.is_main_process: vllm_device = self.args.vllm_device if vllm_device == "auto": - vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + if torch.cuda.device_count() == 1: + vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it + else: + vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx # Check that the requested device is available if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): raise ValueError( @@ -385,8 +388,10 @@ def data_collator(features): # No data collation is needed in GRPO # Check that the requested device is not also used for training if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}: warnings.warn( - f"The requested device {vllm_device} is also used for training. This may lead to unexpected " - "behavior. It is recommended to use a dedicated device for vLLM." + f"The requested device {vllm_device} is also being used for training. For higher throughput " + "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. " + "If this is intentional, you may ignore this warning but should adjust " + "`vllm_gpu_memory_utilization` accordingly." ) # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our From 2106b3129815b68acdbbd0c4c15d162fae33ec10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 11 Feb 2025 11:46:26 +0100 Subject: [PATCH 77/96] =?UTF-8?q?=F0=9F=91=B4=20Update=20`tokenizer`=20par?= =?UTF-8?q?ameter=20to=20`processing=5Fclass`=20in=20tests=20(#2828)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_callbacks.py | 4 ++-- tests/test_dpo_trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index eac36ebe02..cf056f3e03 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -341,7 +341,7 @@ def test_callback(self): model=self.model, args=training_args, train_dataset=self.dataset, - tokenizer=self.tokenizer, + processing_class=self.tokenizer, callbacks=[merge_callback], ) trainer.train() @@ -364,7 +364,7 @@ def test_every_checkpoint(self): model=self.model, args=training_args, train_dataset=self.dataset, - tokenizer=self.tokenizer, + processing_class=self.tokenizer, callbacks=[merge_callback], ) trainer.train() diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 4e1a72ce9a..e5c6b1b08a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -265,7 +265,7 @@ def test_dpo_trainer_with_weighting(self): model=self.model, ref_model=self.ref_model, args=training_args, - tokenizer=self.tokenizer, + processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], ) From 7347c292c3d18ff9209ad745ede6fce7e3b94155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 11 Feb 2025 18:56:22 +0100 Subject: [PATCH 78/96] =?UTF-8?q?=F0=9F=A5=BE=20=20Allow=20bootstrap=20GRP?= =?UTF-8?q?O=20(#2829)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Kashif Rasul --- trl/data_utils.py | 22 +++++++++++++++++++--- trl/trainer/grpo_trainer.py | 5 ++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index 2d69cdd431..1c718a6c2d 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -90,8 +90,21 @@ def apply_chat_template( # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: + last_role = example["prompt"][-1]["role"] + if last_role == "user": + add_generation_prompt = True + continue_final_message = False + elif last_role == "assistant": + add_generation_prompt = False + continue_final_message = True + else: + raise ValueError(f"Invalid role in the last message: {last_role}") prompt = tokenizer.apply_chat_template( - example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True + example["prompt"], + tools=tools, + continue_final_message=continue_final_message, + tokenize=False, + add_generation_prompt=add_generation_prompt, ) # Apply the chat template to the entire prompt + completion @@ -180,10 +193,13 @@ def maybe_apply_chat_template( Returns: `dict[str, str]`: The formatted example with the chat template applied. - Note: - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by + Notes: + - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by `"text"`. + - In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else, + if the last role is `"assistant"`, the final message is continued. + Example: ```python diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0ab8f3eda9..620cffd145 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -577,7 +577,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): - completions = [[{"role": "assistant", "content": completion}] for completion in completions_text] + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) else: completions = completions_text From 81221661c6f864bd7cdb7c461e881bbe03414be8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 12 Feb 2025 18:36:01 +0100 Subject: [PATCH 79/96] =?UTF-8?q?=E2=9A=A1=20Fix=20GRPO=20PEFT=20(#2725)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_grpo_trainer.py | 35 +++++++++++++++++++++++++++++++++++ trl/models/utils.py | 31 +++++++++++++++++++++++-------- trl/trainer/grpo_trainer.py | 16 +++++++++++++++- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 3af3986f0d..d06b4c8ca6 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -498,3 +498,38 @@ def test_training_with_sync_ref_model(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") + @require_torch_accelerator + @require_peft + def test_training_vllm_and_peft(self): + """Test that training works with vLLM for generation.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + use_vllm=True, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/small-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/models/utils.py b/trl/models/utils.py index dce9d60228..34c9b3c037 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -20,6 +20,7 @@ from accelerate.utils import is_deepspeed_available from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.utils.deprecation import deprecate_kwarg from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead @@ -37,8 +38,6 @@ from deepspeed.runtime.engine import DeepSpeedEngine from torch.nn.parallel.distributed import DistributedDataParallel - from .modeling_base import PreTrainedModelWrapper - # TODO: Add Abstract Base Class if more formats are added @dataclass @@ -176,18 +175,34 @@ def add_hooks(model: "DeepSpeedEngine") -> None: @contextmanager +@deprecate_kwarg("is_peft_model", "0.16.0", warn_if_greater_or_equal_version=True) def unwrap_model_for_generation( model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", - is_peft_model: bool = False, gather_deepspeed3_params: bool = True, -) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]: - """Context manager to unwrap a model for generation. - For ZeRO-3 models, we gather the weights once to speed up generation. +): + """ + Context manager to unwrap distributed or accelerated models for generation tasks. + + Args: + model (`Union[DistributedDataParallel, DeepSpeedEngine]`): + Model to be unwrapped. + accelerator (`~accelerate.Accelerator`): + Accelerator instance managing the model. + gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): + Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which + can be more memory-efficient but may lead to slower generation times. + + Yields: + Unwrapped model. + + Example: + ```python + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + generated_outputs = unwrapped_model.generate(input_ids) + ``` """ unwrapped_model = accelerator.unwrap_model(model) - if is_peft_model: - unwrapped_model.pretrained_model.disable_adapter() if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: if not gather_deepspeed3_params: yield accelerator.unwrap_model(model) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 620cffd145..a078f0b2c2 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -51,7 +51,7 @@ if is_peft_available(): - from peft import PeftConfig, get_peft_model + from peft import PeftConfig, PeftModel, get_peft_model if is_vllm_available(): from vllm import LLM, SamplingParams @@ -492,6 +492,20 @@ def _move_model_to_vllm(self): ) as unwrapped_model: if is_compiled_module(unwrapped_model): state_dict = unwrapped_model._orig_mod.state_dict() + elif isinstance(unwrapped_model, PeftModel): + unwrapped_model.merge_adapter() + state_dict = unwrapped_model.state_dict() + unwrapped_model.unmerge_adapter() + state_dict = { + k.removeprefix("base_model.model.").replace(".base_layer", ""): v + for k, v in state_dict.items() + if self.model.prefix not in k + } + state_dict = { + k.replace("modules_to_save.default.", ""): v + for k, v in state_dict.items() + if "original_module" not in k + } else: state_dict = unwrapped_model.state_dict() if self.accelerator.is_main_process: From b0f513c13de297dcff9caadfc19392b8d62ed396 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Thu, 13 Feb 2025 12:23:10 +0100 Subject: [PATCH 80/96] Fix PeftModel check when moving weights to vlllm (#2850) This check meant that peft now because a required dep when running GRPO with vllm. This PR should resolve this. --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a078f0b2c2..a71a2f9b1f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -492,7 +492,7 @@ def _move_model_to_vllm(self): ) as unwrapped_model: if is_compiled_module(unwrapped_model): state_dict = unwrapped_model._orig_mod.state_dict() - elif isinstance(unwrapped_model, PeftModel): + elif is_peft_available() and isinstance(unwrapped_model, PeftModel): unwrapped_model.merge_adapter() state_dict = unwrapped_model.state_dict() unwrapped_model.unmerge_adapter() From 8830786a233853b9310e2e154f6eb87194a75173 Mon Sep 17 00:00:00 2001 From: codeychen Date: Thu, 13 Feb 2025 20:46:18 +0800 Subject: [PATCH 81/96] =?UTF-8?q?=F0=9F=AA=86=20Fix=20for=20Incorrect=20Va?= =?UTF-8?q?lueError=20Handling=20in=20reward=5Fweights=20in=20grpo=5Ftrain?= =?UTF-8?q?er.py=20(#2843)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fixed a bug where an extra `len` call inside the error message caused a `TypeError` instead of the expected `ValueError`. - Replaced `len(len(args.reward_weights))` with the correct `len(args.reward_weights)` to properly calculate the number of reward weights. - Ensured that a `ValueError` is now raised with an accurate and clear message when the number of reward weights does not match the number of reward functions. This fix prevents confusion during debugging and ensures proper error handling during validation. Tested with cases where: - `args.reward_weights` is None (default case). - `args.reward_weights` has mismatched lengths with `reward_funcs`. --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a71a2f9b1f..ff7d677cf2 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -275,7 +275,7 @@ def __init__( if args.reward_weights is not None: if len(args.reward_weights) != len(reward_funcs): raise ValueError( - f"Number of reward weights ({len(len(args.reward_weights))}) must match number of reward " + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " f"functions ({len(reward_funcs)})" ) self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) From 5c9cf2003dee448f730e80a7259f49eef5ae93b1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 13 Feb 2025 09:23:36 -0500 Subject: [PATCH 82/96] =?UTF-8?q?=F0=9F=91=A8=E2=80=8D=F0=9F=91=A9?= =?UTF-8?q?=E2=80=8D=F0=9F=91=A7=20GRPO=20+=20PEFT=20+=20vLLM=20(#2818)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * peft + grpo + vllm * test change * support model alread peft * Update tests/test_grpo_trainer.py --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- tests/test_grpo_trainer.py | 39 +++++++++++++++++++++++++------------ trl/trainer/grpo_trainer.py | 18 +++++++++-------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index d06b4c8ca6..8685ac99a4 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -125,7 +125,7 @@ def test_training_peft(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the peft params have changed and the base model params have not changed + # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same @@ -168,7 +168,7 @@ def test_training_different_reward_model(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the params have changed + # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") @@ -203,7 +203,7 @@ def reward_func(completions, **kwargs): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the params have changed + # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") @@ -239,7 +239,7 @@ def reward_func(completions, **kwargs): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the params have changed + # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") @@ -278,7 +278,7 @@ def reward_func2(completions, **kwargs): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the params have changed + # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") @@ -356,7 +356,7 @@ def reward_func(completions, **kwargs): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the params have changed + # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") @@ -395,7 +395,7 @@ def reward_func(completions, some_values, **kwargs): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check the params have changed + # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") @@ -416,9 +416,10 @@ def test_training_vllm(self): report_to="none", use_vllm=True, vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU + vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm ) trainer = GRPOTrainer( - model="trl-internal-testing/small-Qwen2ForCausalLM-2.5", + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny is too small for vLLM reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", args=training_args, train_dataset=dataset, @@ -504,6 +505,8 @@ def test_training_with_sync_ref_model(self): @require_peft def test_training_vllm_and_peft(self): """Test that training works with vLLM for generation.""" + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") with tempfile.TemporaryDirectory() as tmp_dir: @@ -513,14 +516,22 @@ def test_training_vllm_and_peft(self): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage - use_vllm=True, report_to="none", + use_vllm=True, + vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU + vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm + ) + lora_config = LoraConfig( + target_modules="all-linear", + # test with non-default modules as it add extra keys in state_dict tht we need to handle + modules_to_save=["embed_tokens", "lm_head"], ) trainer = GRPOTrainer( - model="trl-internal-testing/small-Qwen2ForCausalLM-2.5", + model=model, reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", args=training_args, train_dataset=dataset, + peft_config=lora_config, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -529,7 +540,11 @@ def test_training_vllm_and_peft(self): self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the params have changed + # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + if n in base_param_names: # We expect the base model params to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + elif "base_layer" not in n and "original_module" not in n: + # We expect the peft params to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ff7d677cf2..1211a453fc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -22,7 +22,7 @@ import torch import torch.utils.data import transformers -from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset from packaging import version @@ -51,7 +51,7 @@ if is_peft_available(): - from peft import PeftConfig, PeftModel, get_peft_model + from peft import PeftConfig, get_peft_model if is_vllm_available(): from vllm import LLM, SamplingParams @@ -249,7 +249,7 @@ def __init__( # Reference model if is_deepspeed_zero3_enabled(): self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) - elif peft_config is None: + elif not is_peft_model(model): # If PEFT configuration is not provided, create a reference model based on the initial model. self.ref_model = create_reference_model(model) else: @@ -491,16 +491,18 @@ def _move_model_to_vllm(self): self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation ) as unwrapped_model: if is_compiled_module(unwrapped_model): - state_dict = unwrapped_model._orig_mod.state_dict() - elif is_peft_available() and isinstance(unwrapped_model, PeftModel): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): unwrapped_model.merge_adapter() state_dict = unwrapped_model.state_dict() unwrapped_model.unmerge_adapter() + # Remove base_model and base_layer prefixes state_dict = { - k.removeprefix("base_model.model.").replace(".base_layer", ""): v - for k, v in state_dict.items() - if self.model.prefix not in k + k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items() } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module state_dict = { k.replace("modules_to_save.default.", ""): v for k, v in state_dict.items() From 00e58893809af89d07c2ff9c582ac5060888363c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Feb 2025 14:28:36 +0000 Subject: [PATCH 83/96] Release: v0.15 --- .github/workflows/tests_latest.yml | 2 +- CITATION.cff | 2 +- setup.py | 2 +- trl/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests_latest.yml b/.github/workflows/tests_latest.yml index ab071f05d3..9d957a4242 100644 --- a/.github/workflows/tests_latest.yml +++ b/.github/workflows/tests_latest.yml @@ -17,7 +17,7 @@ jobs: steps: - name: Git checkout uses: actions/checkout@v4 - with: { ref: v0.14-release } + with: { ref: v0.15-release } - name: Set up Python 3.12 uses: actions/setup-python@v5 with: diff --git a/CITATION.cff b/CITATION.cff index cc7130b2b1..68076ef706 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -31,4 +31,4 @@ keywords: - pytorch - transformers license: Apache-2.0 -version: 0.14 +version: 0.15 diff --git a/setup.py b/setup.py index d70406d67e..10e3cb2a62 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ from setuptools import find_packages, setup -__version__ = "0.15.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.15.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "accelerate>=0.34.0", diff --git a/trl/__init__.py b/trl/__init__.py index 0cd4be772f..e177bcd40f 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.15.0.dev0" +__version__ = "0.15.0" from typing import TYPE_CHECKING From ffcb9f4aee725a2bd072d0387afe68a4b1c7967c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Feb 2025 14:33:44 +0000 Subject: [PATCH 84/96] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Bump=20dev=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.py | 2 +- trl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 10e3cb2a62..5c3c4d5ac3 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ from setuptools import find_packages, setup -__version__ = "0.15.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.16.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "accelerate>=0.34.0", diff --git a/trl/__init__.py b/trl/__init__.py index e177bcd40f..6850dc7681 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.15.0" +__version__ = "0.16.0.dev0" from typing import TYPE_CHECKING From 6d9fc11fd6c759310e006aae353c8d098cc42071 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 17 Feb 2025 07:50:55 +0100 Subject: [PATCH 85/96] [SFT] fix check for AutoLigerKernelForCausalLM (#2874) * fix check for AutoLigerKernelForCausalLM * fix case where AutoLigerKernelForCausalLM is not defined * update min liger version * formatting * fix win CI --- setup.py | 2 +- trl/trainer/sft_trainer.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5c3c4d5ac3..2395da3fa6 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,7 @@ "diffusers": ["diffusers>=0.18.0"], "judges": ["openai>=1.23.2", "llm-blender>=0.0.2"], # liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility - "liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"], + "liger": ["liger-kernel>=0.5.3; sys_platform != 'win32'"], "mergekit": ["mergekit>=0.0.5.1"], "peft": ["peft>=0.8.0"], "quantization": ["bitsandbytes"], diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 1229e4c6ac..7092ea95d0 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -60,6 +60,8 @@ if is_liger_kernel_available(): from liger_kernel.transformers import AutoLigerKernelForCausalLM +else: + AutoLigerKernelForCausalLM = None if is_wandb_available(): import wandb @@ -440,7 +442,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - if "labels" in inputs and not self.args.use_liger: + use_liger = self.args.use_liger or ( + AutoLigerKernelForCausalLM is not None and isinstance(model, AutoLigerKernelForCausalLM) + ) + if "labels" in inputs and not use_liger: shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = inputs["labels"][..., 1:].contiguous() From ae3bd0d07a7a54541f6cb883dd0f7378fe384f3b Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Mon, 17 Feb 2025 10:54:07 +0100 Subject: [PATCH 86/96] =?UTF-8?q?=F0=9F=86=99=20Bump=20vLLM=20min=20versio?= =?UTF-8?q?n=20to=200.7.2=20(#2860)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps vllm as there were a number of throughput improvements in vllm==0.7.2 Also may resolve issue such as https://github.com/huggingface/trl/issues/2851 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2395da3fa6..819c75b6be 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ "quantization": ["bitsandbytes"], "scikit": ["scikit-learn"], "test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"], - "vllm": ["vllm>=0.7.1; sys_platform != 'win32'"], # vllm is not available on Windows + "vllm": ["vllm>=0.7.2; sys_platform != 'win32'"], # vllm is not available on Windows "vlm": ["Pillow"], } EXTRAS["dev"] = [] From 293b62095099561a6a2e27d2d521c3cee709a1f2 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Mon, 17 Feb 2025 13:26:21 +0100 Subject: [PATCH 87/96] [GRPO] Fix loss normalization (#2881) * fix GRPO loss normalization * fix sum dim * fix loss= repeated --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1211a453fc..6957216970 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -707,7 +707,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N advantages = inputs["advantages"] per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) per_token_loss = -(per_token_loss - self.beta * per_token_kl) - loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() # Log the metrics completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() From ba036576d4a62d91da0388b7e727f6656f4c08d7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 17 Feb 2025 16:47:06 +0100 Subject: [PATCH 88/96] =?UTF-8?q?=F0=9F=92=AC=20Add=20`maybe=5Fconvert=5Ft?= =?UTF-8?q?o=5Fchatml`=20map=20for=20conversational=20datasets=20in=20SFT?= =?UTF-8?q?=20(#2862)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add back get_formatting_func_from_dataset * maybe_convert_to_chatml * maybe_convert_to_chatml before maybe_apply_chat_template map * remove comment * test * desc * style * Update trl/data_utils.py --------- Co-authored-by: Quentin Gallouédec Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/data_utils.md | 4 +++ tests/test_data_utils.py | 46 +++++++++++++++++++++++++++ tests/test_sft_trainer.py | 30 ++++++++++++++++++ trl/__init__.py | 2 ++ trl/data_utils.py | 65 ++++++++++++++++++++++++++++++++++---- trl/trainer/sft_trainer.py | 11 ++++++- 6 files changed, 150 insertions(+), 8 deletions(-) diff --git a/docs/source/data_utils.md b/docs/source/data_utils.md index 7faafd256a..bdadd4206b 100644 --- a/docs/source/data_utils.md +++ b/docs/source/data_utils.md @@ -12,6 +12,10 @@ [[autodoc]] maybe_apply_chat_template +## maybe_convert_to_chatml + +[[autodoc]] maybe_convert_to_chatml + ## extract_prompt [[autodoc]] extract_prompt diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index b1908fc996..20c8614f68 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -24,6 +24,7 @@ extract_prompt, is_conversational, maybe_apply_chat_template, + maybe_convert_to_chatml, maybe_extract_prompt, maybe_unpair_preference_dataset, pack_examples, @@ -435,6 +436,51 @@ def test_pack_with_dataset(self): self.assertEqual(dataset.to_dict(), expected_output) +class TestMaybeConvertToChatML(unittest.TestCase): + def test_with_conversations_key(self): + # Particular case where the key is "conversations": we rename it to "messages" + example = { + "conversations": [ + {"from": "user", "value": "What color is the sky?"}, + {"from": "assistant", "value": "It is blue."}, + ] + } + expected_output = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + self.assertEqual(maybe_convert_to_chatml(example), expected_output) + + def test_without_conversations_key(self): + # Same as before, but we don't rename the keys + example = { + "prompt": [{"from": "user", "value": "What color is the sky?"}], + "completion": [{"from": "assistant", "value": "It is blue."}], + } + expected_output = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + } + self.assertEqual(maybe_convert_to_chatml(example), expected_output) + + def test_not_conversional(self): + # When not needed, the example should remain unchanged + example = {"text": "The sky is blue."} + self.assertEqual(maybe_convert_to_chatml(example), example) + + def test_already_chatml(self): + # When the example is already in ChatML format, it should remain unchanged + example = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + self.assertEqual(maybe_convert_to_chatml(example), example) + + # Run the tests if __name__ == "__main__": unittest.main() diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 158a87586e..ecf8d44b6c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1370,3 +1370,33 @@ def test_train_peft_model(self): "base_layer" not in n ): # We expect the peft parameters to be different (except for the base layer) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_non_chatml_conversational_data(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Rename role/content to from/value to ensure SFT works with non-chatML conversational data + def rename_fields(example: list[dict]): + return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} + + dataset = dataset.map(rename_fields, remove_columns="messages") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/trl/__init__.py b/trl/__init__.py index 6850dc7681..188b05056d 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -26,6 +26,7 @@ "extract_prompt", "is_conversational", "maybe_apply_chat_template", + "maybe_convert_to_chatml", "maybe_extract_prompt", "maybe_unpair_preference_dataset", "pack_examples", @@ -126,6 +127,7 @@ extract_prompt, is_conversational, maybe_apply_chat_template, + maybe_convert_to_chatml, maybe_extract_prompt, maybe_unpair_preference_dataset, pack_examples, diff --git a/trl/data_utils.py b/trl/data_utils.py index 1c718a6c2d..bb0891d001 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool: dataset type. Returns: - `bool`: `True` if the data is in a conversational format, `False` otherwise. + `bool`: + `True` if the data is in a conversational format, `False` otherwise. Examples: @@ -185,20 +186,21 @@ def maybe_apply_chat_template( For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of messages, where each message is a dictionary with keys `"role"` and `"content"`. tokenizer (`PreTrainedTokenizer`): - The tokenizer to apply the chat template with. + Tokenizer to apply the chat template with. tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect Returns: - `dict[str, str]`: The formatted example with the chat template applied. + `dict[str, str]`: + Formatted example with the chat template applied. Notes: - - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by - `"text"`. + - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced + by `"text"`. - - In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else, - if the last role is `"assistant"`, the final message is continued. + - In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. + Else, if the last role is `"assistant"`, the final message is continued. Example: @@ -462,3 +464,52 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, # Split the values into chunks of size seq_length examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()} return examples + + +def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]: + """ + Convert a conversational dataset with fields `from` and `value` to ChatML format. + + This function modifies conversational data to align with OpenAI's ChatML format: + - Replaces the key `"from"` with `"role"` in message dictionaries. + - Replaces the key `"value"` with `"content"` in message dictionaries. + - Renames `"conversations"` to `"messages"` for consistency with ChatML. + + Args: + example (`dict[str, list]`): + A single data entry containing a list of messages. + + Returns: + `dict[str, list]`: + Example reformatted to ChatML style. + + Example: + ```python + >>> from trl import maybe_convert_to_chatml + >>> example = { + ... "conversations": [ + ... {"from": "user", "value": "What color is the sky?"}, + ... {"from": "assistant", "value": "It is blue."} + ... ] + ... } + >>> maybe_convert_to_chatml(example) + {'messages': [{'role': 'user', 'content': 'What color is the sky?'}, + {'role': 'assistant', 'content': 'It is blue.'}]} + ``` + """ + # List of possible keys containing message lists + for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]: + if key in example and isinstance(example[key], list): + messages = example[key] + for message in messages: + if isinstance(message, dict): + if "from" in message: + message["role"] = message.pop("from") + if "value" in message: + message["content"] = message.pop("value") + + # Rename "conversations" to "messages" + if "conversations" in example: + example["messages"] = example.pop("conversations") + + return example diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 7092ea95d0..f6b406aeb4 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -43,7 +43,7 @@ from transformers.utils import is_liger_kernel_available, is_peft_available from transformers.utils.deprecation import deprecate_kwarg -from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples +from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples from .sft_config import SFTConfig from .utils import ( ConstantLengthDataset, @@ -397,6 +397,15 @@ def concat_prompt_completion(example): dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"]) + # Convert the dataset to ChatML if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in dataset.column_names else None, + **map_kwargs, + ) + # Apply the chat template if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" From 822653824bf084bc6c042cf0e759f86187c92569 Mon Sep 17 00:00:00 2001 From: XZ-X Date: Mon, 17 Feb 2025 14:34:07 -0500 Subject: [PATCH 89/96] =?UTF-8?q?=F0=9F=A7=B6=20[GRPO][vLLM=20+=20LoRA]=20?= =?UTF-8?q?Move=20unmerge=20of=20PEFT=20model=20after=20weight=20loading?= =?UTF-8?q?=20(#2873)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/grpo_trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6957216970..3c05c9b2a3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -495,7 +495,6 @@ def _move_model_to_vllm(self): if is_peft_model(unwrapped_model): unwrapped_model.merge_adapter() state_dict = unwrapped_model.state_dict() - unwrapped_model.unmerge_adapter() # Remove base_model and base_layer prefixes state_dict = { k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items() @@ -510,9 +509,13 @@ def _move_model_to_vllm(self): } else: state_dict = unwrapped_model.state_dict() - if self.accelerator.is_main_process: - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights(state_dict.items()) + if self.accelerator.is_main_process: + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights(state_dict.items()) + # Unmerge the adapter to restore the model to its original state. + # This must be done after loading weights to ensure they correspond to the merged state. + if is_peft_model(unwrapped_model): + unwrapped_model.unmerge_adapter() def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device From aafd8cbea59ae10767942bb053377db372b9d5a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B7=A6=E5=85=B6=E5=8F=B3?= <48852791+BenasdTW@users.noreply.github.com> Date: Tue, 18 Feb 2025 16:56:47 +0800 Subject: [PATCH 90/96] =?UTF-8?q?=F0=9F=8D=9F=20[SFT]=20Handles=20the=20da?= =?UTF-8?q?taset=20if=20it=20has=20been=20preprocessed=20(#2863)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * return dataset if it's preprocessed * add is_processed flag variable * add test * move test_sft_trainer_directly_with_pretokenized_data to Tester2 * Update sft_trainer.py * no need for padding and truncation * minor reorganization * Update trl/trainer/sft_trainer.py * let the collator pad * style * fix tests --------- Co-authored-by: Kashif Rasul Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- tests/test_sft_trainer.py | 34 +++++++++++++++++++++++++++++++++- trl/trainer/sft_trainer.py | 30 +++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index ecf8d44b6c..14d235585a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -288,7 +288,7 @@ def test_sft_trainer(self): self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2")) - def test_sft_trainer_with_pretokenzied_data_packing(self): + def test_sft_trainer_with_pretokenized_data_packing(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = SFTConfig( output_dir=tmp_dir, @@ -1400,3 +1400,35 @@ def rename_fields(example: list[dict]): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_sft_trainer_with_pretokenized_data(self): + # Get the model and dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + def tokenize_example(example): + return tokenizer(example["text"]) + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f6b406aeb4..510ceb5348 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -109,6 +109,8 @@ class SFTTrainer(Trainer): - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): @@ -370,6 +372,10 @@ def _prepare_dataset( if isinstance(dataset, ConstantLengthDataset): return dataset + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "input_ids" in column_names + # Build the kwargs for the `map` function map_kwargs = {} if isinstance(dataset, Dataset): # IterableDataset does not support num_proc @@ -377,7 +383,15 @@ def _prepare_dataset( with PartialState().local_main_process_first(): # Apply the formatting function if any - if formatting_func is not None: + if formatting_func is not None and is_processed: + warnings.warn( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a " + "formatting function. Therefore `formatting_func` will be ignored. Either remove the " + "`formatting_func` or pass a dataset that is not already processed.", + UserWarning, + ) + + if formatting_func is not None and not is_processed: if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" @@ -416,10 +430,16 @@ def concat_prompt_completion(example): **map_kwargs, ) - # Tokenize the dataset - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs) + # Tokenize the dataset if needed + if not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(ex): + tokenized = processing_class(ex[args.dataset_text_field]) + return {"input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"]} + + dataset = dataset.map(tokenize, **map_kwargs) # Pack or truncate if packing: From 963243a7d1268ce98fc9f65939b5aab885c331f7 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 18 Feb 2025 11:44:15 +0100 Subject: [PATCH 91/96] Optimize vllm num_generations (#2855) * small optimization of vllm batching * style * adds comment * style --- trl/trainer/grpo_trainer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3c05c9b2a3..028021570d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -415,6 +415,7 @@ def data_collator(features): # No data collation is needed in GRPO self.sampling_params = SamplingParams( temperature=args.temperature, max_tokens=self.max_completion_length, + n=args.num_generations, ) self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation @@ -541,8 +542,17 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) - completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text)) + all_outputs = self.llm.generate( + ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False + ) + completion_ids = [] + for outputs in all_outputs: + for output in outputs.outputs: + completion_ids.append(output.token_ids) else: completion_ids = [None] * len(all_prompts_text) # Broadcast the completions from the main process to all processes, ensuring each process receives its From 6c54f023ae7ccecc9b755dc8e994bef5c42ecadf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 18 Feb 2025 15:31:08 +0100 Subject: [PATCH 92/96] =?UTF-8?q?=F0=9F=AA=82=20Don't=20gather=20logits=20?= =?UTF-8?q?in=20SFT=20to=20avoid=20hanging=20(#2890)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Don't gather logits * Remove unused function and test --- tests/test_utils.py | 55 -------------------------------------- trl/__init__.py | 4 +-- trl/trainer/__init__.py | 2 -- trl/trainer/sft_trainer.py | 48 +++++++++++++++++---------------- trl/trainer/utils.py | 21 --------------- 5 files changed, 27 insertions(+), 103 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 64ab68bf74..871ec6c737 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,7 +27,6 @@ from trl.trainer.utils import ( DataCollatorForChatML, batch_generation, - compute_token_accuracy, decode_and_strip_padding, flush_left, generate_model_card, @@ -456,60 +455,6 @@ def test_no_tensors(self): self.assertTrue(torch.equal(new_mask, expected_mask)) -class TestComputeTokenAccuracy(unittest.TestCase): - def test_basic_accuracy(self): - # Test basic accuracy computation - logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2] - labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2] - accuracy = compute_token_accuracy(logits, labels) - self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens - - def test_with_ignore_index(self): - # Test accuracy computation with ignored tokens - logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) - labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored - accuracy = compute_token_accuracy(logits, labels, ignore_index=-100) - self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens - - def test_all_ignored(self): - # Test case where all tokens are ignored - logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) - labels = torch.tensor([[-100, -100]]) - accuracy = compute_token_accuracy(logits, labels) - self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy - - def test_perfect_accuracy(self): - # Test case with 100% accuracy - logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) - labels = torch.tensor([[1, 0]]) - accuracy = compute_token_accuracy(logits, labels) - self.assertEqual(accuracy, 1.0) # All predictions correct - - def test_zero_accuracy(self): - # Test case with 0% accuracy - logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) - labels = torch.tensor([[0, 1]]) - accuracy = compute_token_accuracy(logits, labels) - self.assertEqual(accuracy, 0.0) # All predictions wrong - - def test_batch_accuracy(self): - # Test accuracy computation across multiple batches - logits = torch.tensor( - [ - [[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1 - [[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2 - ] - ) - labels = torch.tensor( - [ - [1, 0, 1], # Batch 1 - [1, 0, -100], # Batch 2 (last token ignored) - ] - ) - accuracy = compute_token_accuracy(logits, labels) - self.assertAlmostEqual(accuracy, 0.8) - - class TestSelectiveLogSoftmax(unittest.TestCase): @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) def test_selective_log_softmax(self, dtype): diff --git a/trl/__init__.py b/trl/__init__.py index 188b05056d..9a17e8e873 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -102,7 +102,7 @@ "XPOTrainer", ], "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], - "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], } try: @@ -204,7 +204,7 @@ XPOTrainer, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback - from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config + from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config try: if not is_diffusers_available(): diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 9ef887864a..85968218cc 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -76,7 +76,6 @@ "disable_dropout_in_model", "empty_cache", "peft_module_casting_to_bf16", - "compute_token_accuracy", ], "xpo_config": ["XPOConfig"], "xpo_trainer": ["XPOTrainer"], @@ -145,7 +144,6 @@ DataCollatorForCompletionOnlyLM, RunningMoments, compute_accuracy, - compute_token_accuracy, disable_dropout_in_model, empty_cache, peft_module_casting_to_bf16, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 510ceb5348..a3d6829fd4 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -45,13 +45,7 @@ from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples from .sft_config import SFTConfig -from .utils import ( - ConstantLengthDataset, - compute_token_accuracy, - generate_model_card, - get_comet_experiment_url, - peft_module_casting_to_bf16, -) +from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16 if is_peft_available(): @@ -60,8 +54,6 @@ if is_liger_kernel_available(): from liger_kernel.transformers import AutoLigerKernelForCausalLM -else: - AutoLigerKernelForCausalLM = None if is_wandb_available(): import wandb @@ -185,12 +177,13 @@ def __init__( ) if isinstance(model, str): model = self._create_model_from_path(model, args) + self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM) # PEFT configuration and model wrapping if peft_config is not None: model = self._prepare_peft_model(model, peft_config, args) - # 3. Handle the tokenizer + # Handle the tokenizer if processing_class is None: processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path) if processing_class.pad_token is None: @@ -279,8 +272,10 @@ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTraine if args.use_liger: if not is_liger_kernel_available(): raise ImportError("Please install Liger-kernel for use_liger=True") - return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs) - return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + return model def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel: """Prepares a model for PEFT training.""" @@ -471,21 +466,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - use_liger = self.args.use_liger or ( - AutoLigerKernelForCausalLM is not None and isinstance(model, AutoLigerKernelForCausalLM) - ) - if "labels" in inputs and not use_liger: + if "labels" in inputs and not self.use_liger: shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = inputs["labels"][..., 1:].contiguous() - # Gather logits and labels from all GPUs first - shift_logits = self.accelerator.gather_for_metrics(shift_logits) - shift_labels = self.accelerator.gather_for_metrics(shift_labels) + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) - # Then compute accuracy on the gathered tensors - if self.accelerator.is_main_process: - accuracy = compute_token_accuracy(shift_logits, shift_labels) - self._metrics["mean_token_accuracy"].append(accuracy) + # Compute the mean token accuracy and log it + accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0 + self._metrics["mean_token_accuracy"].append(accuracy) return (loss, outputs) if return_outputs else loss diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index ea603b9637..7a20645535 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask, *tensors -def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float: - """ - Compute the mean token accuracy. - """ - # Get predictions - predictions = logits.argmax(dim=-1) - - # Create mask for non-padding tokens (assuming pad_token_id is ignore_index) - mask = labels != ignore_index - - # Calculate accuracy only on non-padding tokens - correct_predictions = (predictions == labels) & mask - total_tokens = mask.sum() - correct_tokens = correct_predictions.sum() - - # Calculate accuracy - accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 - - return accuracy - - def selective_log_softmax(logits, index): """ A memory-efficient implementation of the common `log_softmax -> gather` operation. From 49adf74833d71feff22a9a21e0d5ed65b909c257 Mon Sep 17 00:00:00 2001 From: Nikolai Kolodziej <7687617+kldzj@users.noreply.github.com> Date: Tue, 18 Feb 2025 16:53:05 +0100 Subject: [PATCH 93/96] =?UTF-8?q?=E2=9C=A8=20Add=20vLLM=20guided=20decodin?= =?UTF-8?q?g=20support=20to=20GRPO=20Trainer=20(#2811)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Add vLLM guided decoding support to GRPO Trainer * 🔧 Update vLLM guided decoding in GRPO to use regex parameter * style and docstring * test --------- Co-authored-by: Quentin Gallouédec Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- tests/test_grpo_trainer.py | 36 ++++++++++++++++++++++++++++++++++++ trl/trainer/grpo_config.py | 6 ++++++ trl/trainer/grpo_trainer.py | 10 ++++++++++ 3 files changed, 52 insertions(+) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 8685ac99a4..a891793550 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -548,3 +548,39 @@ def test_training_vllm_and_peft(self): elif "base_layer" not in n and "original_module" not in n: # We expect the peft params to be different (except for the base layer) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") + @require_torch_accelerator + def test_training_vllm_guided_decoding(self): + """Test that training works with vLLM for generation with guided decoding.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU + vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n", + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index cd6cc91748..02a02dc788 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -79,6 +79,8 @@ class GRPOConfig(TrainingArguments): If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context size, which might be much larger than the KV cache, leading to inefficiencies. + vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): + Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. > Parameters that control the training @@ -201,6 +203,10 @@ class GRPOConfig(TrainingArguments): "context size, which might be much larger than the KV cache, leading to inefficiencies." }, ) + vllm_guided_decoding_regex: Optional[str] = field( + default=None, + metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + ) # Parameters that control the training learning_rate: float = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 028021570d..93993e082a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -55,6 +55,7 @@ if is_vllm_available(): from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams if is_wandb_available(): import wandb @@ -412,9 +413,18 @@ def data_collator(features): # No data collation is needed in GRPO enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, ) + + # Guided decoding, if enabled + if args.vllm_guided_decoding_regex is not None: + guided_decoding = GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) + else: + guided_decoding = None + + # Sampling parameters self.sampling_params = SamplingParams( temperature=args.temperature, max_tokens=self.max_completion_length, + guided_decoding=guided_decoding, n=args.num_generations, ) From 6aaf379a82da43550e5296143673a7edcdb822ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 18 Feb 2025 16:53:21 +0100 Subject: [PATCH 94/96] =?UTF-8?q?=E2=9A=B0=EF=B8=8F=20Remove=20deprecated?= =?UTF-8?q?=20(#2894)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/models/utils.py | 2 -- trl/trainer/dpo_trainer.py | 4 ---- trl/trainer/sft_trainer.py | 4 ---- 3 files changed, 10 deletions(-) diff --git a/trl/models/utils.py b/trl/models/utils.py index 34c9b3c037..0632e025f4 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -20,7 +20,6 @@ from accelerate.utils import is_deepspeed_available from transformers import PreTrainedModel, PreTrainedTokenizer -from transformers.utils.deprecation import deprecate_kwarg from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead @@ -175,7 +174,6 @@ def add_hooks(model: "DeepSpeedEngine") -> None: @contextmanager -@deprecate_kwarg("is_peft_model", "0.16.0", warn_if_greater_or_equal_version=True) def unwrap_model_for_generation( model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a16edb6f37..0346d991fc 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -51,7 +51,6 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_xpu_available -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper, create_reference_model @@ -202,9 +201,6 @@ class DPOTrainer(Trainer): _tag_names = ["trl", "dpo"] - @deprecate_kwarg( - "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index a3d6829fd4..e4708eb7c7 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -41,7 +41,6 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_liger_kernel_available, is_peft_available -from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples from .sft_config import SFTConfig @@ -136,9 +135,6 @@ class SFTTrainer(Trainer): _tag_names = ["trl", "sft"] - @deprecate_kwarg( - "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True - ) def __init__( self, model: Union[str, nn.Module, PreTrainedModel], From be1e34003ccd84ff39afec6d2d4bc13c9a142f82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 18 Feb 2025 16:53:37 +0100 Subject: [PATCH 95/96] =?UTF-8?q?=F0=9F=A9=B3=20`max=5Fseq=5Flength`=20to?= =?UTF-8?q?=20`max=5Flength`=20(#2895)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * `max_seq_length` to `max_length` * remove in 0.20 --- commands/run_sft.sh | 2 +- docs/source/reducing_memory_usage.md | 6 ++-- docs/source/sft_trainer.md | 12 +++---- .../stack_llama_2/scripts/sft_llama2.py | 2 +- tests/slow/test_sft_slow.py | 22 ++++++------ tests/test_sft_trainer.py | 30 ++++++++-------- tests/test_trainers_args.py | 4 +-- trl/trainer/gkd_trainer.py | 2 +- trl/trainer/sft_config.py | 36 +++++++++++-------- trl/trainer/sft_trainer.py | 10 +++--- trl/trainer/utils.py | 6 ++-- 11 files changed, 70 insertions(+), 62 deletions(-) diff --git a/commands/run_sft.sh b/commands/run_sft.sh index bdea77fcb6..b7beaaf7fd 100644 --- a/commands/run_sft.sh +++ b/commands/run_sft.sh @@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \ --output_dir $OUTPUT_DIR \ --max_steps $MAX_STEPS \ --per_device_train_batch_size $BATCH_SIZE \ - --max_seq_length $SEQ_LEN \ + --max_length $SEQ_LEN \ $EXTRA_TRAINING_ARGS """ diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index 6c05490616..cc335156e6 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...) -SFT truncation is applied to the input sequence via the `max_seq_length` parameter. +SFT truncation is applied to the input sequence via the `max_length` parameter.
Truncation input ids @@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet: ```python from trl import SFTConfig -training_args = SFTConfig(..., max_seq_length=...) +training_args = SFTConfig(..., max_length=...) ``` @@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f ```python from trl import SFTConfig -training_args = SFTConfig(..., packing=True, max_seq_length=512) +training_args = SFTConfig(..., packing=True, max_length=512) ``` diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index ab3e9e1cc5..5c30b744fa 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer dataset = load_dataset("stanfordnlp/imdb", split="train") training_args = SFTConfig( - max_seq_length=512, + max_length=512, output_dir="/tmp", ) trainer = SFTTrainer( @@ -29,7 +29,7 @@ trainer = SFTTrainer( ) trainer.train() ``` -Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. +Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. You can also construct a model outside of the trainer and pass it as follows: @@ -550,12 +550,12 @@ import torch from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel -max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number +max_length = 2048 # Supports automatic RoPE Scaling, so choose any number # Load model model, tokenizer = FastLanguageModel.from_pretrained( model_name="unsloth/mistral-7b", - max_seq_length=max_seq_length, + max_seq_length=max_length, dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf @@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model( random_state=3407, ) -training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length) +training_args = SFTConfig(output_dir="./output", max_length=max_length) trainer = SFTTrainer( model=model, @@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith Pay attention to the following best practices when training a model with that trainer: -- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. +- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. - For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index 3ae1e82c2a..1f4611a3e8 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -185,7 +185,7 @@ def create_datasets(tokenizer, args, seed=None): train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, - max_seq_length=None, + max_length=None, formatting_func=prepare_sample_text, processing_class=tokenizer, args=training_args, diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 74811d092c..8a772a48aa 100644 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase): def setUp(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") - self.max_seq_length = 128 + self.max_length = 128 self.peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, @@ -74,7 +74,7 @@ def test_sft_trainer_str(self, model_name, packing): per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) trainer = SFTTrainer( @@ -100,7 +100,7 @@ def test_sft_trainer_transformers(self, model_name, packing): per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -135,7 +135,7 @@ def test_sft_trainer_peft(self, model_name, packing): max_steps=10, fp16=True, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -172,7 +172,7 @@ def test_sft_trainer_transformers_mp(self, model_name, packing): max_steps=10, fp16=True, # this is sufficient to enable amp packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -205,7 +205,7 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -242,7 +242,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -286,7 +286,7 @@ def test_sft_trainer_transformers_mp_gc_device_map( per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -324,7 +324,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -364,7 +364,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): training_args = SFTConfig( packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, output_dir=tmp_dir, logging_strategy="no", report_to="none", @@ -411,7 +411,7 @@ def test_sft_trainer_with_liger(self, model_name, packing): per_device_train_batch_size=2, max_steps=2, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, use_liger=True, ) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 14d235585a..1a26378f3f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -326,7 +326,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=32, # make sure there is at least 1 packed sequence + max_length=32, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -353,7 +353,7 @@ def test_sft_trainer_uncorrect_data(self): train_dataset=self.conversational_lm_dataset["train"], ) - # Same, but with packing with `max_seq_length` + # Same, but with packing with `max_length` training_args = SFTConfig( output_dir=tmp_dir, dataloader_drop_last=True, @@ -361,7 +361,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -396,7 +396,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=32, # make sure there is at least 1 packed sequence + max_length=32, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -461,7 +461,7 @@ def test_sft_trainer_with_model_num_train_epochs(self): save_steps=1, num_train_epochs=2, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, packing=True, report_to="none", ) @@ -485,7 +485,7 @@ def test_sft_trainer_with_model_num_train_epochs(self): save_steps=1, num_train_epochs=2, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -534,7 +534,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, packing=True, report_to="none", ) @@ -558,7 +558,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, packing=True, report_to="none", ) @@ -583,7 +583,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -606,7 +606,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -755,7 +755,7 @@ def test_sft_trainer_infinite_with_model(self): save_steps=1, per_device_train_batch_size=2, packing=True, - max_seq_length=500, + max_length=500, report_to="none", ) trainer = SFTTrainer( @@ -782,7 +782,7 @@ def test_sft_trainer_infinite_with_model_epochs(self): per_device_train_batch_size=2, save_strategy="epoch", packing=True, - max_seq_length=500, + max_length=500, report_to="none", ) trainer = SFTTrainer( @@ -1088,7 +1088,7 @@ def test_sft_trainer_only_train_packing(self): per_device_train_batch_size=2, gradient_checkpointing=True, packing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence eval_packing=False, report_to="none", ) @@ -1114,7 +1114,7 @@ def test_sft_trainer_eval_packing(self): save_steps=2, per_device_train_batch_size=2, gradient_checkpointing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -1139,7 +1139,7 @@ def test_sft_trainer_no_packing(self): save_steps=2, per_device_train_batch_size=2, gradient_checkpointing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=False, report_to="none", ) diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 251b1f5a96..406eba4f86 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -368,7 +368,7 @@ def test_sft(self): tmp_dir, dataset_text_field="dummy_text_field", packing=True, - max_seq_length=256, + max_length=256, dataset_num_proc=4, dataset_batch_size=512, neftune_noise_alpha=0.1, @@ -379,7 +379,7 @@ def test_sft(self): trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") self.assertEqual(trainer.args.packing, True) - self.assertEqual(trainer.args.max_seq_length, 256) + self.assertEqual(trainer.args.max_length, 256) self.assertEqual(trainer.args.dataset_num_proc, 4) self.assertEqual(trainer.args.dataset_batch_size, 512) self.assertEqual(trainer.args.neftune_noise_alpha, 0.1) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 5e76f30ed2..7cfad453f7 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -87,7 +87,7 @@ def __init__( ): # add remove_unused_columns=False to the dataclass args args.remove_unused_columns = False - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) super().__init__( model, diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index ad0e936c18..23b617dfe2 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -49,13 +49,11 @@ class SFTConfig(TrainingArguments): `skip_prepare_dataset`. dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. - max_seq_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the - right. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. packing (`bool`, *optional*, defaults to `False`): - Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence - length. + Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length. eval_packing (`bool` or `None`, *optional*, defaults to `None`): Whether to pack the eval dataset. If `None`, uses the same value as `packing`. @@ -95,19 +93,19 @@ class SFTConfig(TrainingArguments): default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) - max_seq_length: Optional[int] = field( + max_length: Optional[int] = field( default=1024, metadata={ - "help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated " - "from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " "sequence length." }, ) packing: bool = field( default=False, metadata={ - "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to " - "define sequence length." + "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define " + "sequence length." }, ) eval_packing: Optional[bool] = field( @@ -132,13 +130,17 @@ class SFTConfig(TrainingArguments): num_of_sequences: int = field( default=None, metadata={ - "help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized " + "help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized " "sequence, unlike `num_of_sequences`, which referred to string sequences." }, ) chars_per_token: float = field( default=None, - metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."}, + metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."}, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={"help": "Deprecated. Use `max_length` instead."}, ) def __post_init__(self): @@ -153,7 +155,7 @@ def __post_init__(self): if self.num_of_sequences is not None: warnings.warn( - "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, " + "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_length` instead, " "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r" "eferred to string sequences.", DeprecationWarning, @@ -162,6 +164,12 @@ def __post_init__(self): if self.chars_per_token is not None: warnings.warn( "`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the " - "packing length, use `max_seq_length`.", + "packing length, use `max_length`.", + DeprecationWarning, + ) + + if self.max_seq_length is not None: + warnings.warn( + "`max_seq_length` is deprecated and will be remove in version 0.20.0. Use `max_length` instead.", DeprecationWarning, ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index e4708eb7c7..b0104f4b53 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -434,17 +434,17 @@ def tokenize(ex): # Pack or truncate if packing: - if args.max_seq_length is None: - raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Packing {dataset_name} dataset" dataset = dataset.select_columns("input_ids") dataset = dataset.map( - pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs + pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs ) - elif args.max_seq_length is not None: + elif args.max_length is not None: dataset = dataset.map( - lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]}, + lambda ex: {key: ex[key][: args.max_length] for key in ["input_ids", "attention_mask"]}, **map_kwargs, ) # For Liger kernel, ensure only input_ids is present diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 7a20645535..853ba1f3ca 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -140,7 +140,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -167,7 +167,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -182,7 +182,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find instruction key `{self.instruction_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index From 15fec312d5ff08f6c92831d6b43c9e4bb4711190 Mon Sep 17 00:00:00 2001 From: Pierre TASSEL Date: Tue, 18 Feb 2025 17:57:15 +0100 Subject: [PATCH 96/96] =?UTF-8?q?=F0=9F=8D=83=20GRPO=20-=20Do=20not=20load?= =?UTF-8?q?=20reference=20model=20when=20beta=20=3D=3D=200=20(#2806)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 Optimize GRPO training by conditionally loading reference model based on beta value * ✅ Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence * 🔧 Refactor GRPOTrainer code for improved readability and maintainability * 🔧 Simplify per_token_loss calculation in GRPOTrainer for clarity * fix test, style, and some struct for clarity --------- Co-authored-by: Quentin Gallouédec Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- tests/test_grpo_trainer.py | 30 ++++++++++++++++++++++++++++++ trl/trainer/grpo_config.py | 8 ++++++-- trl/trainer/grpo_trainer.py | 30 ++++++++++++++++++++---------- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index a891793550..5f58d69ca7 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -500,6 +500,36 @@ def test_training_with_sync_ref_model(self): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + def test_beta_zero_no_ref_model_and_no_kl(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + beta=0.0, # set beta to 0 to test the case where the reference model is not used + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + @unittest.skipIf(not is_vllm_available(), "vLLM is not available") @require_torch_accelerator @require_peft diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 02a02dc788..923686276c 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -88,7 +88,8 @@ class GRPOConfig(TrainingArguments): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. beta (`float`, *optional*, defaults to `0.04`): - KL coefficient. + KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training + speed. reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`. @@ -218,7 +219,10 @@ class GRPOConfig(TrainingArguments): ) beta: float = field( default=0.04, - metadata={"help": "KL coefficient."}, + metadata={ + "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " + "training speed." + }, ) reward_weights: Optional[list[float]] = field( default=None, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 93993e082a..573350c277 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -244,11 +244,16 @@ def __init__( "This argument can only be used when the `model` argument is a string." ) + self.beta = args.beta + if peft_config is not None: model = get_peft_model(model, peft_config) # Reference model - if is_deepspeed_zero3_enabled(): + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_deepspeed_zero3_enabled(): self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) elif not is_peft_model(model): # If PEFT configuration is not provided, create a reference model based on the initial model. @@ -314,8 +319,6 @@ def data_collator(features): # No data collation is needed in GRPO self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm - self.beta = args.beta - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: @@ -603,7 +606,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens with torch.inference_mode(): - if self.ref_model is not None: + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep ) @@ -723,21 +728,26 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model - ref_per_token_logps = inputs["ref_per_token_logps"] - per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] - per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - per_token_loss = -(per_token_loss - self.beta * per_token_kl) + per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() # Log the metrics completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length) - mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() - self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + if self.beta != 0.0: + mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) return loss