Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPOTrainer adds support for OpenAI API-compatible servers to models that generate samples #2901

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,21 @@ class GRPOConfig(TrainingArguments):
default=False,
metadata={"help": "Whether to log the completions during training."},
)

# reference model using OpenAI compatible api server
use_openai_compatible_server: bool = field(
default=False,
metadata={"help": "Whether to use openai compatible server."},
)
api_endpoint: str = field(
default=None,
metadata={"help": "Openai compatible server API endpoint."},
)
api_key: str = field(
default=None,
metadata={"help": "Openai compatible server API key."},
)
ref_model_name: str = field(
default=None,
metadata={"help": "Openai compatible server reference model name."},
)
36 changes: 36 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from collections import defaultdict
from typing import Any, Callable, Optional, Sized, Union
from unittest.mock import patch
from functools import partial
import logging

import torch
import torch.utils.data
Expand All @@ -41,6 +43,7 @@
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
import openai

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..import_utils import is_vllm_available
Expand Down Expand Up @@ -423,6 +426,23 @@ def data_collator(features): # No data collation is needed in GRPO
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()

elif self.args.use_openai_compatible_server:
api_endpoint = args.api_endpoint
api_key = args.api_key

openai_serving_client = openai.OpenAI(base_url=api_endpoint, api_key=api_key, )
# set the openai logger to ERROR level to avoid mess log information
logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
self.ref_model_name = args.ref_model_name

self.ref_llm = partial(openai_serving_client.chat.completions.create,
model=args.ref_model_name,
max_tokens=self.max_completion_length,
temperature=args.temperature,
)

else:
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
Expand Down Expand Up @@ -555,6 +575,22 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

elif self.args.use_openai_compatible_server:
completions = []
# don't use any chattemplate, because the server have load it.
for prompt in prompts:
# request server
response = self.ref_llm(messages=prompt)
completion_text = response.choices[0].message.content
completion_tokens = self.processing_class.encode(completion_text, add_special_tokens=False)
completions.append(completion_tokens)

completion_ids = completions
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
Expand Down