Skip to content

Commit

Permalink
fix format reward (#238)
Browse files Browse the repository at this point in the history
* fix format reward

* failing test

* add \s* between </think> and <answer> tag to handle multilines

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
JamesHujy and kashif authored Feb 8, 2025
1 parent f5f0b55 commit d12886d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def accuracy_reward(completions, solution, **kwargs):

def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]


Expand Down
7 changes: 7 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def test_cosine_scaled_reward(self):
rewards = get_cosine_scaled_reward(**test_params)(completion, [solution])
self.assertAlmostEqual(rewards[0], expected_reward, places=2)

def test_format_reward_specific_multiline(self):
"""Test format_reward with a specific multiline input."""
inputs = "<think>\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n</think>\n<answer>9</answer>"
completion = [[{"content": inputs}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit d12886d

Please sign in to comment.