Skip to content

Commit

Permalink
Add num_updates and epsilon parameters to GRPOConfig and GRPOTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 19, 2025
1 parent be1e340 commit c076794
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 17 deletions.
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class GRPOConfig(TrainingArguments):
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
num_updates (`int`, *optional*, defaults to `1`):
Number of updates per batch.
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
Expand Down Expand Up @@ -220,6 +224,14 @@ class GRPOConfig(TrainingArguments):
default=0.04,
metadata={"help": "KL coefficient."},
)
num_updates: int = field(
default=1,
metadata={"help": "Number of updates per batch."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
Expand Down
108 changes: 91 additions & 17 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,51 @@

class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset N times.
Sampler that repeats the indices of a dataset in a structured manner.
Args:
data_source (`Sized`):
Dataset to sample from.
repeat_count (`int`):
Number of times to repeat each index.
seed (`Optional[int]`):
mini_repeat_count (`int`):
Number of times to repeat each index per batch.
batch_size (`int`, *optional*, defaults to `1`):
Number of unique indices per batch.
repeat_count (`int`, *optional*, defaults to `1`):
Number of times to repeat the full sampling process.
seed (`int` or `None`, *optional*, defaults to `None`):
Random seed for reproducibility (only affects this sampler).
Example:
```python
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
>>> list(sampler)
[2, 2, 0, 0, 3, 3, 1, 1]
[4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
```
```txt
mini_repeat_count
- -
[4, 4, 3, 3, 0, 0, |
4, 4, 3, 3, 0, 0, |
4, 4, 3, 3, 0, 0, | repeat_count
4, 4, 3, 3, 0, 0] |
---- ---- ----
batch_size
```
"""

def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
def __init__(
self,
data_source: Sized,
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
seed: Optional[int] = None,
):
self.data_source = data_source
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
Expand All @@ -95,15 +120,33 @@ def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] =
self.generator.manual_seed(seed)

def __iter__(self):
indexes = [
idx
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
for _ in range(self.repeat_count)
]
# [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()

# [2, 4, 3, 1, 0, 6, 5]
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]

# [[2, 4, 3], [1, 0, 6], [5]]
# -> [[2, 4, 3], [1, 0, 6]]
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]

# [[2, 4, 3], [1, 0, 6]]
# -> [[2, 0, 3], [2, 0, 3], [1, 0, 6], [1, 0, 6]] (repeat_count = 2)
indexes = [chunk for chunk in indexes for _ in range(self.repeat_count)]

# [[2, 0, 3], [2, 0, 3], [1, 0, 6], [1, 0, 6]]
# -> [[2, 2, 0, 0, 3, 3], [2, 2, 0, 0, 3, 3], [1, 1, 0, 0, 6, 6], [1, 1, 0, 0, 6, 6]] (mini_repeat_count = 2)
indexes = [[index for index in chunk for _ in range(self.mini_repeat_count)] for chunk in indexes]

# [[2, 2, 0, 0, 3, 3], [2, 2, 0, 0, 3, 3], [1, 1, 0, 0, 6, 6], [1, 1, 0, 0, 6, 6]]
# -> [2, 2, 0, 0, 3, 3, 2, 2, 0, 0, 3, 3, 1, 1, 0, 0, 6, 6, 1, 1, 0, 0, 6, 6]
indexes = sum(indexes, [])

return iter(indexes)

def __len__(self):
return self.num_samples * self.repeat_count
def __len__(self) -> int:
return self.num_samples * self.mini_repeat_count * self.repeat_count


class GRPOTrainer(Trainer):
Expand Down Expand Up @@ -313,7 +356,9 @@ def data_collator(features): # No data collation is needed in GRPO
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm

self.num_updates = args.num_updates
self._buffered_inputs = []
self.epsilon = args.epsilon
self.beta = args.beta

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
Expand Down Expand Up @@ -476,14 +521,26 @@ def _get_train_sampler(self) -> Sampler:
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
return RepeatRandomSampler(
data_source = self.train_dataset,
mini_repeat_count= self.num_generations,
batch_size=self.args.per_device_train_batch_size * self.accelerator.num_processes,
repeat_count=self.num_updates,
seed=self.args.seed,
)

def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
return RepeatRandomSampler(
data_source = self.eval_dataset,
mini_repeat_count= self.num_generations,
batch_size=self.args.per_device_eval_batch_size * self.accelerator.num_processes,
repeat_count=self.num_updates,
seed=self.args.seed,
)

# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
Expand Down Expand Up @@ -529,6 +586,16 @@ def _move_model_to_vllm(self):
unwrapped_model.unmerge_adapter()

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
if self.state.global_step % self.num_updates == 0:
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs.append(inputs)
else:
inputs = self._buffered_inputs.pop(0)
return inputs

def _generate_and_score_completions(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
Expand Down Expand Up @@ -729,6 +796,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)

per_token_loss1 = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss2 = torch.clamp(torch.exp(per_token_logps - per_token_logps.detach()), 1 - self.epsilon, 1 + self.epsilon) * advantages.unsqueeze(1)
per_token_loss_min = torch.min(per_token_loss1, per_token_loss2,dim=-1)
per_token_loss = -(per_token_loss_min - self.beta * per_token_kl)


per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

Expand Down

0 comments on commit c076794

Please sign in to comment.