Skip to content

Commit

Permalink
Revert "Adds repetition penalty reward (#263)" (#267)
Browse files Browse the repository at this point in the history
This reverts commit d57f2ed.
  • Loading branch information
edbeeching authored Feb 10, 2025
1 parent d57f2ed commit 486f7d4
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 214 deletions.
21 changes: 1 addition & 20 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
reasoning_steps_reward,
)
from open_r1.rewards import accuracy_reward, format_reward, get_cosine_scaled_reward, reasoning_steps_reward
from open_r1.utils.callbacks import get_callbacks
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

Expand Down Expand Up @@ -86,15 +80,6 @@ class GRPOScriptArguments(ScriptArguments):
metadata={"help": "Maximum length for scaling"},
)

repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)


SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
Expand Down Expand Up @@ -154,10 +139,6 @@ def main(script_args, training_args, model_args):
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
48 changes: 0 additions & 48 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,51 +148,3 @@ def cosine_scaled_reward(completions, solution, **kwargs):
return rewards

return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")

if max_penalty == 0:
return 0

def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])

def repetition_penalty_reward(completions, *args, **kwargs):
"""
reward function the penalizes repetitions
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
completions: List of model completions
solution: List of ground truth solutions
This function is parameterized by the following arguments:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
"""

rewards = []
for completion in completions:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue

ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1

scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
return rewards

return repetition_penalty_reward
147 changes: 1 addition & 146 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import unittest

from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
reasoning_steps_reward,
)
from open_r1.rewards import accuracy_reward, format_reward, get_cosine_scaled_reward, reasoning_steps_reward


class TestRewards(unittest.TestCase):
Expand Down Expand Up @@ -111,144 +105,5 @@ def test_format_reward_specific_multiline(self):
self.assertEqual(rewards[0], 1.0)


class TestRepetitionPenaltyReward(unittest.TestCase):
def test_positive_max_penalty_raises_value_error(self):
with self.assertRaises(ValueError):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)

def test_zero_max_penalty_returns_zero(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=0.0)
self.assertEqual(reward_fn, 0)

def test_no_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = ["this is a test sentence"]
solution = [] # Solution is not used in the reward calculation
rewards = reward_fn(completions, solution)
self.assertEqual(rewards, [0.0])

def test_full_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = ["this this this this this"]
solution = []
rewards = reward_fn(completions, solution)
# (1 - 1/4) * -1 = -0.75
self.assertEqual(rewards, [-0.75])

def test_partial_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [
"this is a this is a test"
] # 2-grams: (this, is), (is, a), (a, this), (this, is), (is, a), (a, test)
solution = []
rewards = reward_fn(completions, solution)
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
# (1 - 4/6) * -1 = -1/3 = -0.3333...
self.assertAlmostEqual(rewards[0], -1 / 3)

def test_multiple_completions(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = ["this is a test", "test test test test"]
solution = []
rewards = reward_fn(completions, solution)
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
self.assertAlmostEqual(rewards[0], 0.0)
self.assertAlmostEqual(rewards[1], -0.25)

def test_empty_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [""]
solution = []
rewards = reward_fn(completions, solution)
self.assertEqual(rewards, [0.0])

def test_different_ngram_size(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
completions = [
"this is a this is a test"
] # 3-grams:(this, is, a) (is, a, this) (a, this, is) (this, is, a) (is, a, test)
solution = []
rewards = reward_fn(completions, solution)
# Unique 3-grams: (this, is, a), (is, a, this), (a, this, is), (is, a, test) = 4. Total 3-grams: 5
# (1 - 4/5) * -2 = -0.4
self.assertAlmostEqual(rewards[0], -0.4)

def test_mixed_case(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = ["This is A Test", "this IS a test"]
solution = []
rewards = reward_fn(completions, solution)
# both completions should produce the same reward, because the text gets lowercased
self.assertAlmostEqual(rewards[0], rewards[1])

def test_one_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["word"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])

def test_two_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["two words"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])

def test_three_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["three different words"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])

def test_three_word_repetition_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["word word word word"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [-0.5])

def test_four_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["one two one two"]
solutions = []
rewards = reward_fn(completions, solutions)
# ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.
self.assertEqual(rewards, [0.0])

def test_five_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = ["A B C A B"]
solutions = []
rewards = reward_fn(completions, solutions)
# (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0
self.assertEqual(rewards, [0.0])

def test_six_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["A B C A B C"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [-0.25])

def test_long_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["A B C A B C E F G A B C A B C"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertAlmostEqual(rewards[0], -0.3846, places=4)

def test_long_completion_without_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = ["A B C D E F G H I J K L"]
solutions = []
rewards = reward_fn(completions, solutions)
self.assertEqual(rewards, [0.0])


if __name__ == "__main__":
unittest.main()

0 comments on commit 486f7d4

Please sign in to comment.