diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index f92da19f..f33b5641 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -51,9 +51,9 @@ def accuracy_reward(completions, solution, **kwargs): def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" - pattern = r"^.*?.*?$" + pattern = r"^.*?\s*.*?$" completion_contents = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, content) for content in completion_contents] + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] return [1.0 if match else 0.0 for match in matches] diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 0ff8a106..473a0760 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -97,6 +97,13 @@ def test_cosine_scaled_reward(self): rewards = get_cosine_scaled_reward(**test_params)(completion, [solution]) self.assertAlmostEqual(rewards[0], expected_reward, places=2) + def test_format_reward_specific_multiline(self): + """Test format_reward with a specific multiline input.""" + inputs = "\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n\n9" + completion = [[{"content": inputs}]] + rewards = format_reward(completion) + self.assertEqual(rewards[0], 1.0) + if __name__ == "__main__": unittest.main()