Skip to content

Commit

Permalink
Language specific code format reward
Browse files Browse the repository at this point in the history
  • Loading branch information
zeenolife committed Feb 19, 2025
1 parent d76ecc1 commit 8b36ef1
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
accuracy_reward,
code_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
Expand Down Expand Up @@ -61,12 +62,14 @@ class GRPOScriptArguments(ScriptArguments):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
"""

reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'code', 'code_format'"
},
)
cosine_min_value_wrong: float = field(
Expand Down Expand Up @@ -97,6 +100,13 @@ class GRPOScriptArguments(ScriptArguments):
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
code_language: str = field(
default="python",
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash"],
},
)


def main(script_args, training_args, model_args):
Expand Down Expand Up @@ -163,6 +173,7 @@ def main(script_args, training_args, model_args):
),
"length": len_reward,
"code": code_reward,
"code_format": get_code_format_reward(language=script_args.code_language),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
16 changes: 16 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,19 @@ def evaluate_code(code, test_cases):
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(completions)
return rewards


def get_code_format_reward(language: str = "python"):
"""Format reward function specifically for code responses.
Args:
language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
"""
pattern = rf"^<think>.*?</think>\s*<answer>.*?```{language}\n.*?```.*?</answer>$"

def code_format_reward(completions, **kwargs):
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]

return code_format_reward
79 changes: 79 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
Expand Down Expand Up @@ -313,5 +314,83 @@ def test_long_completion_without_repetition(self):
self.assertEqual(rewards, [0.0])


class TestCodeFormat(unittest.TestCase):
def test_correct_python_format(self):
"""Test code format reward with correct Python format."""
completion = [
[
{
"content": "<think>Let's solve this\nStep 1: First step</think>\n<answer>```python\ndef hello():\n print('world')\n```</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)

def test_incorrect_formats(self):
"""Test code format reward with various incorrect formats."""
incorrect_formats = [
# Missing think/answer tags
"```python\ndef hello():\n print('world')\n```",
# Missing code block
"<think>Some thinking</think><answer>Just plain text</answer>",
# Wrong language
"<think>Analysis</think><answer>```javascript\nconsole.log('hello');\n```</answer>",
# Missing language identifier
"<think>Analysis</think><answer>```\ndef hello(): pass\n```</answer>",
# Wrong order of tags
"<answer>```python\ndef hello(): pass\n```</answer><think>Analysis</think>",
]

reward_fn = get_code_format_reward(language="python")
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 0.0)

def test_multiple_code_blocks(self):
"""Test format reward with multiple code blocks in think and answer sections."""
completion = [
[
{
"content": "<think>Here's an example:\n```python\nx = 1\n```\nNow the solution:</think>\n<answer>```python\ndef solution():\n return 42\n```</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)

def test_different_languages(self):
"""Test code format reward with different programming languages."""
completion = [
[{"content": "<think>Analysis</think><answer>```javascript\nconsole.log('hello');\n```</answer>"}]
]

# Test with JavaScript
js_reward_fn = get_code_format_reward(language="javascript")
rewards = js_reward_fn(completion)
self.assertEqual(rewards[0], 1.0)

# Same completion should fail for Python
py_reward_fn = get_code_format_reward(language="python")
rewards = py_reward_fn(completion)
self.assertEqual(rewards[0], 0.0)

def test_multiline_code(self):
"""Test format reward with complex multiline code blocks."""
completion = [
[
{
"content": "<think>Here's the analysis</think>\n<answer>```python\nclass Solution:\n def __init__(self):\n self.value = 42\n \n def get_value(self):\n return self.value\n```</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8b36ef1

Please sign in to comment.