Skip to content

Commit

Permalink
Fix CUDA sync point in mean_token_accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever committed Feb 19, 2025
1 parent 9b3c5bf commit f677ed3
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)

# Compute the mean token accuracy and log it
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
self._metrics["mean_token_accuracy"].append(accuracy)
# Record data for the mean token accuracy and logging
if not self._metrics["mean_token_accuracy"]:
self._metrics["mean_token_accuracy"].append((correct_tokens.sum(), total_tokens.sum()))
else:
self._metrics["mean_token_accuracy"][0][0] += correct_tokens.sum()
self._metrics["mean_token_accuracy"][0][1] += total_tokens.sum()

return (loss, outputs) if return_outputs else loss

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
# average the metrics
metrics = {}
for key, val in self._metrics.items():
if key == "mean_token_accuracy":
total_tokens = val[0][1].item()
metrics[key] = (val[0][0].item() / total_tokens) if total_tokens > 0 else 0.0
else:
metrics[key] = sum(val) / len(val)

# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
Expand Down

0 comments on commit f677ed3

Please sign in to comment.