From a82e5541cf6d63fbc90bab4092ee3b34eb81a532 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 12 Feb 2025 20:01:53 +0100 Subject: [PATCH 1/2] use dense rep penalty --- src/open_r1/rewards.py | 40 ++++++++++++++++++++++++++-------------- tests/test_rewards.py | 14 +++++++------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index bec3d11c..1112acec 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -150,7 +150,7 @@ def cosine_scaled_reward(completions, solution, **kwargs): return cosine_scaled_reward -def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): +def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, only_start: bool = False): """ Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373. Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py @@ -162,9 +162,8 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): if max_penalty > 0: raise ValueError(f"max_penalty {max_penalty} should not be positive") - def zipngram(text: str, ngram_size: int): - words = text.lower().split() - return zip(*[words[i:] for i in range(ngram_size)]) + def zipngram(tokens: list[str], ngram_size: int): + return zip(*[tokens[i:] for i in range(ngram_size)]) def repetition_penalty_reward(completions, **kwargs) -> float: """ @@ -178,22 +177,35 @@ def repetition_penalty_reward(completions, **kwargs) -> float: contents = [completion[0]["content"] for completion in completions] rewards = [] for completion in contents: - if completion == "": - rewards.append(0.0) - continue - if len(completion.split()) < ngram_size: + if completion == "" or len(completion.split()) < ngram_size: rewards.append(0.0) continue + # Find repeated n-grams and their positions + words = completion.lower().split() + repeated_positions = [] ngrams = set() - total = 0 - for ng in zipngram(completion, ngram_size): + + for start_idx, ng in enumerate(zipngram(words, ngram_size)): + if ng in ngrams: + repeated_positions.append(start_idx) ngrams.add(ng) - total += 1 - scaling = 1 - len(ngrams) / total - reward = scaling * max_penalty - rewards.append(reward) + # Calculate word-level penalties + word_penalties = [0.0] * len(words) + curr_end_idx = -1 + + for start_idx in repeated_positions: + if not only_start or start_idx > curr_end_idx: + # Apply penalty to each token in the repeated n-gram + for i in range(start_idx, start_idx + ngram_size): + word_penalties[i] = max_penalty + curr_end_idx = start_idx + ngram_size + + # Average the word-level penalties for the final reward + reward = sum(word_penalties) / len(word_penalties) if word_penalties else 0.0 + rewards.append(float(reward)) + return rewards return repetition_penalty_reward diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 7f0cbfa9..33636f28 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -130,7 +130,7 @@ def test_full_repetition(self): rewards = reward_fn(completions) # (1 - 1/4) * -1 = -0.75 - self.assertEqual(rewards, [-0.75]) + self.assertEqual(rewards, [-0.8]) def test_partial_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) @@ -139,7 +139,7 @@ def test_partial_repetition(self): rewards = reward_fn(completions) # Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total # (1 - 4/6) * -1 = -1/3 = -0.3333... - self.assertAlmostEqual(rewards[0], -1 / 3) + self.assertAlmostEqual(rewards[0], -0.4285714, places=4) def test_multiple_completions(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5) @@ -152,7 +152,7 @@ def test_multiple_completions(self): # Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0 # Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25 self.assertAlmostEqual(rewards[0], 0.0) - self.assertAlmostEqual(rewards[1], -0.25) + self.assertAlmostEqual(rewards[1], -0.375) def test_empty_completion(self): reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) @@ -165,7 +165,7 @@ def test_different_ngram_size(self): completions = [[{"content": "this is a this is a test"}]] rewards = reward_fn(completions) - self.assertAlmostEqual(rewards[0], -0.4) + self.assertAlmostEqual(rewards[0], -0.8571428, places=4) def test_mixed_case(self): reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) @@ -204,7 +204,7 @@ def test_three_word_repetition_completion(self): completions = [[{"content": "word word word word"}]] rewards = reward_fn(completions) - self.assertEqual(rewards, [-0.5]) + self.assertEqual(rewards, [-0.75]) def test_four_word_completion_with_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) @@ -227,13 +227,13 @@ def test_six_word_completion_with_repetition(self): completions = [[{"content": "A B C A B C"}]] rewards = reward_fn(completions) - self.assertEqual(rewards, [-0.25]) + self.assertEqual(rewards, [-0.5]) def test_long_completion_with_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "A B C A B C E F G A B C A B C"}]] rewards = reward_fn(completions) - self.assertAlmostEqual(rewards[0], -0.3846, places=4) + self.assertAlmostEqual(rewards[0], -0.6) def test_long_completion_without_repetition(self): reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) From 473cad8eb315861259bcb323dad42371b03b07fe Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 12 Feb 2025 20:03:44 +0100 Subject: [PATCH 2/2] formatting --- tests/test_rewards.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 33636f28..92882821 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -165,7 +165,7 @@ def test_different_ngram_size(self): completions = [[{"content": "this is a this is a test"}]] rewards = reward_fn(completions) - self.assertAlmostEqual(rewards[0], -0.8571428, places=4) + self.assertAlmostEqual(rewards[0], -0.8571428, places=4) def test_mixed_case(self): reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)