Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language specific code format reward #377

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
accuracy_reward,
code_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
Expand Down Expand Up @@ -61,12 +62,14 @@ class GRPOScriptArguments(ScriptArguments):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
"""

reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'code', 'code_format'"
},
)
cosine_min_value_wrong: float = field(
Expand Down Expand Up @@ -97,6 +100,13 @@ class GRPOScriptArguments(ScriptArguments):
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
code_language: str = field(
default="python",
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash"],
},
)


def main(script_args, training_args, model_args):
Expand Down Expand Up @@ -163,6 +173,7 @@ def main(script_args, training_args, model_args):
),
"length": len_reward,
"code": code_reward,
"code_format": get_code_format_reward(language=script_args.code_language),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
16 changes: 16 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,19 @@ def evaluate_code(code, test_cases):
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(completions)
return rewards


def get_code_format_reward(language: str = "python"):
"""Format reward function specifically for code responses.

Args:
language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
"""
pattern = rf"^<think>.*?</think>\s*<answer>.*?```{language}\n.*?```.*?</answer>$"

def code_format_reward(completions, **kwargs):
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]

return code_format_reward
79 changes: 79 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
Expand Down Expand Up @@ -313,5 +314,83 @@ def test_long_completion_without_repetition(self):
self.assertEqual(rewards, [0.0])


class TestCodeFormat(unittest.TestCase):
def test_correct_python_format(self):
"""Test code format reward with correct Python format."""
completion = [
[
{
"content": "<think>Let's solve this\nStep 1: First step</think>\n<answer>```python\ndef hello():\n print('world')\n```</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)

def test_incorrect_formats(self):
"""Test code format reward with various incorrect formats."""
incorrect_formats = [
# Missing think/answer tags
"```python\ndef hello():\n print('world')\n```",
# Missing code block
"<think>Some thinking</think><answer>Just plain text</answer>",
# Wrong language
"<think>Analysis</think><answer>```javascript\nconsole.log('hello');\n```</answer>",
# Missing language identifier
"<think>Analysis</think><answer>```\ndef hello(): pass\n```</answer>",
# Wrong order of tags
"<answer>```python\ndef hello(): pass\n```</answer><think>Analysis</think>",
]

reward_fn = get_code_format_reward(language="python")
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 0.0)

def test_multiple_code_blocks(self):
"""Test format reward with multiple code blocks in think and answer sections."""
completion = [
[
{
"content": "<think>Here's an example:\n```python\nx = 1\n```\nNow the solution:</think>\n<answer>```python\ndef solution():\n return 42\n```</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)

def test_different_languages(self):
"""Test code format reward with different programming languages."""
completion = [
[{"content": "<think>Analysis</think><answer>```javascript\nconsole.log('hello');\n```</answer>"}]
]

# Test with JavaScript
js_reward_fn = get_code_format_reward(language="javascript")
rewards = js_reward_fn(completion)
self.assertEqual(rewards[0], 1.0)

# Same completion should fail for Python
py_reward_fn = get_code_format_reward(language="python")
rewards = py_reward_fn(completion)
self.assertEqual(rewards[0], 0.0)

def test_multiline_code(self):
"""Test format reward with complex multiline code blocks."""
completion = [
[
{
"content": "<think>Here's the analysis</think>\n<answer>```python\nclass Solution:\n def __init__(self):\n self.value = 42\n \n def get_value(self):\n return self.value\n```</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)


if __name__ == "__main__":
unittest.main()