-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🍭 Custom reward function for RLOO #2612
🍭 Custom reward function for RLOO #2612
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
trl/trainer/utils.py
Outdated
texts = processor.batch_decode(query_responses, skip_special_tokens=True) | ||
rewards = model(texts) | ||
rewards = torch.tensor(rewards, dtype=torch.float).to(query_responses.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering, since it's only three lines and used once, maybe you can put it directly in RLOO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but it could be used in other trainers that use get_reward
like ppo or online dpo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True.
We can always limit duplication later, once we're sure that the same lines are used in PPO and OnlineDPO. In fact, I'd rather have duplicated code than utilities that you only use once.
trl/trainer/rloo_trainer.py
Outdated
@@ -79,7 +80,7 @@ def __init__( | |||
], | |||
policy: nn.Module, | |||
ref_policy: nn.Module, | |||
reward_model: nn.Module, | |||
reward_model: Union[nn.Module, Callable], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reward_model: Union[nn.Module, Callable], | |
reward_model: Union[nn.Module, Callable[[list[str]], list[float]], |
Nice! Can be done in a follow-up PR, but it might make sense to rename |
Can you add a test in |
yea I think it's better to do it for all trainers in a separate PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just need a unit test and we're good to merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Feel free to merge when the CI is green (ignore dev dep that will fail)
Similar to #2540
Since RLOO doesn't use a value model, it's much simpler.