Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: optimise data loading and .zip creation #1193

Merged
merged 12 commits into from
Jan 15, 2025

Conversation

olupton
Copy link
Collaborator

@olupton olupton commented Dec 10, 2024

Some rough measurements on vanilla jax-nccl-test and 8xH100:

Profile collection, whole execution: 52s (nsys), 58s (nsys-jax with this PR), 1m5s (nsys-jax without this PR)
Profile collection, restricted range: 46s (nsys), 50s (nsys-jax with this PR), 55s (nsys-jax without this PR)
Communication analysis, whole execution: 1.1s (with this PR), 2.1s (without this PR)
Communication analysis, restricted range: 1.0s (with this PR), 1.7s (without this PR)

The differences are more pronounced on larger workloads with more activity.

The two bigger changes are:

  • Convert .csv to .parquet as part of nsys-jax to avoid compressing .csv with Python's lzma module, which is slow and single-threaded. This speeds up nsys-jax and subsequent data-loading.
  • A new algorithm for calculating the hidden/exposed time of communication kernels when loading profile data -- essentially this adds a fast pandas-friendly pass to identify [most] non-overlapping kernels and skip running the [relatively slow and pandas-unfriendly] overlap calculation on them. This also removes an assumption that there is no compute-compute overlap.

Otherwise there are some tweaks to pandas usage and minor reorganisations to make Python profiles more informative, and minor bugfixes in the example Jupyter notebook.

@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch 6 times, most recently from 74a3d94 to d8056e0 Compare December 10, 2024 15:58
@olupton olupton requested a review from gspschmid December 11, 2024 09:56
@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch from 61357e9 to 4bf5946 Compare January 10, 2025 11:09
Copy link
Contributor

@gspschmid gspschmid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a few questions inline.

@olupton olupton merged commit 20d3137 into main Jan 15, 2025
127 of 139 checks passed
@olupton olupton deleted the olupton/nsys-jax-python-opt branch January 15, 2025 15:51
@gspschmid
Copy link
Contributor

Here's what a single sweep over all the comm/compute intervals might look like: https://gist.github.com/gspschmid/478b056b42a5c81e00999617b18d5b71

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants