Skip to content

Commit

Permalink
Add E2B code interpreter reward function (#364)
Browse files Browse the repository at this point in the history
* Add stuff

* Make it kind of work

* Add more stuff

* Add fix for parse

* Fix

* Refactor

* Clean up

* Fix config

* Fix sys

* Add SFT config

* Use min rate

* Fix eval

* Add base model

* Add s1k

* Disable eval

* Fix

* Add import checker

* Fix importer

* Fix

* Tune config

* Tune

* Fix

* Fix save

* Tuen beta

* Remove configs

* Fix vLLM

* Fix

* Add note

* Add doc

* doc

* Fix

* Tune lr

* Add command
  • Loading branch information
lewtun authored Feb 19, 2025
1 parent 740a7a4 commit d76ecc1
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 1 deletion.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,43 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con

Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.

#### 👨‍💻 Training with a code interpreter

We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we use [E2B](https://e2b.dev) sandboxes, which are fast and cheap to run. To use this reward function, first install the necessary dependencies:

```shell
uv pip install -e '.[code]'
```

Then create a `.env` file and place an API token from E2B within it:

```
E2B_API_KEY="e2b_xxx"
```

Then make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):

```python
{
"language": "python",
"test_cases": [
{
"input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
"output": "1\n3 \n-1\n0\n\n2\n1 2 \n",
"type": "stdin_stdout",
}
],
}
```

For example, to train a smol model on Python problems, run:

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```

### Launching jobs on a Slurm cluster

If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
Expand Down
57 changes: 57 additions & 0 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python-10k
dataset_configs:
- default
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

# GRPO trainer config
beta: 0.01
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO
hub_strategy: every_save
learning_rate: 5.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 1024
max_completion_length: 2048
max_steps: 500
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- code
- format
reward_weights:
- 1.0
- 0.1
save_strategy: "steps"
save_steps: 50
save_total_limit: 1
seed: 42
temperature: 1.0
warmup_ratio: 0.03
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"datasets>=3.2.0",
"deepspeed==0.15.4",
"distilabel[vllm,ray,openai]>=1.5.2",
"e2b-code-interpreter>=1.0.5",
"einops>=0.8.0",
"flake8>=6.0.0",
"flash_attn>=2.7.4.post1",
Expand All @@ -60,6 +61,7 @@
"parameterized>=0.9.0",
"peft>=0.14.0",
"pytest",
"python-dotenv",
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
Expand Down Expand Up @@ -88,6 +90,7 @@ def deps_list(*pkgs):
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["train"] = deps_list("flash_attn")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
extras["eval"] = deps_list("lighteval", "math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["train"]

Expand Down
2 changes: 2 additions & 0 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
code_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand Down Expand Up @@ -161,6 +162,7 @@ def main(script_args, training_args, model_args):
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": code_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
87 changes: 87 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
"""Reward functions for GRPO training."""

import json
import math
import re
from typing import Dict

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify

from .utils import is_e2b_available


if is_e2b_available():
from dotenv import load_dotenv
from e2b_code_interpreter import Sandbox

load_dotenv()


def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
Expand Down Expand Up @@ -271,3 +281,80 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
return rewards

return repetition_penalty_reward


def extract_code(completion: str) -> str:
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
matches = pattern.findall(completion)
extracted_answer = matches[-1] if len(matches) >= 1 else ""
return extracted_answer


def code_reward(completions, **kwargs) -> list[float]:
"""Reward function that evaluates code snippets using the E2B code interpreter.
Assumes the dataset contains a `verification_info` column with test cases.
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this reward function. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)

rewards = []
# TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
try:
"""Returns a reward function that evaluates code snippets in a sandbox."""
evaluation_script_template = """
import subprocess
import json
def evaluate_code(code, test_cases):
passed = 0
total = len(test_cases)
exec_timeout = 5
for case in test_cases:
process = subprocess.run(
["python3", "-c", code],
input=case["input"],
text=True,
capture_output=True,
timeout=exec_timeout
)
if process.returncode != 0: # Error in execution
continue
output = process.stdout.strip()
if output.strip() == case["output"].strip():
passed += 1
success_rate = (passed / total)
return success_rate
code_snippet = {code}
test_cases = json.loads({test_cases})
evaluate_code(code_snippet, test_cases)
"""
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
verification_info = kwargs["verification_info"]
scripts = [
evaluation_script_template.format(
code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))
)
for code, info in zip(code_snippets, verification_info)
]
with Sandbox(timeout=30, request_timeout=3) as sbx:
for script in scripts:
execution = sbx.run_code(script, language=verification_info["language"])
try:
output = float(execution.text)
except (TypeError, ValueError):
output = 0.0
rewards.append(output)
except Exception as e:
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(completions)
return rewards
3 changes: 2 additions & 1 deletion src/open_r1/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .import_utils import is_e2b_available
from .model_utils import get_tokenizer


__all__ = ["get_tokenizer"]
__all__ = ["get_tokenizer", "is_e2b_available"]
23 changes: 23 additions & 0 deletions src/open_r1/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers.utils.import_utils import _is_package_available


# Use same as transformers.utils.import_utils
_e2b_available = _is_package_available("e2b")


def is_e2b_available() -> bool:
return _e2b_available

0 comments on commit d76ecc1

Please sign in to comment.