Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔁 🦈 Support iterative GRPO #2700

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def test_training(self, config_name):
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
sync_ref_model=True,
ref_model_mixup_alpha=0.5,
ref_model_sync_steps=1,
shirinyamani marked this conversation as resolved.
Show resolved Hide resolved
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand Down
35 changes: 35 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ class GRPOConfig(TrainingArguments):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
between the current policy and the previous reference policy during updates. The reference policy is
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
must set `sync_ref_model=True`.
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
"""

# Parameters that control the model and reference model
Expand Down Expand Up @@ -174,3 +187,25 @@ class GRPOConfig(TrainingArguments):
default=0.04,
metadata={"help": "KL coefficient."},
)
sync_ref_model: bool = field(
default=False,
metadata={
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
},
)
ref_model_mixup_alpha: float = field(
default=0.9,
metadata={
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
"previous reference policy during updates. The reference policy is updated according to the equation: "
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
},
)
ref_model_sync_steps: int = field(
default=64,
metadata={
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)
3 changes: 3 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad

Expand Down Expand Up @@ -207,6 +208,8 @@ def __init__(
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
if args.sync_ref_model:
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

# Processing class
if processing_class is None:
Expand Down
Loading