Skip to content

Commit

Permalink
Merge branch 'main' into len-reward
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Feb 13, 2025
2 parents 452edf2 + 80e7e7b commit 127811a
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 6 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u
Next, install vLLM:

```shell
uv pip install vllm==0.7.1 --link-mode=copy
uv pip install vllm==0.7.2 --link-mode=copy
```

This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
Expand Down Expand Up @@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
--per_device_train_batch_size=1 --num_train_epochs=5
```

If you also wish to override the Weights and Biases default settings, you can do so as follows:

```shell
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
```

> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
Expand All @@ -141,10 +149,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con

### GRPO

To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices:
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices:

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
```
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion slurm/evaluate.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ echo "Uploading details to Hugging Face Hub..."
DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \))
echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS"
TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S")
python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
python scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP

echo "Cleaning up ..."
rm -rf $OUTPUT_DIR
Expand Down
16 changes: 16 additions & 0 deletions src/open_r1/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig):
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)


@dataclass
Expand All @@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig):
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
6 changes: 5 additions & 1 deletion src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config


Expand Down Expand Up @@ -131,7 +132,7 @@ def main(script_args, training_args, model_args):
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
logger.info(f"Training parameters {training_args}")

# Check for last checkpoint
last_checkpoint = None
Expand All @@ -140,6 +141,9 @@ def main(script_args, training_args, model_args):
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

if "wandb" in training_args.report_to:
init_wandb_training(training_args)

# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

Expand Down
6 changes: 5 additions & 1 deletion src/open_r1/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from open_r1.configs import SFTConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import (
ModelConfig,
ScriptArguments,
Expand Down Expand Up @@ -88,7 +89,7 @@ def main(script_args, training_args, model_args):
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
logger.info(f"Training parameters {training_args}")

# Check for last checkpoint
last_checkpoint = None
Expand All @@ -97,6 +98,9 @@ def main(script_args, training_args, model_args):
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

if "wandb" in training_args.report_to:
init_wandb_training(training_args)

################
# Load datasets
################
Expand Down
11 changes: 11 additions & 0 deletions src/open_r1/utils/wandb_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os


def init_wandb_training(training_args):
"""
Helper function for setting up Weights & Biases logging tools.
"""
if training_args.wandb_entity is not None:
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
if training_args.wandb_project is not None:
os.environ["WANDB_PROJECT"] = training_args.wandb_project

0 comments on commit 127811a

Please sign in to comment.