From d12886da7fa7f84b8b78b94e3a54e10a5a98bd4d Mon Sep 17 00:00:00 2001 From: JamesHujy <48405323+JamesHujy@users.noreply.github.com> Date: Sat, 8 Feb 2025 09:46:44 -0500 Subject: [PATCH] fix format reward (#238) * fix format reward * failing test * add \s* between and tag to handle multilines --------- Co-authored-by: Kashif Rasul --- src/open_r1/rewards.py | 4 ++-- tests/test_rewards.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index f92da19f..f33b5641 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -51,9 +51,9 @@ def accuracy_reward(completions, solution, **kwargs): def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" - pattern = r"^.*?.*?$" + pattern = r"^.*?\s*.*?$" completion_contents = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, content) for content in completion_contents] + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] return [1.0 if match else 0.0 for match in matches] diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 0ff8a106..473a0760 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -97,6 +97,13 @@ def test_cosine_scaled_reward(self): rewards = get_cosine_scaled_reward(**test_params)(completion, [solution]) self.assertAlmostEqual(rewards[0], expected_reward, places=2) + def test_format_reward_specific_multiline(self): + """Test format_reward with a specific multiline input.""" + inputs = "\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n\n9" + completion = [[{"content": inputs}]] + rewards = format_reward(completion) + self.assertEqual(rewards[0], 1.0) + if __name__ == "__main__": unittest.main()