diff --git a/README.md b/README.md index 48b767b5..07d2ddd3 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u Next, install vLLM: ```shell -uv pip install vllm==0.7.1 --link-mode=copy +uv pip install vllm==0.7.2 --link-mode=copy ``` This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend: @@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r --per_device_train_batch_size=1 --num_train_epochs=5 ``` +If you also wish to override the Weights and Biases default settings, you can do so as follows: + +```shell +accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \ + --config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml + --wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO +``` + > [!NOTE] > The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps. @@ -141,10 +149,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ### GRPO -To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices: +To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices: ```shell -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \ +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \ --num_processes=7 src/open_r1/grpo.py \ --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml ``` diff --git a/src/open_r1/utils/upload_details.py b/scripts/upload_details.py similarity index 100% rename from src/open_r1/utils/upload_details.py rename to scripts/upload_details.py diff --git a/slurm/evaluate.slurm b/slurm/evaluate.slurm index c659c0b3..da106f6b 100644 --- a/slurm/evaluate.slurm +++ b/slurm/evaluate.slurm @@ -81,7 +81,7 @@ echo "Uploading details to Hugging Face Hub..." DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \)) echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS" TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S") -python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP +python scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP echo "Cleaning up ..." rm -rf $OUTPUT_DIR diff --git a/src/open_r1/configs.py b/src/open_r1/configs.py index 57968b4b..3a6f6866 100644 --- a/src/open_r1/configs.py +++ b/src/open_r1/configs.py @@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig): ) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) @dataclass @@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig): ) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index 803a8da0..916be06e 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -34,6 +34,7 @@ reasoning_steps_reward, ) from open_r1.utils.callbacks import get_callbacks +from open_r1.utils.logging import init_wandb_training from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config @@ -131,7 +132,7 @@ def main(script_args, training_args, model_args): ) logger.info(f"Model parameters {model_args}") logger.info(f"Script parameters {script_args}") - logger.info(f"Data parameters {training_args}") + logger.info(f"Training parameters {training_args}") # Check for last checkpoint last_checkpoint = None @@ -140,6 +141,9 @@ def main(script_args, training_args, model_args): if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + if "wandb" in training_args.report_to: + init_wandb_training(training_args) + # Load the dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) diff --git a/src/open_r1/sft.py b/src/open_r1/sft.py index e8587d03..16791cd4 100644 --- a/src/open_r1/sft.py +++ b/src/open_r1/sft.py @@ -48,6 +48,7 @@ from open_r1.configs import SFTConfig from open_r1.utils.callbacks import get_callbacks +from open_r1.utils.logging import init_wandb_training from trl import ( ModelConfig, ScriptArguments, @@ -88,7 +89,7 @@ def main(script_args, training_args, model_args): ) logger.info(f"Model parameters {model_args}") logger.info(f"Script parameters {script_args}") - logger.info(f"Data parameters {training_args}") + logger.info(f"Training parameters {training_args}") # Check for last checkpoint last_checkpoint = None @@ -97,6 +98,9 @@ def main(script_args, training_args, model_args): if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + if "wandb" in training_args.report_to: + init_wandb_training(training_args) + ################ # Load datasets ################ diff --git a/src/open_r1/utils/wandb_logging.py b/src/open_r1/utils/wandb_logging.py new file mode 100644 index 00000000..13b55276 --- /dev/null +++ b/src/open_r1/utils/wandb_logging.py @@ -0,0 +1,11 @@ +import os + + +def init_wandb_training(training_args): + """ + Helper function for setting up Weights & Biases logging tools. + """ + if training_args.wandb_entity is not None: + os.environ["WANDB_ENTITY"] = training_args.wandb_entity + if training_args.wandb_project is not None: + os.environ["WANDB_PROJECT"] = training_args.wandb_project