diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e11e5225ea..78cd149646 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -707,7 +707,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N advantages = inputs["advantages"] per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) per_token_loss = -(per_token_loss - self.beta * per_token_kl) - loss = loss = (per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum() + loss = loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() # Log the metrics completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()