Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove CUDA synchronization in mean_token_accuracy #2902

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,17 +478,27 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
correct_tokens = correct_predictions.sum()

# Gather the correct_tokens and total_tokens across all processes
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
correct_token_sum = self.accelerator.gather_for_metrics(correct_tokens).detach().sum()
total_token_sum = self.accelerator.gather_for_metrics(total_tokens).detach().sum()

# Compute the mean token accuracy and log it
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
self._metrics["mean_token_accuracy"].append(accuracy)
# Record data for the mean token accuracy and logging
if not self._metrics["mean_token_accuracy"]:
self._metrics["mean_token_accuracy"] = [correct_token_sum, total_token_sum]
else:
self._metrics["mean_token_accuracy"][0] += correct_token_sum
self._metrics["mean_token_accuracy"][1] += total_token_sum

return (loss, outputs) if return_outputs else loss

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
# average the metrics
metrics = {}
for key, val in self._metrics.items():
if key == "mean_token_accuracy":
total_token_sum = val[1].item()
metrics[key] = (val[0].item() / total_token_sum) if total_token_sum > 0 else 0.0
else:
metrics[key] = sum(val) / len(val)

# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
Expand Down