diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index a6ce10c9..f7e270ef 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -162,9 +162,6 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): if max_penalty > 0: raise ValueError(f"max_penalty {max_penalty} should not be positive") - if max_penalty == 0: - return 0 - def zipngram(text: str, ngram_size: int): words = text.lower().split() return zip(*[words[i:] for i in range(ngram_size)]) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index ccf41e18..0ae015d1 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -131,7 +131,7 @@ def test_no_repetition(self): def test_full_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) completions = [[{"content": "this this this this this"}]] - + rewards = reward_fn(completions) # (1 - 1/4) * -1 = -0.75 self.assertEqual(rewards, [-0.75]) @@ -139,7 +139,7 @@ def test_full_repetition(self): def test_partial_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) completions = [[{"content": "this is a this is a test"}]] - + rewards = reward_fn(completions) # Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total # (1 - 4/6) * -1 = -1/3 = -0.3333... @@ -150,8 +150,8 @@ def test_multiple_completions(self): completions = [ [{"content": "this is a test"}], [{"content": "test test test test"}], - ] - + ] + rewards = reward_fn(completions) # Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0 # Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25 @@ -167,7 +167,7 @@ def test_empty_completion(self): def test_different_ngram_size(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0) completions = [[{"content": "this is a this is a test"}]] - + rewards = reward_fn(completions) self.assertAlmostEqual(rewards[0], -0.4) @@ -177,7 +177,7 @@ def test_mixed_case(self): [{"content": "This is A Test"}], [{"content": "this IS a test"}], ] - + rewards = reward_fn(completions) # both completions should produce the same reward, because the text gets lowercased self.assertAlmostEqual(rewards[0], rewards[1]) @@ -192,28 +192,28 @@ def test_one_word_completion(self): def test_two_word_completion(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "two words"}]] - + rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_three_word_completion(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "three different words"}]] - + rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_three_word_repetition_completion(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "word word word word"}]] - + rewards = reward_fn(completions) self.assertEqual(rewards, [-0.5]) def test_four_word_completion_with_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "one two one two"}]] - + rewards = reward_fn(completions) # ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1. self.assertEqual(rewards, [0.0]) @@ -221,7 +221,7 @@ def test_four_word_completion_with_repetition(self): def test_five_word_completion_with_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5) completions = [[{"content": "A B C A B"}]] - + rewards = reward_fn(completions) # (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0 self.assertEqual(rewards, [0.0]) @@ -229,7 +229,7 @@ def test_five_word_completion_with_repetition(self): def test_six_word_completion_with_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "A B C A B C"}]] - + rewards = reward_fn(completions) self.assertEqual(rewards, [-0.25]) @@ -242,7 +242,7 @@ def test_long_completion_with_repetition(self): def test_long_completion_without_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "A B C D E F G H I J K L"}]] - + rewards = reward_fn(completions) self.assertEqual(rewards, [0.0])