Skip to content

Commit

Permalink
clean up, style
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Feb 10, 2025
1 parent 9e1fae0 commit a315815
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
3 changes: 0 additions & 3 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")

if max_penalty == 0:
return 0

def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
Expand Down
26 changes: 13 additions & 13 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def test_no_repetition(self):
def test_full_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this this this this this"}]]

rewards = reward_fn(completions)
# (1 - 1/4) * -1 = -0.75
self.assertEqual(rewards, [-0.75])

def test_partial_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this is a this is a test"}]]

rewards = reward_fn(completions)
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
# (1 - 4/6) * -1 = -1/3 = -0.3333...
Expand All @@ -150,8 +150,8 @@ def test_multiple_completions(self):
completions = [
[{"content": "this is a test"}],
[{"content": "test test test test"}],
]
]

rewards = reward_fn(completions)
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
Expand All @@ -167,7 +167,7 @@ def test_empty_completion(self):
def test_different_ngram_size(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
completions = [[{"content": "this is a this is a test"}]]

rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.4)

Expand All @@ -177,7 +177,7 @@ def test_mixed_case(self):
[{"content": "This is A Test"}],
[{"content": "this IS a test"}],
]

rewards = reward_fn(completions)
# both completions should produce the same reward, because the text gets lowercased
self.assertAlmostEqual(rewards[0], rewards[1])
Expand All @@ -192,44 +192,44 @@ def test_one_word_completion(self):
def test_two_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "two words"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])

def test_three_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "three different words"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])

def test_three_word_repetition_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "word word word word"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.5])

def test_four_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "one two one two"}]]

rewards = reward_fn(completions)
# ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.
self.assertEqual(rewards, [0.0])

def test_five_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = [[{"content": "A B C A B"}]]

rewards = reward_fn(completions)
# (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0
self.assertEqual(rewards, [0.0])

def test_six_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C A B C"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.25])

Expand All @@ -242,7 +242,7 @@ def test_long_completion_with_repetition(self):
def test_long_completion_without_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C D E F G H I J K L"}]]

rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])

Expand Down

0 comments on commit a315815

Please sign in to comment.