diff --git a/examples/datasets/tokenize_ds.py b/examples/datasets/tokenize_ds.py new file mode 100644 index 0000000000..39c9f2814a --- /dev/null +++ b/examples/datasets/tokenize_ds.py @@ -0,0 +1,54 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from transformers import AutoTokenizer, HfArgumentParser + +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +""" +python -i examples/datasets/tokenize_ds.py --model HuggingFaceH4/zephyr-7b-beta +python -i examples/datasets/tokenize_ds.py --model gpt2 +""" + + +@dataclass +class ScriptArguments: + dataset_name: str = field( + default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"} + ) + model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"}) + dataset_num_proc: Optional[int] = field( + default=None, metadata={"help": "The number of workers to use to tokenize the data"} + ) + + +if __name__ == "__main__": + script_args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0] + dataset = load_dataset(script_args.dataset_name) + tokenizer = AutoTokenizer.from_pretrained(script_args.model) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + def process(row): + row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) + row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) + return row + + dataset = dataset.map(process, num_proc=script_args.dataset_num_proc) + print(dataset["train"][0]["chosen"]) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index c30813a4b1..717b0318d4 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -337,13 +337,20 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.is_deepspeed_enabled: self.deepspeed = backup_deepspeed + def _compute_score(self, query_reponse: torch.Tensor, context_length: int) -> torch.Tensor: + """ + This methods decoples the score computing from the training method. + Override it to implement your custom reward function. + """ + _, score, _ = get_reward(self.reward_model, query_reponse, self.processing_class.pad_token_id, context_length) + return score + def train(self): args = self.args accelerator = self.accelerator optimizer = self.optimizer model = self.model ref_policy = self.ref_model - reward_model = self.reward_model processing_class = self.processing_class dataloader = self.dataloader device = accelerator.device @@ -460,9 +467,7 @@ def repeat_generator(): unwrapped_value_model, query_response, processing_class.pad_token_id, context_length ) value = full_value[:, context_length - 1 : -1].squeeze(-1) - _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) + score = self._compute_score(postprocessed_query_response, context_length) responses.append(response) postprocessed_responses.append(postprocessed_response) @@ -714,9 +719,7 @@ def generate_completions(self, sampling: bool = False): ) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) + score = self._compute_score(postprocessed_query_response, context_length) table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) if sampling: