Skip to content

Commit

Permalink
[GRPO] add cosine reward (#206)
Browse files Browse the repository at this point in the history
* add cosine reward

* fix merge

* fix typo

* fix check
  • Loading branch information
kashif authored Feb 7, 2025
1 parent e8c2673 commit 250ab46
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := src
check_dirs := src tests

style:
ruff format --line-length 119 --target-version py310 $(check_dirs) setup.py
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"hf_transfer>=0.1.4",
"huggingface-hub[cli]>=0.19.2,<1.0",
"isort>=5.12.0",
"latex2sympy2_extended>=1.0.6",
"math-verify>=0.5.2",
"liger_kernel==0.5.2",
"lighteval @ git+https://github.com/huggingface/lighteval.git@86f62259f105ae164f655e0b91c92a823a742724#egg=lighteval[math]",
"math-verify==0.5.2", # Used for math verification in grpo
Expand Down Expand Up @@ -96,6 +98,8 @@ def deps_list(*pkgs):
deps["deepspeed"],
deps["hf_transfer"],
deps["huggingface-hub"],
deps["latex2sympy2_extended"],
deps["math-verify"],
deps["liger_kernel"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["safetensors"],
Expand Down
52 changes: 48 additions & 4 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import sys
from dataclasses import dataclass, field
from functools import partial

import datasets
import torch
Expand All @@ -25,7 +26,7 @@
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import GRPOConfig
from open_r1.rewards import REWARD_FUNCS_REGISTRY
from open_r1.rewards import accuracy_reward, cosine_scaled_reward, format_reward, reasoning_steps_reward
from open_r1.utils.callbacks import get_callbacks
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

Expand All @@ -40,15 +41,45 @@ class GRPOScriptArguments(ScriptArguments):
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values are dynamically populated from REWARD_FUNCS_REGISTRY.
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
"""

reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
default_factory=lambda: ["accuracy", "format", "reasoning_steps", "cosine"],
metadata={
"help": f"List of reward functions. Possible values: {', '.join(REWARD_FUNCS_REGISTRY.keys())}"
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)


SYSTEM_PROMPT = (
Expand Down Expand Up @@ -98,6 +129,19 @@ def main(script_args, training_args, model_args):
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": partial(
cosine_scaled_reward,
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

# Format into conversation
Expand Down
78 changes: 73 additions & 5 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Reward functions for GRPO training."""

import math
import re

from latex2sympy2_extended import NormalizationConfig
Expand Down Expand Up @@ -73,8 +74,75 @@ def reasoning_steps_reward(completions, **kwargs):
return [min(1.0, count / 3) for count in matches]


REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
}
def cosine_scaled_reward(
completions,
solution,
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
min_value_correct: float = 0.5,
max_value_correct: float = 1.0,
max_len: int = 1000,
**kwargs,
):
"""Reward function that scales based on completion length using a cosine schedule.
Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.
Args:
completions: List of model completions
solution: List of ground truth solutions
min_value_wrong: Minimum reward for wrong answers
max_value_wrong: Maximum reward for wrong answers
min_value_correct: Minimum reward for correct answers
max_value_correct: Maximum reward for correct answers
max_len: Maximum length for scaling
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []

for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
continue

answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)

is_correct = verify(answer_parsed, gold_parsed)
gen_len = len(content)

# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)

if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong

reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))

return rewards
77 changes: 44 additions & 33 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import unittest
from open_r1.rewards import accuracy_reward, format_reward, reasoning_steps_reward

from open_r1.rewards import accuracy_reward, cosine_scaled_reward, format_reward, reasoning_steps_reward


class TestRewards(unittest.TestCase):
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
solution = [r"\frac{63}{400}"]

rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 1.0)

def test_accuracy_reward_wrong_answer(self):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]

rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)

Expand All @@ -32,9 +33,9 @@ def test_format_reward_incorrect(self):
"<answer>Only answer</answer>",
"No tags at all",
"<think>Missing closing</think><answer>Missing closing",
"<think>Wrong order</answer><answer>Wrong order</think>"
"<think>Wrong order</answer><answer>Wrong order</think>",
]

for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = format_reward(completion)
Expand All @@ -44,48 +45,58 @@ def test_reasoning_steps_reward(self):
"""Test reasoning_steps_reward with various formats."""
test_cases = [
# Full credit cases (3 or more steps)
(
"Step 1: First step\nStep 2: Second step\nStep 3: Third step",
1.0
),
(
"First, we do this.\nSecond, we do that.\nFinally, we conclude.",
1.0
),
("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0),
("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0),
# Partial credit cases (less than 3 steps)
(
"Step 1: Only step",
1/3
),
(
"First, we do this.\nFinally, we conclude.",
2/3
),
("Step 1: Only step", 1 / 3),
("First, we do this.\nFinally, we conclude.", 2 / 3),
# No credit case
(
"Just plain text without any clear steps",
0.0
)
("Just plain text without any clear steps", 0.0),
]

for content, expected_reward in test_cases:
completion = [[{"content": content}]]
rewards = reasoning_steps_reward(completion)
self.assertAlmostEqual(rewards[0], expected_reward)

def test_multiple_completions(self):
"""Test handling multiple completions at once."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}]
]
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]

rewards = accuracy_reward(completions, solutions)
self.assertEqual(len(rewards), 2)
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)

def test_cosine_scaled_reward(self):
"""Test cosine_scaled_reward with various cases."""
# Test parameters
test_params = {
"min_value_wrong": -1.0,
"max_value_wrong": -0.5,
"min_value_correct": 0.5,
"max_value_correct": 1.0,
"max_len": 100,
}

test_cases = [
# Correct answers with different lengths
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 20, 0.943), # Short correct answer
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 80, 0.547), # Long correct answer
# Wrong answers with different lengths
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 20, -0.942), # Short wrong answer
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 80, -0.547), # Long wrong answer
]

for content, solution, content_len, expected_reward in test_cases:
# Pad content to desired length
padded_content = content + " " * (content_len - len(content))
completion = [[{"content": padded_content}]]

rewards = cosine_scaled_reward(completion, [solution], **test_params)
self.assertAlmostEqual(rewards[0], expected_reward, places=2)


if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

0 comments on commit 250ab46

Please sign in to comment.