diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py
index bec3d11c..96f66e5c 100644
--- a/src/open_r1/rewards.py
+++ b/src/open_r1/rewards.py
@@ -51,9 +51,9 @@ def accuracy_reward(completions, solution, **kwargs):
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
- pattern = r"^.*?\s*.*?$"
+ pattern = re.compile(r"^.*?\s*.*?$", re.DOTALL | re.MULTILINE)
completion_contents = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
+ matches = [pattern.match(content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
@@ -66,9 +66,9 @@ def reasoning_steps_reward(completions, **kwargs):
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
- pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
+ pattern = re.compile(r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)")
completion_contents = [completion[0]["content"] for completion in completions]
- matches = [len(re.findall(pattern, content)) for content in completion_contents]
+ matches = [len(pattern.findall(content)) for content in completion_contents]
# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]
@@ -161,11 +161,7 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
-
- def zipngram(text: str, ngram_size: int):
- words = text.lower().split()
- return zip(*[words[i:] for i in range(ngram_size)])
-
+
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
reward function the penalizes repetitions
@@ -181,16 +177,13 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
if completion == "":
rewards.append(0.0)
continue
- if len(completion.split()) < ngram_size:
+ words = completion.lower().split()
+ if len(words) < ngram_size:
rewards.append(0.0)
continue
-
- ngrams = set()
- total = 0
- for ng in zipngram(completion, ngram_size):
- ngrams.add(ng)
- total += 1
-
+
+ ngrams = set(zip(*[words[i:] for i in range(ngram_size)]))
+ total = len(words) - ngram_size + 1
scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)