Skip to content

Commit

Permalink
📉 Use num_logits_to_keep to reduce memory usage in GRPO (#2683)
Browse files Browse the repository at this point in the history
* use num_logits to keep

* add comment back

* Update trl/trainer/grpo_trainer.py
  • Loading branch information
qgallouedec authored Jan 29, 2025
1 parent ed14ed9 commit 801582e
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,29 +427,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
completion_ids = prompt_completion_ids[:, prompt_length:]

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids):
logits = model(input_ids).logits # (B, L, V)
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

per_token_logps = get_per_token_logps(model, prompt_completion_ids)
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
per_token_logps = per_token_logps[:, prompt_length - 1 :]
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids)
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids)
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down

0 comments on commit 801582e

Please sign in to comment.