diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index bec3d11c..96f66e5c 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -51,9 +51,9 @@ def accuracy_reward(completions, solution, **kwargs): def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" - pattern = r"^.*?\s*.*?$" + pattern = re.compile(r"^.*?\s*.*?$", re.DOTALL | re.MULTILINE) completion_contents = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] + matches = [pattern.match(content) for content in completion_contents] return [1.0 if match else 0.0 for match in matches] @@ -66,9 +66,9 @@ def reasoning_steps_reward(completions, **kwargs): \n\* - matches bullet points with asterisks First,|Second,|Next,|Finally, - matches transition words """ - pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)" + pattern = re.compile(r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)") completion_contents = [completion[0]["content"] for completion in completions] - matches = [len(re.findall(pattern, content)) for content in completion_contents] + matches = [len(pattern.findall(content)) for content in completion_contents] # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward return [min(1.0, count / 3) for count in matches] @@ -161,11 +161,7 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): """ if max_penalty > 0: raise ValueError(f"max_penalty {max_penalty} should not be positive") - - def zipngram(text: str, ngram_size: int): - words = text.lower().split() - return zip(*[words[i:] for i in range(ngram_size)]) - + def repetition_penalty_reward(completions, **kwargs) -> float: """ reward function the penalizes repetitions @@ -181,16 +177,13 @@ def repetition_penalty_reward(completions, **kwargs) -> float: if completion == "": rewards.append(0.0) continue - if len(completion.split()) < ngram_size: + words = completion.lower().split() + if len(words) < ngram_size: rewards.append(0.0) continue - - ngrams = set() - total = 0 - for ng in zipngram(completion, ngram_size): - ngrams.add(ng) - total += 1 - + + ngrams = set(zip(*[words[i:] for i in range(ngram_size)])) + total = len(words) - ngram_size + 1 scaling = 1 - len(ngrams) / total reward = scaling * max_penalty rewards.append(reward)