Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rewards] use dense rep penalty #296

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def cosine_scaled_reward(completions, solution, **kwargs):
return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, only_start: bool = False):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Expand All @@ -162,9 +162,8 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")

def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def zipngram(tokens: list[str], ngram_size: int):
return zip(*[tokens[i:] for i in range(ngram_size)])

def repetition_penalty_reward(completions, **kwargs) -> float:
"""
Expand All @@ -178,22 +177,35 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
contents = [completion[0]["content"] for completion in completions]
rewards = []
for completion in contents:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
if completion == "" or len(completion.split()) < ngram_size:
rewards.append(0.0)
continue

# Find repeated n-grams and their positions
words = completion.lower().split()
repeated_positions = []
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):

for start_idx, ng in enumerate(zipngram(words, ngram_size)):
if ng in ngrams:
repeated_positions.append(start_idx)
ngrams.add(ng)
total += 1

scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
# Calculate word-level penalties
word_penalties = [0.0] * len(words)
curr_end_idx = -1

for start_idx in repeated_positions:
if not only_start or start_idx > curr_end_idx:
# Apply penalty to each token in the repeated n-gram
for i in range(start_idx, start_idx + ngram_size):
word_penalties[i] = max_penalty
curr_end_idx = start_idx + ngram_size

# Average the word-level penalties for the final reward
reward = sum(word_penalties) / len(word_penalties) if word_penalties else 0.0
rewards.append(float(reward))

return rewards

return repetition_penalty_reward
14 changes: 7 additions & 7 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_full_repetition(self):

rewards = reward_fn(completions)
# (1 - 1/4) * -1 = -0.75
self.assertEqual(rewards, [-0.75])
self.assertEqual(rewards, [-0.8])

def test_partial_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
Expand All @@ -139,7 +139,7 @@ def test_partial_repetition(self):
rewards = reward_fn(completions)
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
# (1 - 4/6) * -1 = -1/3 = -0.3333...
self.assertAlmostEqual(rewards[0], -1 / 3)
self.assertAlmostEqual(rewards[0], -0.4285714, places=4)

def test_multiple_completions(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
Expand All @@ -152,7 +152,7 @@ def test_multiple_completions(self):
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
self.assertAlmostEqual(rewards[0], 0.0)
self.assertAlmostEqual(rewards[1], -0.25)
self.assertAlmostEqual(rewards[1], -0.375)

def test_empty_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
Expand All @@ -165,7 +165,7 @@ def test_different_ngram_size(self):
completions = [[{"content": "this is a this is a test"}]]

rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.4)
self.assertAlmostEqual(rewards[0], -0.8571428, places=4)

def test_mixed_case(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_three_word_repetition_completion(self):
completions = [[{"content": "word word word word"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.5])
self.assertEqual(rewards, [-0.75])

def test_four_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
Expand All @@ -227,13 +227,13 @@ def test_six_word_completion_with_repetition(self):
completions = [[{"content": "A B C A B C"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.25])
self.assertEqual(rewards, [-0.5])

def test_long_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C A B C E F G A B C A B C"}]]
rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.3846, places=4)
self.assertAlmostEqual(rewards[0], -0.6)

def test_long_completion_without_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
Expand Down