From a7b91ba282c7f2d5e277424827b169929ca0a530 Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Fri, 3 Jan 2025 08:57:16 +0000 Subject: [PATCH 1/3] custom reward function support for ppo trainer --- trl/trainer/ppo_trainer.py | 34 +++++++++++++------ trl/trainer/utils.py | 68 +++++++++++++++++++++++--------------- 2 files changed, 65 insertions(+), 37 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 51897eeb44..cae35720bb 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -19,7 +19,7 @@ import time from collections import defaultdict from contextlib import contextmanager, nullcontext -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np import pandas as pd @@ -114,7 +114,7 @@ def __init__( ], model: nn.Module, ref_model: Optional[nn.Module], - reward_model: nn.Module, + reward_model: Union[nn.Module, Callable], train_dataset: Dataset, value_model: Optional[nn.Module] = None, data_collator: Optional[DataCollatorWithPadding] = None, @@ -218,7 +218,7 @@ def __init__( # setup model, optimizer, and others ######### for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: - if module is not None: + if isinstance(module, nn.Module): disable_dropout_in_model(module) if args.stop_token and args.stop_token == "eos": args.stop_token_id = processing_class.eos_token_id @@ -285,9 +285,10 @@ def __init__( self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) + if isinstance(self.reward_model, nn.Module): + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) if self.ref_model is None: if not self.is_peft_model: @@ -302,7 +303,8 @@ def __init__( raise ValueError("No reference model and model is not a Peft model.") else: self.ref_model = self.ref_model.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) + if isinstance(self.reward_model, nn.Module): + self.reward_model = self.reward_model.to(self.accelerator.device) def get_train_dataloader(self) -> DataLoader: return self.dataloader @@ -457,11 +459,19 @@ def repeat_generator(): sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 unwrapped_value_model = accelerator.unwrap_model(model).value_model full_value, _, _ = get_reward( - unwrapped_value_model, query_response, processing_class.pad_token_id, context_length + unwrapped_value_model, + processing_class, + query_response, + processing_class.pad_token_id, + context_length, ) value = full_value[:, context_length - 1 : -1].squeeze(-1) _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + reward_model, + processing_class, + postprocessed_query_response, + processing_class.pad_token_id, + context_length, ) responses.append(response) @@ -713,7 +723,11 @@ def generate_completions(self, sampling: bool = False): postprocessed_query_response = torch.cat((query, postprocessed_response), 1) _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + self.reward_model, + processing_class, + postprocessed_query_response, + processing_class.pad_token_id, + context_length, ) table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index ab9a06e469..360a883380 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -20,7 +20,7 @@ from collections import deque from dataclasses import dataclass from importlib.metadata import version -from typing import Any, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import datasets import numpy as np @@ -1049,14 +1049,20 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long): def get_reward( - model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int + model: Union[torch.nn.Module, Callable], + processor: PreTrainedTokenizerBase, + query_responses: torch.Tensor, + pad_token_id: int, + context_length: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Computes the reward logits and the rewards for a given model and query responses. + Computes the reward logits and the rewards for a given model/function and query responses. Args: - model (`torch.nn.Module`): - The model used to compute the reward logits. + model (`torch.nn.Module` or `Callable`): + The model or a custom function used to compute the reward logits. + processor: + The processor (e.g., tokenizer) to decode the input if needed. query_responses (`torch.Tensor`): The tensor containing the query responses. pad_token_id (`int`): @@ -1073,29 +1079,37 @@ def get_reward( - `sequence_lengths` (`torch.Tensor`): The lengths of the sequences in the query responses. """ - attention_mask = query_responses != pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - lm_backbone = getattr(model, model.base_model_prefix) - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - output = lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - use_cache=False, # otherwise mistral-based RM would error out - ) - reward_logits = model.score(output.hidden_states[-1]) - sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length - # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return ( - reward_logits, - reward_logits[ - torch.arange(reward_logits.size(0), device=reward_logits.device), + + if isinstance(model, torch.nn.Module): + attention_mask = query_responses != pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + lm_backbone = getattr(model, model.base_model_prefix) + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + output = lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + use_cache=False, # otherwise mistral-based RM would error out + ) + reward_logits = model.score(output.hidden_states[-1]) + sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[ + torch.arange(reward_logits.size(0), device=reward_logits.device), + sequence_lengths, + ].squeeze(-1), sequence_lengths, - ].squeeze(-1), - sequence_lengths, - ) + ) + else: + texts = processor.batch_decode(query_responses) + rewards = model(texts) + rewards = torch.tensor(rewards, dtype=torch.float) + final_rewards, sequence_lengths = None, None + return final_rewards, rewards, sequence_lengths def forward( From 0c4a98a08f87d7eff64c56ad8652f14771273ddc Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Thu, 9 Jan 2025 12:50:50 +0000 Subject: [PATCH 2/3] movinng the custom reward tensor to device --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 360a883380..44aaa769e2 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1107,7 +1107,7 @@ def get_reward( else: texts = processor.batch_decode(query_responses) rewards = model(texts) - rewards = torch.tensor(rewards, dtype=torch.float) + rewards = torch.tensor(rewards, dtype=torch.float).to(query_responses.device) final_rewards, sequence_lengths = None, None return final_rewards, rewards, sequence_lengths From 11484e724804f0fe4395b1863696f2019a458bbc Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Sat, 18 Jan 2025 12:47:00 +0000 Subject: [PATCH 3/3] alternative approach to avoid modifying get_reward --- trl/trainer/ppo_trainer.py | 32 +++++++++-------- trl/trainer/utils.py | 74 ++++++++++++++++++-------------------- 2 files changed, 51 insertions(+), 55 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index cae35720bb..0f5e420628 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -62,6 +62,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + get_reward_custom, log_table_to_comet_experiment, peft_module_casting_to_bf16, prepare_deepspeed, @@ -460,19 +461,19 @@ def repeat_generator(): unwrapped_value_model = accelerator.unwrap_model(model).value_model full_value, _, _ = get_reward( unwrapped_value_model, - processing_class, query_response, processing_class.pad_token_id, context_length, ) value = full_value[:, context_length - 1 : -1].squeeze(-1) - _, score, _ = get_reward( - reward_model, - processing_class, - postprocessed_query_response, - processing_class.pad_token_id, - context_length, - ) + if isinstance(reward_model, torch.nn.Module): + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + else: + score = get_reward_custom( + reward_model, processing_class, postprocessed_query_response + ) responses.append(response) postprocessed_responses.append(postprocessed_response) @@ -722,13 +723,14 @@ def generate_completions(self, sampling: bool = False): ) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, - processing_class, - postprocessed_query_response, - processing_class.pad_token_id, - context_length, - ) + if isinstance(self.reward_model, torch.nn.Module): + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + else: + score = get_reward_custom( + self.reward_model, processing_class, postprocessed_query_response + ) table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) if sampling: diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 44aaa769e2..2c3c234112 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1049,20 +1049,14 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long): def get_reward( - model: Union[torch.nn.Module, Callable], - processor: PreTrainedTokenizerBase, - query_responses: torch.Tensor, - pad_token_id: int, - context_length: int, + model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Computes the reward logits and the rewards for a given model/function and query responses. + Computes the reward logits and the rewards for a given model and query responses. Args: - model (`torch.nn.Module` or `Callable`): - The model or a custom function used to compute the reward logits. - processor: - The processor (e.g., tokenizer) to decode the input if needed. + model (`torch.nn.Module`): + The model used to compute the reward logits. query_responses (`torch.Tensor`): The tensor containing the query responses. pad_token_id (`int`): @@ -1079,38 +1073,38 @@ def get_reward( - `sequence_lengths` (`torch.Tensor`): The lengths of the sequences in the query responses. """ - - if isinstance(model, torch.nn.Module): - attention_mask = query_responses != pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum - lm_backbone = getattr(model, model.base_model_prefix) - input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) - output = lm_backbone( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - output_hidden_states=True, - use_cache=False, # otherwise mistral-based RM would error out - ) - reward_logits = model.score(output.hidden_states[-1]) - sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length - # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return ( - reward_logits, - reward_logits[ - torch.arange(reward_logits.size(0), device=reward_logits.device), - sequence_lengths, - ].squeeze(-1), + attention_mask = query_responses != pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + lm_backbone = getattr(model, model.base_model_prefix) + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + output = lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + use_cache=False, # otherwise mistral-based RM would error out + ) + reward_logits = model.score(output.hidden_states[-1]) + sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[ + torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths, - ) - else: - texts = processor.batch_decode(query_responses) - rewards = model(texts) - rewards = torch.tensor(rewards, dtype=torch.float).to(query_responses.device) - final_rewards, sequence_lengths = None, None - return final_rewards, rewards, sequence_lengths + ].squeeze(-1), + sequence_lengths, + ) +def get_reward_custom(model: Callable, processor: PreTrainedTokenizerBase, query_responses: torch.Tensor) -> torch.Tensor: + """ + This function ensures that the custom reward function produces the correct output structure for integration with the trainer script. + """ + texts = processor.batch_decode(query_responses) + rewards = model(texts) + rewards = torch.tensor(rewards, dtype=torch.float).to(query_responses.device) + return rewards def forward( model: torch.nn.Module,