Skip to content

Commit

Permalink
adds liger kernel support
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Feb 21, 2025
1 parent 382a0c7 commit 420d72a
Showing 1 changed file with 34 additions and 23 deletions.
57 changes: 34 additions & 23 deletions src/open_r1/trainers/faster_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

# great reference: https://github.com/vllm-project/vllm/issues/11400

import contextlib
import functools
import gc
import math
import os
Expand All @@ -24,7 +26,7 @@
from multiprocessing import reduction
from typing import Callable, Optional, Union
from unittest.mock import patch

from transformers.utils import is_liger_kernel_available
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -52,17 +54,10 @@
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
from trl.trainer.utils import pad, selective_log_softmax
from vllm import LLM, SamplingParams


RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
import contextlib
import functools
import time

from transformers import is_wandb_available

import wandb
from vllm import LLM, SamplingParams


if is_wandb_available():
Expand Down Expand Up @@ -141,7 +136,7 @@ class FastGRPOConfig(trl.GRPOConfig):
metadata={"help": ("The project to store runs under.")},
)
remote_gen_model_url: str = field(
default="26.0.164.45",
default="26.0.165.24",
)
remote_gen_model_port: str = field(
default="30010",
Expand Down Expand Up @@ -206,11 +201,20 @@ def __init__(
model_str = model
model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs)
# offload to cpu
ref_model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs).to("cpu")
ref_model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs) #.to("cpu")

self.model = model
self.ref_model = ref_model

if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=self.model)
_apply_liger_kernel_to_instance(model=self.ref_model)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
"Please install it with `pip install liger-kernel`"
)
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
Expand Down Expand Up @@ -286,19 +290,34 @@ def data_collator(features): # No data collation is needed in GRPO
drop_last=True,
)
torch.manual_seed(args.seed)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
self.model = self._enable_gradient_checkpointing(self.model, self.args)
self.model, self.optimizer, self.dataloader = self.accelerator.prepare(
self.model, self.optimizer, self.dataloader
)
device = self.accelerator.device

self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
# connect to a remote sglang model
if self.args.remote_gen_model_url is None:
self.sglang_job_launcher.wait_for_server()
self.args.remote_gen_model_url = self.sglang_job_launcher.get_remote_ip()
self.remote_model = RemoteModel(
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
)
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: FastGRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
model.gradient_checkpointing_enable()

gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()

return model
def print_gpu_memory_usage(self):
if torch.cuda.is_available():
gpu_memory_allocated = torch.cuda.memory_allocated()
Expand Down Expand Up @@ -485,9 +504,6 @@ def mini_batch_collator(examples):
device = self.accelerator.device
for step in range(start_step, self.total_steps_per_device + 1):
batch = next(iter_dataloader)

self.ref_model = self.ref_model.to("cpu")

batch = self._prepare_batch(batch)

# TODO: log completions, rewards, etc
Expand All @@ -506,7 +522,6 @@ def mini_batch_collator(examples):
kls = []

with profiling_context(self, "train_step"):
# at this point there should be no tensors in GPU memory
for mini_batch in mini_batch_dataloader:
prompt_completion_ids = mini_batch["prompt_completion_ids"]
attention_mask = mini_batch["attention_mask"][
Expand All @@ -515,16 +530,12 @@ def mini_batch_collator(examples):
logits_to_keep = (
prompt_completion_ids.size(1) - 1
) # TODO, fix padding with the optimization from the original grpo trainer
# torch.cuda.empty_cache()

# get the ref logprobs, this could also be done at the batch prepare step to avoid too much model unloading
self.ref_model = self.ref_model.to(device)
# self.ref_model = self.ref_model.to(device)
with torch.inference_mode():
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
self.ref_model = self.ref_model.to("cpu")
# torch.cuda.empty_cache()

with self.accelerator.accumulate(self.model):
per_token_logps = self._get_per_token_logps(
Expand Down

0 comments on commit 420d72a

Please sign in to comment.