data:image/s3,"s3://crabby-images/5b9a2/5b9a247478401e34029a00366cf1b1ec1cf211d5" alt="Truncation input ids"
@@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet:
```python
from trl import SFTConfig
-training_args = SFTConfig(..., max_seq_length=...)
+training_args = SFTConfig(..., max_length=...)
```
@@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f
```python
from trl import SFTConfig
-training_args = SFTConfig(..., packing=True, max_seq_length=512)
+training_args = SFTConfig(..., packing=True, max_length=512)
```
diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md
index ab3e9e1cc5..5c30b744fa 100644
--- a/docs/source/sft_trainer.md
+++ b/docs/source/sft_trainer.md
@@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
- max_seq_length=512,
+ max_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
@@ -29,7 +29,7 @@ trainer = SFTTrainer(
)
trainer.train()
```
-Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
+Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
@@ -550,12 +550,12 @@ import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
-max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
+max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
- max_seq_length=max_seq_length,
+ max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
@@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model(
random_state=3407,
)
-training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
+training_args = SFTConfig(output_dir="./output", max_length=max_length)
trainer = SFTTrainer(
model=model,
@@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith
Pay attention to the following best practices when training a model with that trainer:
-- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
+- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py
index 3ae1e82c2a..1f4611a3e8 100644
--- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py
+++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py
@@ -185,7 +185,7 @@ def create_datasets(tokenizer, args, seed=None):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
- max_seq_length=None,
+ max_length=None,
formatting_func=prepare_sample_text,
processing_class=tokenizer,
args=training_args,
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index 74811d092c..8a772a48aa 100644
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
def setUp(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
- self.max_seq_length = 128
+ self.max_length = 128
self.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
@@ -74,7 +74,7 @@ def test_sft_trainer_str(self, model_name, packing):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
)
trainer = SFTTrainer(
@@ -100,7 +100,7 @@ def test_sft_trainer_transformers(self, model_name, packing):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -135,7 +135,7 @@ def test_sft_trainer_peft(self, model_name, packing):
max_steps=10,
fp16=True,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -172,7 +172,7 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):
max_steps=10,
fp16=True, # this is sufficient to enable amp
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -205,7 +205,7 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -242,7 +242,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -286,7 +286,7 @@ def test_sft_trainer_transformers_mp_gc_device_map(
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -324,7 +324,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -364,7 +364,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
training_args = SFTConfig(
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
@@ -411,7 +411,7 @@ def test_sft_trainer_with_liger(self, model_name, packing):
per_device_train_batch_size=2,
max_steps=2,
packing=packing,
- max_seq_length=self.max_seq_length,
+ max_length=self.max_length,
use_liger=True,
)
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index 14d235585a..1a26378f3f 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -326,7 +326,7 @@ def test_sft_trainer_uncorrect_data(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=32, # make sure there is at least 1 packed sequence
+ max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@@ -353,7 +353,7 @@ def test_sft_trainer_uncorrect_data(self):
train_dataset=self.conversational_lm_dataset["train"],
)
- # Same, but with packing with `max_seq_length`
+ # Same, but with packing with `max_length`
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
@@ -361,7 +361,7 @@ def test_sft_trainer_uncorrect_data(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=16, # make sure there is at least 1 packed sequence
+ max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@@ -396,7 +396,7 @@ def test_sft_trainer_uncorrect_data(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=32, # make sure there is at least 1 packed sequence
+ max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@@ -461,7 +461,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
- max_seq_length=16,
+ max_length=16,
packing=True,
report_to="none",
)
@@ -485,7 +485,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
- max_seq_length=16,
+ max_length=16,
report_to="none",
)
trainer = SFTTrainer(
@@ -534,7 +534,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=16,
+ max_length=16,
packing=True,
report_to="none",
)
@@ -558,7 +558,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=16,
+ max_length=16,
packing=True,
report_to="none",
)
@@ -583,7 +583,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=16,
+ max_length=16,
report_to="none",
)
trainer = SFTTrainer(
@@ -606,7 +606,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
- max_seq_length=16,
+ max_length=16,
report_to="none",
)
trainer = SFTTrainer(
@@ -755,7 +755,7 @@ def test_sft_trainer_infinite_with_model(self):
save_steps=1,
per_device_train_batch_size=2,
packing=True,
- max_seq_length=500,
+ max_length=500,
report_to="none",
)
trainer = SFTTrainer(
@@ -782,7 +782,7 @@ def test_sft_trainer_infinite_with_model_epochs(self):
per_device_train_batch_size=2,
save_strategy="epoch",
packing=True,
- max_seq_length=500,
+ max_length=500,
report_to="none",
)
trainer = SFTTrainer(
@@ -1088,7 +1088,7 @@ def test_sft_trainer_only_train_packing(self):
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
- max_seq_length=16, # make sure there is at least 1 packed sequence
+ max_length=16, # make sure there is at least 1 packed sequence
eval_packing=False,
report_to="none",
)
@@ -1114,7 +1114,7 @@ def test_sft_trainer_eval_packing(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
- max_seq_length=16, # make sure there is at least 1 packed sequence
+ max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@@ -1139,7 +1139,7 @@ def test_sft_trainer_no_packing(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
- max_seq_length=16, # make sure there is at least 1 packed sequence
+ max_length=16, # make sure there is at least 1 packed sequence
packing=False,
report_to="none",
)
diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py
index 251b1f5a96..406eba4f86 100644
--- a/tests/test_trainers_args.py
+++ b/tests/test_trainers_args.py
@@ -368,7 +368,7 @@ def test_sft(self):
tmp_dir,
dataset_text_field="dummy_text_field",
packing=True,
- max_seq_length=256,
+ max_length=256,
dataset_num_proc=4,
dataset_batch_size=512,
neftune_noise_alpha=0.1,
@@ -379,7 +379,7 @@ def test_sft(self):
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field")
self.assertEqual(trainer.args.packing, True)
- self.assertEqual(trainer.args.max_seq_length, 256)
+ self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.dataset_batch_size, 512)
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py
index 5e76f30ed2..7cfad453f7 100644
--- a/trl/trainer/gkd_trainer.py
+++ b/trl/trainer/gkd_trainer.py
@@ -87,7 +87,7 @@ def __init__(
):
# add remove_unused_columns=False to the dataclass args
args.remove_unused_columns = False
- data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
super().__init__(
model,
diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py
index ad0e936c18..23b617dfe2 100644
--- a/trl/trainer/sft_config.py
+++ b/trl/trainer/sft_config.py
@@ -49,13 +49,11 @@ class SFTConfig(TrainingArguments):
`skip_prepare_dataset`.
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
- max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
- right.
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
+ Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
packing (`bool`, *optional*, defaults to `False`):
- Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
- length.
+ Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
@@ -95,19 +93,19 @@ class SFTConfig(TrainingArguments):
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
- max_seq_length: Optional[int] = field(
+ max_length: Optional[int] = field(
default=1024,
metadata={
- "help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated "
- "from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
+ "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
+ "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
"sequence length."
},
)
packing: bool = field(
default=False,
metadata={
- "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to "
- "define sequence length."
+ "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define "
+ "sequence length."
},
)
eval_packing: Optional[bool] = field(
@@ -132,13 +130,17 @@ class SFTConfig(TrainingArguments):
num_of_sequences: int = field(
default=None,
metadata={
- "help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized "
+ "help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized "
"sequence, unlike `num_of_sequences`, which referred to string sequences."
},
)
chars_per_token: float = field(
default=None,
- metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."},
+ metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."},
+ )
+ max_seq_length: Optional[int] = field(
+ default=None,
+ metadata={"help": "Deprecated. Use `max_length` instead."},
)
def __post_init__(self):
@@ -153,7 +155,7 @@ def __post_init__(self):
if self.num_of_sequences is not None:
warnings.warn(
- "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, "
+ "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_length` instead, "
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r"
"eferred to string sequences.",
DeprecationWarning,
@@ -162,6 +164,12 @@ def __post_init__(self):
if self.chars_per_token is not None:
warnings.warn(
"`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the "
- "packing length, use `max_seq_length`.",
+ "packing length, use `max_length`.",
+ DeprecationWarning,
+ )
+
+ if self.max_seq_length is not None:
+ warnings.warn(
+ "`max_seq_length` is deprecated and will be remove in version 0.20.0. Use `max_length` instead.",
DeprecationWarning,
)
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index e4708eb7c7..b0104f4b53 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -434,17 +434,17 @@ def tokenize(ex):
# Pack or truncate
if packing:
- if args.max_seq_length is None:
- raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
+ if args.max_length is None:
+ raise ValueError("When packing is enabled, `max_length` can't be `None`.")
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
dataset = dataset.select_columns("input_ids")
dataset = dataset.map(
- pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
+ pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs
)
- elif args.max_seq_length is not None:
+ elif args.max_length is not None:
dataset = dataset.map(
- lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]},
+ lambda ex: {key: ex[key][: args.max_length] for key in ["input_ids", "attention_mask"]},
**map_kwargs,
)
# For Liger kernel, ensure only input_ids is present
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 7a20645535..853ba1f3ca 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -140,7 +140,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
- "calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
+ "calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
@@ -167,7 +167,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
- "calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
+ "calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
@@ -182,7 +182,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
- "calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
+ "calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
From 15fec312d5ff08f6c92831d6b43c9e4bb4711190 Mon Sep 17 00:00:00 2001
From: Pierre TASSEL
Date: Tue, 18 Feb 2025 17:57:15 +0100
Subject: [PATCH 96/96] =?UTF-8?q?=F0=9F=8D=83=20GRPO=20-=20Do=20not=20load?=
=?UTF-8?q?=20reference=20model=20when=20beta=20=3D=3D=200=20(#2806)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* 🔧 Optimize GRPO training by conditionally loading reference model based on beta value
* ✅ Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence
* 🔧 Refactor GRPOTrainer code for improved readability and maintainability
* 🔧 Simplify per_token_loss calculation in GRPOTrainer for clarity
* fix test, style, and some struct for clarity
---------
Co-authored-by: Quentin Gallouédec
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
---
tests/test_grpo_trainer.py | 30 ++++++++++++++++++++++++++++++
trl/trainer/grpo_config.py | 8 ++++++--
trl/trainer/grpo_trainer.py | 30 ++++++++++++++++++++----------
3 files changed, 56 insertions(+), 12 deletions(-)
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index a891793550..5f58d69ca7 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -500,6 +500,36 @@ def test_training_with_sync_ref_model(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+ def test_beta_zero_no_ref_model_and_no_kl(self):
+ dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ training_args = GRPOConfig(
+ output_dir=tmp_dir,
+ beta=0.0, # set beta to 0 to test the case where the reference model is not used
+ learning_rate=0.1, # increase the learning rate to speed up the test
+ per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
+ num_generations=3, # reduce the number of generations to reduce memory usage
+ max_completion_length=32, # reduce the completion length to reduce memory usage
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check that the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
@require_peft
diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py
index 02a02dc788..923686276c 100644
--- a/trl/trainer/grpo_config.py
+++ b/trl/trainer/grpo_config.py
@@ -88,7 +88,8 @@ class GRPOConfig(TrainingArguments):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
- KL coefficient.
+ KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
+ speed.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
@@ -218,7 +219,10 @@ class GRPOConfig(TrainingArguments):
)
beta: float = field(
default=0.04,
- metadata={"help": "KL coefficient."},
+ metadata={
+ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
+ "training speed."
+ },
)
reward_weights: Optional[list[float]] = field(
default=None,
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index 93993e082a..573350c277 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -244,11 +244,16 @@ def __init__(
"This argument can only be used when the `model` argument is a string."
)
+ self.beta = args.beta
+
if peft_config is not None:
model = get_peft_model(model, peft_config)
# Reference model
- if is_deepspeed_zero3_enabled():
+ if self.beta == 0.0:
+ # If beta is 0.0, the reference model is not needed
+ self.ref_model = None
+ elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
@@ -314,8 +319,6 @@ def data_collator(features): # No data collation is needed in GRPO
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
- self.beta = args.beta
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
@@ -603,7 +606,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.inference_mode():
- if self.ref_model is not None:
+ if self.beta == 0.0:
+ ref_per_token_logps = None
+ elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
@@ -723,21 +728,26 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
- ref_per_token_logps = inputs["ref_per_token_logps"]
- per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ if self.beta != 0.0:
+ ref_per_token_logps = inputs["ref_per_token_logps"]
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ )
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
- per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
- per_token_loss = -(per_token_loss - self.beta * per_token_kl)
+ per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
+ if self.beta != 0.0:
+ per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
- mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
- self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
+ if self.beta != 0.0:
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss