Skip to content

Commit

Permalink
fix sum dim
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Feb 17, 2025
1 parent 0e10950 commit de44135
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = loss = (per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum()
loss = loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
Expand Down

0 comments on commit de44135

Please sign in to comment.