Skip to content

Commit

Permalink
🪪 Adds profiling decorators for GRPOTrainer (#2889)
Browse files Browse the repository at this point in the history
* adds profiling decorator

* naming + precommit

* style

* revert inclusion of slider table

* revert 2

* revert3

* revert4

* revert 5 fml

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
edbeeching and qgallouedec authored Feb 20, 2025
1 parent 9b3c5bf commit a92e00e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
41 changes: 41 additions & 0 deletions trl/extras/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import time

from transformers import is_wandb_available


if is_wandb_available():
import wandb


def profiling_decorator(func):
"""
Decorator to profile a function and log the time taken to execute it.
"""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time

if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
return result

return wrapper
5 changes: 5 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_decorator
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -517,6 +518,7 @@ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfi
return model

# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
Expand All @@ -528,6 +530,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

@profiling_decorator
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
Expand Down Expand Up @@ -559,6 +562,7 @@ def _move_model_to_vllm(self):
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
Expand Down Expand Up @@ -742,6 +746,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
"advantages": advantages,
}

@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
Expand Down

0 comments on commit a92e00e

Please sign in to comment.