From 7431386dfa8598675505a6ed5f5824182280216a Mon Sep 17 00:00:00 2001 From: Yiming Zheng Date: Tue, 18 Feb 2025 22:58:46 +0800 Subject: [PATCH 1/3] Add reference model OpenAI compatible API server support --- trl/trainer/grpo_config.py | 18 ++++++++++++++++++ trl/trainer/grpo_trainer.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index cd6cc91748..32051cefff 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -249,3 +249,21 @@ class GRPOConfig(TrainingArguments): default=False, metadata={"help": "Whether to log the completions during training."}, ) + + # reference model using api server + use_openai_compatible_server: bool = field( + default=False, + metadata={"help": "Whether to use openai compatible server."}, + ) + api_endpoint: str = field( + default=None, + metadata={"help": "Openai compatible server API endpoint."}, + ) + api_key: str = field( + default=None, + metadata={"help": "Openai compatible server API key."}, + ) + ref_model_name: str = field( + default=None, + metadata={"help": "Openai compatible server reference model name."}, + ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1211a453fc..7273b397d6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -18,6 +18,8 @@ from collections import defaultdict from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch +from functools import partial +import logging import torch import torch.utils.data @@ -41,6 +43,7 @@ ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import is_peft_available +import openai from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..import_utils import is_vllm_available @@ -423,6 +426,23 @@ def data_collator(features): # No data collation is needed in GRPO # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we # synchronize all processes after vLLM has been fully initialized. self.accelerator.wait_for_everyone() + + elif self.args.use_openai_compatible_server: + api_endpoint = args.api_endpoint + api_key = args.api_key + + openai_serving_client = openai.OpenAI(base_url=api_endpoint, api_key=api_key, ) + # set the openai logger to ERROR level to avoid mess log information + logging.getLogger("openai").setLevel(logging.ERROR) + logging.getLogger("httpx").setLevel(logging.ERROR) + self.ref_model_name = args.ref_model_name + + self.ref_llm = partial(openai_serving_client.chat.completions.create, + model=args.ref_model_name, + max_tokens=self.max_completion_length, + temperature=args.temperature, + ) + else: self.generation_config = GenerationConfig( max_new_tokens=self.max_completion_length, @@ -555,6 +575,21 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + + elif self.args.use_openai_compatible_server: + completions = [] + for prompt in prompts_text: + # request server + response = self.ref_llm(messages=[{"role": "user", "content": prompt}]) + completion_text = response.choices[0].message.content + completion_tokens = self.processing_class.encode(completion_text, add_special_tokens=False) + completions.append(completion_tokens) + + completion_ids = completions + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: # Regular generation path with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: From d27686478414d13221fb80dffa01be56a81a7943 Mon Sep 17 00:00:00 2001 From: Yiming Zheng Date: Wed, 19 Feb 2025 10:22:10 +0800 Subject: [PATCH 2/3] Add annotation --- trl/trainer/grpo_config.py | 2 +- trl/trainer/grpo_trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 32051cefff..15d90ec1f8 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -250,7 +250,7 @@ class GRPOConfig(TrainingArguments): metadata={"help": "Whether to log the completions during training."}, ) - # reference model using api server + # reference model using OpenAI compatible api server use_openai_compatible_server: bool = field( default=False, metadata={"help": "Whether to use openai compatible server."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7273b397d6..0287628e45 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -578,7 +578,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s elif self.args.use_openai_compatible_server: completions = [] - for prompt in prompts_text: + # don't use any chattemplate, because the server have load it. + for prompt in prompts: # request server response = self.ref_llm(messages=[{"role": "user", "content": prompt}]) completion_text = response.choices[0].message.content From 40c7bc1efff69b55eaf698f8ac50e03a25ae85be Mon Sep 17 00:00:00 2001 From: Yiming Zheng Date: Wed, 19 Feb 2025 14:30:54 +0800 Subject: [PATCH 3/3] Use Default Chat template --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0287628e45..6a6899a37a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -581,7 +581,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # don't use any chattemplate, because the server have load it. for prompt in prompts: # request server - response = self.ref_llm(messages=[{"role": "user", "content": prompt}]) + response = self.ref_llm(messages=prompt) completion_text = response.choices[0].message.content completion_tokens = self.processing_class.encode(completion_text, add_special_tokens=False) completions.append(completion_tokens)