-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Paxml patchlist with TE and config patches for improved perf #225
Conversation
ashors1
commented
Sep 12, 2023
•
edited
Loading
edited
- adds Transformer Engine support to Pax
- updates GPU configs and default XLA flags for improved performance
Rosetta pax build/test: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6239209930 Above workflow should be sufficient for this change |
Looks like the build passed (yay!), but only some of the MGMN tests passed. Unit tests are expected to fail. Re-run just to make sure it wasn't a one-off: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6241478583 |
To address the unit test failure: #253 |
The failures don't appear to be one-offs. I see errors like this in the logs:
Investigating now |
Google reverted their commit that broke TE: google/paxml@1696411. |
Merging failing test isn't great. If we do that, can we disable the test until it is fixed? |
TP=8 looks like it failed, but it's not showing up in the metrics pytest check. Let me look into why |
So I've reminded myself that actually two tests are not measured for perf/loss tests because they were failing some time ago. The two tests were:
It looks like (2) is working now, so created an issue to track adding these back in: #272 But to @nouiz 's comment, the test is actually already omitted, so I think this is okay to merge |