Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom reward function support for ppo trainer #2540

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import Optional, Union
from typing import Callable, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
],
model: nn.Module,
ref_model: Optional[nn.Module],
reward_model: nn.Module,
reward_model: Union[nn.Module, Callable],
train_dataset: Dataset,
value_model: Optional[nn.Module] = None,
data_collator: Optional[DataCollatorWithPadding] = None,
Expand Down Expand Up @@ -218,7 +218,7 @@ def __init__(
# setup model, optimizer, and others
#########
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
if module is not None:
if isinstance(module, nn.Module):
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
Expand Down Expand Up @@ -285,9 +285,10 @@ def __init__(
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)

if self.is_deepspeed_enabled:
self.reward_model = prepare_deepspeed(
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
)
if isinstance(self.reward_model, nn.Module):
self.reward_model = prepare_deepspeed(
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
)

if self.ref_model is None:
if not self.is_peft_model:
Expand All @@ -302,7 +303,8 @@ def __init__(
raise ValueError("No reference model and model is not a Peft model.")
else:
self.ref_model = self.ref_model.to(self.accelerator.device)
self.reward_model = self.reward_model.to(self.accelerator.device)
if isinstance(self.reward_model, nn.Module):
self.reward_model = self.reward_model.to(self.accelerator.device)

def get_train_dataloader(self) -> DataLoader:
return self.dataloader
Expand Down Expand Up @@ -457,11 +459,19 @@ def repeat_generator():
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
unwrapped_value_model = accelerator.unwrap_model(model).value_model
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
unwrapped_value_model,
processing_class,
query_response,
processing_class.pad_token_id,
context_length,
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
_, score, _ = get_reward(
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
reward_model,
processing_class,
postprocessed_query_response,
processing_class.pad_token_id,
context_length,
)

responses.append(response)
Expand Down Expand Up @@ -713,7 +723,11 @@ def generate_completions(self, sampling: bool = False):

postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
self.reward_model,
processing_class,
postprocessed_query_response,
processing_class.pad_token_id,
context_length,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the if isinstance(model, torch.nn.Module): here? I would allow not to introduce breaking change in get_reward

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to clarify what you mean.

Copy link
Member

@qgallouedec qgallouedec Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, it wasn't clear:

something like this instead:

if isinstance(model, torch.nn.Module):
    full_value, _, _ = get_reward(
        unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
    )
else:
    full_value = ...

doing such we don't introduce a breaking change in get_reward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean I changed get_reward to work either way with both a callable and an nn.Module.
So you want to add if isinstance(model, torch.nn.Module) there and keep get_reward as it is without change?

)
table["score"].extend(self.accelerator.gather(score).float().cpu().numpy())

Expand Down
68 changes: 41 additions & 27 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import deque
from dataclasses import dataclass
from importlib.metadata import version
from typing import Any, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

import datasets
import numpy as np
Expand Down Expand Up @@ -1049,14 +1049,20 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long):


def get_reward(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the primary change are:
modifying the get_reward function to work with both a nn.Module and a Callable.

model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
model: Union[torch.nn.Module, Callable],
processor: PreTrainedTokenizerBase,
query_responses: torch.Tensor,
pad_token_id: int,
context_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the reward logits and the rewards for a given model and query responses.
Computes the reward logits and the rewards for a given model/function and query responses.

Args:
model (`torch.nn.Module`):
The model used to compute the reward logits.
model (`torch.nn.Module` or `Callable`):
The model or a custom function used to compute the reward logits.
processor:
The processor (e.g., tokenizer) to decode the input if needed.
query_responses (`torch.Tensor`):
The tensor containing the query responses.
pad_token_id (`int`):
Expand All @@ -1073,29 +1079,37 @@ def get_reward(
- `sequence_lengths` (`torch.Tensor`):
The lengths of the sequences in the query responses.
"""
attention_mask = query_responses != pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
lm_backbone = getattr(model, model.base_model_prefix)
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
output = lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
use_cache=False, # otherwise mistral-based RM would error out
)
reward_logits = model.score(output.hidden_states[-1])
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
return (
reward_logits,
reward_logits[
torch.arange(reward_logits.size(0), device=reward_logits.device),

if isinstance(model, torch.nn.Module):
attention_mask = query_responses != pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
lm_backbone = getattr(model, model.base_model_prefix)
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
output = lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
use_cache=False, # otherwise mistral-based RM would error out
)
reward_logits = model.score(output.hidden_states[-1])
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
return (
reward_logits,
reward_logits[
torch.arange(reward_logits.size(0), device=reward_logits.device),
sequence_lengths,
].squeeze(-1),
sequence_lengths,
].squeeze(-1),
sequence_lengths,
)
)
else:
texts = processor.batch_decode(query_responses)
rewards = model(texts)
rewards = torch.tensor(rewards, dtype=torch.float)
final_rewards, sequence_lengths = None, None
return final_rewards, rewards, sequence_lengths


def forward(
Expand Down
Loading