From c076794e202765ae2b8b4dd738d01522ddf6342c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Feb 2025 00:13:44 +0000 Subject: [PATCH] Add num_updates and epsilon parameters to GRPOConfig and GRPOTrainer --- trl/trainer/grpo_config.py | 12 ++++ trl/trainer/grpo_trainer.py | 108 ++++++++++++++++++++++++++++++------ 2 files changed, 103 insertions(+), 17 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 02a02dc788..6c43aac766 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -89,6 +89,10 @@ class GRPOConfig(TrainingArguments): [`~transformers.TrainingArguments`]. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. + num_updates (`int`, *optional*, defaults to `1`): + Number of updates per batch. + epsilon (`float`, *optional*, defaults to `0.2`): + Epsilon value for clipping reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`. @@ -220,6 +224,14 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + num_updates: int = field( + default=1, + metadata={"help": "Number of updates per batch."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Epsilon value for clipping."}, + ) reward_weights: Optional[list[float]] = field( default=None, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 93993e082a..2682f3ea38 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -67,26 +67,51 @@ class RepeatRandomSampler(Sampler): """ - Sampler that repeats the indices of a dataset N times. + Sampler that repeats the indices of a dataset in a structured manner. Args: data_source (`Sized`): Dataset to sample from. - repeat_count (`int`): - Number of times to repeat each index. - seed (`Optional[int]`): + mini_repeat_count (`int`): + Number of times to repeat each index per batch. + batch_size (`int`, *optional*, defaults to `1`): + Number of unique indices per batch. + repeat_count (`int`, *optional*, defaults to `1`): + Number of times to repeat the full sampling process. + seed (`int` or `None`, *optional*, defaults to `None`): Random seed for reproducibility (only affects this sampler). Example: ```python - >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2) + >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4) >>> list(sampler) - [2, 2, 0, 0, 3, 3, 1, 1] + [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, + 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6] + ``` + + ```txt + mini_repeat_count + - - + [4, 4, 3, 3, 0, 0, | + 4, 4, 3, 3, 0, 0, | + 4, 4, 3, 3, 0, 0, | repeat_count + 4, 4, 3, 3, 0, 0] | + ---- ---- ---- + batch_size ``` """ - def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None): + def __init__( + self, + data_source: Sized, + mini_repeat_count: int, + batch_size: int = 1, + repeat_count: int = 1, + seed: Optional[int] = None, + ): self.data_source = data_source + self.mini_repeat_count = mini_repeat_count + self.batch_size = batch_size self.repeat_count = repeat_count self.num_samples = len(data_source) self.seed = seed @@ -95,15 +120,33 @@ def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = self.generator.manual_seed(seed) def __iter__(self): - indexes = [ - idx - for idx in torch.randperm(self.num_samples, generator=self.generator).tolist() - for _ in range(self.repeat_count) - ] + # [2, 4, 3, 1, 0, 6, 5] (num_samples = 7) + indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() + + # [2, 4, 3, 1, 0, 6, 5] + # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) + indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] + + # [[2, 4, 3], [1, 0, 6], [5]] + # -> [[2, 4, 3], [1, 0, 6]] + indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] + + # [[2, 4, 3], [1, 0, 6]] + # -> [[2, 0, 3], [2, 0, 3], [1, 0, 6], [1, 0, 6]] (repeat_count = 2) + indexes = [chunk for chunk in indexes for _ in range(self.repeat_count)] + + # [[2, 0, 3], [2, 0, 3], [1, 0, 6], [1, 0, 6]] + # -> [[2, 2, 0, 0, 3, 3], [2, 2, 0, 0, 3, 3], [1, 1, 0, 0, 6, 6], [1, 1, 0, 0, 6, 6]] (mini_repeat_count = 2) + indexes = [[index for index in chunk for _ in range(self.mini_repeat_count)] for chunk in indexes] + + # [[2, 2, 0, 0, 3, 3], [2, 2, 0, 0, 3, 3], [1, 1, 0, 0, 6, 6], [1, 1, 0, 0, 6, 6]] + # -> [2, 2, 0, 0, 3, 3, 2, 2, 0, 0, 3, 3, 1, 1, 0, 0, 6, 6, 1, 1, 0, 0, 6, 6] + indexes = sum(indexes, []) + return iter(indexes) - def __len__(self): - return self.num_samples * self.repeat_count + def __len__(self) -> int: + return self.num_samples * self.mini_repeat_count * self.repeat_count class GRPOTrainer(Trainer): @@ -313,7 +356,9 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm - + self.num_updates = args.num_updates + self._buffered_inputs = [] + self.epsilon = args.epsilon self.beta = args.beta # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the @@ -476,14 +521,26 @@ def _get_train_sampler(self) -> Sampler: # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, # preventing discrepancies in group formation. - return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed) + return RepeatRandomSampler( + data_source = self.train_dataset, + mini_repeat_count= self.num_generations, + batch_size=self.args.per_device_train_batch_size * self.accelerator.num_processes, + repeat_count=self.num_updates, + seed=self.args.seed, + ) def _get_eval_sampler(self, eval_dataset) -> Sampler: # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, # preventing discrepancies in group formation. - return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed) + return RepeatRandomSampler( + data_source = self.eval_dataset, + mini_repeat_count= self.num_generations, + batch_size=self.args.per_device_eval_batch_size * self.accelerator.num_processes, + repeat_count=self.num_updates, + seed=self.args.seed, + ) # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): @@ -529,6 +586,16 @@ def _move_model_to_vllm(self): unwrapped_model.unmerge_adapter() def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + if self.state.global_step % self.num_updates == 0: + inputs = self._generate_and_score_completions(inputs) + self._buffered_inputs.append(inputs) + else: + inputs = self._buffered_inputs.pop(0) + return inputs + + def _generate_and_score_completions( + self, inputs: dict[str, Union[torch.Tensor, Any]] + ) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device prompts = [x["prompt"] for x in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -729,6 +796,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + + per_token_loss1 = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss2 = torch.clamp(torch.exp(per_token_logps - per_token_logps.detach()), 1 - self.epsilon, 1 + self.epsilon) * advantages.unsqueeze(1) + per_token_loss_min = torch.min(per_token_loss1, per_token_loss2,dim=-1) + per_token_loss = -(per_token_loss_min - self.beta * per_token_kl) + + per_token_loss = -(per_token_loss - self.beta * per_token_kl) loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()