Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GRPO Trainer support for third-party accelerators #2836

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ji-huazhong
Copy link
Contributor

@ji-huazhong ji-huazhong commented Feb 12, 2025

What does this PR do?

This PR makes GRPO Trainer out of the box on Ascend NPUs.

cc @qgallouedec @lewtun

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ji-huazhong ji-huazhong force-pushed the npu branch 3 times, most recently from 48739bc to 86c5569 Compare February 12, 2025 09:05
@Superskyyy
Copy link
Contributor

Since the vllm device patch is growing larger. It might be wise to move them into a utility module instead. Wdyt.

@ji-huazhong ji-huazhong changed the title Add GRPO Trainer support for Ascend NPU Add GRPO Trainer support for third-party accelerators Feb 13, 2025
@baymax591
Copy link

This PR helps a lot, I hope it can speed up the integration

ji-huazhong and others added 2 commits February 14, 2025 21:18
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@ji-huazhong
Copy link
Contributor Author

I think this PR is ready to be merged 🤗 @qgallouedec

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

qgallouedec commented Feb 14, 2025

Can you make sure sure to run make precommit to apply the style 🙏

@ji-huazhong
Copy link
Contributor Author

make precommit is successfully executed locally

@lynnzhiyun
Copy link

Hi @ji-huazhong, Thank you for your excellent work! This PR has been incredibly helpful in enabling me to train models using GRPO on the NPU smoothly.

I want to ask if this PR is ready to be merged and I'd be extremely grateful if it could be done promptly.

cc @qgallouedec

# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
if (
vllm_device.split(":")[0] == f"{device_type}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should always be the case, no?

Copy link
Contributor Author

@ji-huazhong ji-huazhong Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qgallouedec

Thanks for your review. In line 387,I maintained the same logic as orignal conditional statement,only repalcing the 'cuda' type with more general type.

I believe the check for device availability here is necessary. However, perhaps we could split this conditional statement into two parts.

First, we check if the device type matches, and only after this condition is met do we check if the device index is within the range of available devices. wdyt?

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@ji-huazhong
Copy link
Contributor Author

ji-huazhong commented Feb 19, 2025

asciicast
I did a test on Ascend NPU using the grpo script provided by open-r1, it works 🤗

Since training grpo for one step takes a long time, only the output of the first 4 steps is shown here, and then I just press ctrl-c to exit.

@ji-huazhong
Copy link
Contributor Author

Hi @kashif, the failing test case seems unrelated to this PR. Could you take a look? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants