Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Adds repetition penalty reward" #267

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
reasoning_steps_reward,
)
from open_r1.rewards import accuracy_reward, format_reward, get_cosine_scaled_reward, reasoning_steps_reward
from open_r1.utils.callbacks import get_callbacks
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

Expand Down Expand Up @@ -86,15 +80,6 @@ class GRPOScriptArguments(ScriptArguments):
metadata={"help": "Maximum length for scaling"},
)

repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)


SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
Expand Down Expand Up @@ -154,10 +139,6 @@ def main(script_args, training_args, model_args):
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
48 changes: 0 additions & 48 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,51 +148,3 @@ def cosine_scaled_reward(completions, solution, **kwargs):
return rewards

return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")

if max_penalty == 0:
return 0

def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])

def repetition_penalty_reward(completions, *args, **kwargs):
"""
reward function the penalizes repetitions
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

Args:
completions: List of model completions
solution: List of ground truth solutions

This function is parameterized by the following arguments:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
"""

rewards = []
for completion in completions:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue

ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1

scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
return rewards

return repetition_penalty_reward
147 changes: 1 addition & 146 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import unittest

from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
reasoning_steps_reward,
)
from open_r1.rewards import accuracy_reward, format_reward, get_cosine_scaled_reward, reasoning_steps_reward


class TestRewards(unittest.TestCase):
Expand Down Expand Up @@ -111,144 +105,5 @@ def test_format_reward_specific_multiline(self):
self.assertEqual(rewards[0], 1.0)


class TestRepetitionPenaltyReward(unittest.TestCase):
def test_positive_max_penalty_raises_value_error(self):
with self.assertRaises(ValueError):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)

def test_zero_max_penalty_returns_zero(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=0.0)
self.assertEqual(reward_fn, 0)

def test_no_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = ["this is a test sentence"]
solution = [] # Solution is not used in the reward calculation
rewards = reward_fn(completions, solution)
self.assertEqual(rewards, [0.0])

def test_full_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = ["this this this this this"]
solution = []
rewards = reward_fn(completions, solution)
# (1 - 1/4) * -1 = -0.75
self.assertEqual(rewards, [-0.75])

def test_partial_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [
"this is a this is a test"
] # 2-grams: (this, is), (is, a), (a, this), (this, is), (is, a), (a, test)
solution = []
rewards = reward_fn(completions, solution)
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
# (1 - 4/6) * -1 = -1/3 = -0.3333...
self.assertAlmostEqual(rewards[0], -1 / 3)

def test_multiple_completions(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = ["this is a test", "test test test test"]
solution = []
rewards = reward_fn(completions, solution)
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
self.assertAlmostEqual(rewards[0], 0.0)
self.assertAlmostEqual(rewards[1], -0.25)

def test_empty_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [""]
solution = []
rewards = reward_fn(completions, solution)
self.assertEqual(rewards, [0.0])

def test_different_ngram_size(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
completions = [
"this is a this is a test"
] # 3-grams:(this, is, a) (is, a, this) (a, this, is) (this, is, a) (is, a, test)
solution = []
rewards = reward_fn(completions, solution)
# Unique 3-grams: (this, is, a), (is, a, this), (a, this, is), (is, a, test) = 4. Total 3-grams: 5
# (1 - 4/5) * -2 = -0.4
self.assertAlmostEqual(rewards[0], -0.4)

def test_mixed_case(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = ["This is A Test", "this IS a test"]
solution = []
rewards = reward_fn(completions, solution)
# both completions should produce the same reward, because the text gets lowercased
self.assertAlmostEqual(rewards[0], rewards[1])

def test_one_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["word"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])

def test_two_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["two words"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])

def test_three_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["three different words"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])

def test_three_word_repetition_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["word word word word"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [-0.5])

def test_four_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["one two one two"]
solutions = []
rewards = reward_fn(completions, solutions)
# ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.
self.assertEqual(rewards, [0.0])

def test_five_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = ["A B C A B"]
solutions = []
rewards = reward_fn(completions, solutions)
# (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0
self.assertEqual(rewards, [0.0])

def test_six_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["A B C A B C"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [-0.25])

def test_long_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["A B C A B C E F G A B C A B C"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertAlmostEqual(rewards[0], -0.3846, places=4)

def test_long_completion_without_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["A B C D E F G H I J K L"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])


if __name__ == "__main__":
unittest.main()