Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPOTrainer adds support for OpenAI API-compatible servers to models that generate samples #2901

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ZYM66
Copy link

@ZYM66 ZYM66 commented Feb 19, 2025

What does this PR do?

The original generative model is loaded on a single GPU, which becomes very slow when the generation length reaches 4096 or more. Therefore, I suggest loading the generation model externally to leverage multi-GPU support.

Motivation Behind This Feature:

The motivation for this feature stems from performance limitations when using the original vllm generative model, which currently loads on a single GPU. This becomes a bottleneck, especially when the generation length reaches 4096 tokens or more, significantly slowing down the process. By utilizing multiple GPUs, we can distribute the workload more efficiently and drastically improve performance, especially for long-generation tasks.

This feature is crucial for my project, as faster generation times are essential. I believe it could also benefit the broader community by enhancing the scalability and efficiency of model inference on multi-GPU setups.

Requested Feature:

The feature I am requesting involves loading the generative model outside of the current single-GPU setup in order to leverage multi-GPU capabilities. This would improve performance for tasks involving large generation lengths, making the library more efficient and scalable for demanding use cases.

Code Snippet:

GRPOTrainer __init__

elif self.args.use_openai_compatible_server:
    api_endpoint = args.api_endpoint
    api_key = args.api_key

    openai_serving_client = openai.OpenAI(base_url=api_endpoint, api_key=api_key, )
    # set the openai logger to ERROR level to avoid mess log information
    logging.getLogger("openai").setLevel(logging.ERROR)
    logging.getLogger("httpx").setLevel(logging.ERROR)
    self.ref_model_name = args.ref_model_name

    self.ref_llm = partial(openai_serving_client.chat.completions.create,
            model=args.ref_model_name,
            max_tokens=self.max_completion_length,
            temperature=args.temperature,
    )

GRPOTrainer _prepare_inputs

elif self.args.use_openai_compatible_server:
    completions = []
    # don't use any chattemplate, because the server have load it.
    for prompt in prompts:
        # request server
        response = self.ref_llm(messages=prompt)
        completion_text = response.choices[0].message.content
        completion_tokens = self.processing_class.encode(completion_text, add_special_tokens=False)
        completions.append(completion_tokens)

    completion_ids = completions
    completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
    completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
    prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

Fixes # (issue)

The issue mentioned in: #2887

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. If additional documentation improvements are needed, I would be happy to contribute.

@ZYM66
Copy link
Author

ZYM66 commented Feb 19, 2025

Additionally, you can set these keywords to use this PR
image

@qgallouedec
Copy link
Member

Cc @edbeeching

@ZYM66 ZYM66 changed the title Add support for OpenAI API-compatible server reference models. GRPOTrainer add support for OpenAI API-compatible server reference models. Feb 19, 2025
@ZYM66 ZYM66 changed the title GRPOTrainer add support for OpenAI API-compatible server reference models. GRPOTrainer adds support for OpenAI API-compatible servers to models that generate samples Feb 20, 2025
@ZYM66
Copy link
Author

ZYM66 commented Feb 20, 2025

Cc @edbeeching

I would like to clarify this PR, as I am not a native English speaker and there might be some errors in my original description.

The default code loads the vLLM generative model on a single GPU. When training the model, other GPUs must wait for the single GPU to complete its task, causing delays. In this PR, I have added a new optional feature that allows using an external API for completion, instead of relying solely on the local vLLM implementation.

Thanks!

@XZ-X
Copy link
Contributor

XZ-X commented Feb 20, 2025

I might not fully understand it, but I don't see how is the external openai compatible model updated during training?

The original slow implementation loads the most updated weights to vLLM at each step before generating responses.

@ZYM66
Copy link
Author

ZYM66 commented Feb 20, 2025

I might not fully understand it, but I don't see how is the external openai compatible model updated during training?

The original slow implementation loads the most updated weights to vLLM at each step before generating responses.

Hmm, you're right. This code doesn't update the vLLM server model weights in real time. I'm currently looking for ways to address this issue.
I've now changed this PR to a draft.

@ZYM66 ZYM66 marked this pull request as draft February 20, 2025 03:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants