Skip to content

Commit

Permalink
Revert "Weighted reward functions (#213)"
Browse files Browse the repository at this point in the history
This reverts commit fbea532.
  • Loading branch information
kashif authored Feb 13, 2025
1 parent fbea532 commit 3f346ae
Show file tree
Hide file tree
Showing 7 changed files with 2 additions and 152 deletions.
6 changes: 0 additions & 6 deletions recipes/DeepSeek-R1-Distill-Qwen-7B/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
6 changes: 0 additions & 6 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
6 changes: 0 additions & 6 deletions recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
28 changes: 2 additions & 26 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import torch
Expand All @@ -28,7 +27,6 @@
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
create_weighted_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand All @@ -51,8 +49,6 @@ class GRPOScriptArguments(ScriptArguments):
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
reward_weights (`list[float]` or `None`, *optional*):
List of weights for each reward function. If not provided, defaults to 1.0 for each reward function.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Expand All @@ -71,12 +67,6 @@ class GRPOScriptArguments(ScriptArguments):
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
"help": "List of weights for each reward function. If not provided, defaults to 1.0 for each function."
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
Expand All @@ -98,17 +88,6 @@ class GRPOScriptArguments(ScriptArguments):
metadata={"help": "Maximum length for scaling"},
)

def __post_init__(self):
# If no weights were provided, default to 1.0 for each reward function
if self.reward_weights is None:
self.reward_weights = [1.0] * len(self.reward_funcs)
# If weights were provided, validate the length
elif len(self.reward_weights) != len(self.reward_funcs):
raise ValueError(
f"Number of reward weights ({len(self.reward_weights)}: {self.reward_weights}) must match "
f"number of reward functions ({len(self.reward_funcs)}: {self.reward_funcs})"
)

repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
Expand Down Expand Up @@ -168,7 +147,7 @@ def main(script_args, training_args, model_args):
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Create weighted reward functions
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
Expand All @@ -186,10 +165,7 @@ def main(script_args, training_args, model_args):
),
"length": len_reward,
}
reward_funcs = [
create_weighted_reward(REWARD_FUNCS_REGISTRY[func], weight)
for func, weight in zip(script_args.reward_funcs, script_args.reward_weights)
]
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

# Format into conversation
def make_conversation(example):
Expand Down
20 changes: 0 additions & 20 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import math
import re
from functools import wraps
from typing import Dict

from latex2sympy2_extended import NormalizationConfig
Expand Down Expand Up @@ -272,22 +271,3 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
return rewards

return repetition_penalty_reward


def create_weighted_reward(func, weight):
"""Create a weighted version of a reward function.
Args:
func: The reward function to weight
weight: The weight to apply to the reward
Returns:
A new function that applies the weight to the reward
"""

@wraps(func)
def weighted_reward(*args, **kwargs):
rewards = func(*args, **kwargs)
return [r * weight for r in rewards]

return weighted_reward
47 changes: 0 additions & 47 deletions tests/test_grpo.py

This file was deleted.

41 changes: 0 additions & 41 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from open_r1.rewards import (
accuracy_reward,
create_weighted_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand Down Expand Up @@ -77,35 +76,6 @@ def test_multiple_completions(self):
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)

def test_weighted_reward(self):
"""Test create_weighted_reward with different weights."""
# Test with weight = 2.0
completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]]
base_reward_func = format_reward
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)

base_rewards = base_reward_func(completion)
weighted_rewards = weighted_reward_func(completion)

self.assertEqual(base_rewards[0], 1.0)
self.assertEqual(weighted_rewards[0], 2.0)

# Test with weight = 0.5
weighted_reward_func = create_weighted_reward(base_reward_func, 0.5)
weighted_rewards = weighted_reward_func(completion)
self.assertEqual(weighted_rewards[0], 0.5)

# Test with multiple completions
completions = [
[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}],
[{"content": "Invalid format"}],
]
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)
weighted_rewards = weighted_reward_func(completions)

self.assertEqual(weighted_rewards[0], 2.0)
self.assertEqual(weighted_rewards[1], 0.0)

def test_cosine_scaled_reward(self):
"""Test cosine_scaled_reward with various cases."""
# Test parameters
Expand Down Expand Up @@ -141,17 +111,6 @@ def test_format_reward_specific_multiline(self):
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)

def test_weighted_reward_preserves_name(self):
"""Test that create_weighted_reward preserves the original function name. Important for logging."""
base_reward_func = format_reward
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)

self.assertEqual(
weighted_reward_func.__name__,
base_reward_func.__name__,
"Weighted reward function should preserve the original function name",
)

def test_same_length_responses(self):
"""Test len_reward when all responses have the same length."""
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
Expand Down

0 comments on commit 3f346ae

Please sign in to comment.