Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _compute_score method to PPOTrainer #2560

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser

from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
python -i examples/datasets/tokenize_ds.py --model HuggingFaceH4/zephyr-7b-beta
python -i examples/datasets/tokenize_ds.py --model gpt2
"""


@dataclass
class ScriptArguments:
dataset_name: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
)
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)


if __name__ == "__main__":
script_args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
dataset = load_dataset(script_args.dataset_name)
tokenizer = AutoTokenizer.from_pretrained(script_args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

dataset = dataset.map(process, num_proc=script_args.dataset_num_proc)
print(dataset["train"][0]["chosen"])
17 changes: 10 additions & 7 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,20 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed

def _compute_score(self, query_reponse: torch.Tensor, context_length: int) -> torch.Tensor:
"""
This methods decoples the score computing from the training method.
Override it to implement your custom reward function.
"""
_, score, _ = get_reward(self.reward_model, query_reponse, self.processing_class.pad_token_id, context_length)
return score

def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
ref_policy = self.ref_model
reward_model = self.reward_model
processing_class = self.processing_class
dataloader = self.dataloader
device = accelerator.device
Expand Down Expand Up @@ -460,9 +467,7 @@ def repeat_generator():
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
_, score, _ = get_reward(
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
score = self._compute_score(postprocessed_query_response, context_length)

responses.append(response)
postprocessed_responses.append(postprocessed_response)
Expand Down Expand Up @@ -714,9 +719,7 @@ def generate_completions(self, sampling: bool = False):
)

postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
score = self._compute_score(postprocessed_query_response, context_length)
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())

if sampling:
Expand Down