Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Weighted reward functions" #317

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions recipes/DeepSeek-R1-Distill-Qwen-7B/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
6 changes: 0 additions & 6 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
6 changes: 0 additions & 6 deletions recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
28 changes: 2 additions & 26 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import torch
Expand All @@ -28,7 +27,6 @@
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
create_weighted_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand All @@ -51,8 +49,6 @@ class GRPOScriptArguments(ScriptArguments):
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
reward_weights (`list[float]` or `None`, *optional*):
List of weights for each reward function. If not provided, defaults to 1.0 for each reward function.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Expand All @@ -71,12 +67,6 @@ class GRPOScriptArguments(ScriptArguments):
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
"help": "List of weights for each reward function. If not provided, defaults to 1.0 for each function."
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
Expand All @@ -98,17 +88,6 @@ class GRPOScriptArguments(ScriptArguments):
metadata={"help": "Maximum length for scaling"},
)

def __post_init__(self):
# If no weights were provided, default to 1.0 for each reward function
if self.reward_weights is None:
self.reward_weights = [1.0] * len(self.reward_funcs)
# If weights were provided, validate the length
elif len(self.reward_weights) != len(self.reward_funcs):
raise ValueError(
f"Number of reward weights ({len(self.reward_weights)}: {self.reward_weights}) must match "
f"number of reward functions ({len(self.reward_funcs)}: {self.reward_funcs})"
)

repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
Expand Down Expand Up @@ -168,7 +147,7 @@ def main(script_args, training_args, model_args):
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Create weighted reward functions
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
Expand All @@ -186,10 +165,7 @@ def main(script_args, training_args, model_args):
),
"length": len_reward,
}
reward_funcs = [
create_weighted_reward(REWARD_FUNCS_REGISTRY[func], weight)
for func, weight in zip(script_args.reward_funcs, script_args.reward_weights)
]
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

# Format into conversation
def make_conversation(example):
Expand Down
20 changes: 0 additions & 20 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import math
import re
from functools import wraps
from typing import Dict

from latex2sympy2_extended import NormalizationConfig
Expand Down Expand Up @@ -272,22 +271,3 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
return rewards

return repetition_penalty_reward


def create_weighted_reward(func, weight):
"""Create a weighted version of a reward function.

Args:
func: The reward function to weight
weight: The weight to apply to the reward

Returns:
A new function that applies the weight to the reward
"""

@wraps(func)
def weighted_reward(*args, **kwargs):
rewards = func(*args, **kwargs)
return [r * weight for r in rewards]

return weighted_reward
47 changes: 0 additions & 47 deletions tests/test_grpo.py

This file was deleted.

41 changes: 0 additions & 41 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from open_r1.rewards import (
accuracy_reward,
create_weighted_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand Down Expand Up @@ -77,35 +76,6 @@ def test_multiple_completions(self):
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)

def test_weighted_reward(self):
"""Test create_weighted_reward with different weights."""
# Test with weight = 2.0
completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]]
base_reward_func = format_reward
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)

base_rewards = base_reward_func(completion)
weighted_rewards = weighted_reward_func(completion)

self.assertEqual(base_rewards[0], 1.0)
self.assertEqual(weighted_rewards[0], 2.0)

# Test with weight = 0.5
weighted_reward_func = create_weighted_reward(base_reward_func, 0.5)
weighted_rewards = weighted_reward_func(completion)
self.assertEqual(weighted_rewards[0], 0.5)

# Test with multiple completions
completions = [
[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}],
[{"content": "Invalid format"}],
]
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)
weighted_rewards = weighted_reward_func(completions)

self.assertEqual(weighted_rewards[0], 2.0)
self.assertEqual(weighted_rewards[1], 0.0)

def test_cosine_scaled_reward(self):
"""Test cosine_scaled_reward with various cases."""
# Test parameters
Expand Down Expand Up @@ -141,17 +111,6 @@ def test_format_reward_specific_multiline(self):
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)

def test_weighted_reward_preserves_name(self):
"""Test that create_weighted_reward preserves the original function name. Important for logging."""
base_reward_func = format_reward
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)

self.assertEqual(
weighted_reward_func.__name__,
base_reward_func.__name__,
"Weighted reward function should preserve the original function name",
)

def test_same_length_responses(self):
"""Test len_reward when all responses have the same length."""
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
Expand Down