Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding comparison for different fp8 matmuls #36

Merged
merged 1 commit into from
Oct 3, 2024
Merged

Conversation

drisspg
Copy link
Owner

@drisspg drisspg commented Sep 27, 2024

Summary

Comparing different PerTensor and PerRow scaling kernels

fp8_performance_heatmaps

Eager Results

fp8_kernel_comparison
fp8_kernel_comparison

Max Autotune for scaled-fp8

This doesnt seem right....

fp8_compare_compile

Traces

@drisspg drisspg force-pushed the fp8-kernel-experiments branch 13 times, most recently from ac31f09 to f7d68be Compare October 1, 2024 18:17
save_path (Optional[str], optional): Path to save the results. Defaults to None.
"""
torch.random.manual_seed(123)
configs = get_configs_varying_k()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, any interesting insights from varying M and N in this script?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M Sweeps

image
imageimage
image
image

N Sweeps

Are pretty similiar
image

Soo seems like for smaller sizes scaled_mm w/ cutlass in general is much better when less compute bound and (with current non finetuned kernel config) the Triton Impl can overtake

@drisspg drisspg force-pushed the fp8-kernel-experiments branch 2 times, most recently from 3836e3c to 06aaf0b Compare October 2, 2024 21:27
@drisspg drisspg force-pushed the fp8-kernel-experiments branch from 06aaf0b to ba0b51b Compare October 2, 2024 21:31
@drisspg drisspg merged commit b4a789c into main Oct 3, 2024
5 checks passed
@drisspg drisspg deleted the fp8-kernel-experiments branch October 3, 2024 19:06
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Oct 11, 2024
# Summary

I started to explore the performance of _scaled_mm against a triton-based persistent TMA kernel for RowWise scaling.
There are more details here: drisspg/transformer_nuggets#36

It clearly showed that where was some room for improvement on larger problem sizes compared to triton's performance. Note that the triton kernel only has a 128x128x128 Tile shape, where scaled_mm has a 64, 128, 128 tile shape which we use for smaller problem sizes which may explain some of the perf delta for at smaller shapes.

This led to seeing if we can improve our triton codegen lowering  for _scaled_mm (I think we should still do this: #137517).

In the meantime @Chillee  suggested I make sure swizziling is set for the large matmul shapes

This PR makes sure that we increase the max_swizzle_size for the large matmuls.

## Performance
Note* Red means triton based tma beats _scaled_mm blue means _scaled_mm is faster

On Nighlty W/ Triton at (2ef33c6c4c3)
![swizzle_tst_8_full_nightly_heatmaps](https://github.com/user-attachments/assets/e92af19b-4e79-4126-b9d0-da039da5363b)

You can see that as M,K,N increase there is a clear win for the Triton Persistent TMA.

After this PR:

![swizzle_tst_8_full_heatmaps](https://github.com/user-attachments/assets/472068b3-45c2-43f8-84d3-b116da7898d5)

For example w/ this change(power limited gpu)

M=16384  K=16384  N=16384
TFlops Before :`985.49`
TFlops After: `1304.69`
Pull Request resolved: #137681
Approved by: https://github.com/eqy
jackzhxng pushed a commit to pytorch/pytorch that referenced this pull request Oct 16, 2024
# Summary

I started to explore the performance of _scaled_mm against a triton-based persistent TMA kernel for RowWise scaling.
There are more details here: drisspg/transformer_nuggets#36

It clearly showed that where was some room for improvement on larger problem sizes compared to triton's performance. Note that the triton kernel only has a 128x128x128 Tile shape, where scaled_mm has a 64, 128, 128 tile shape which we use for smaller problem sizes which may explain some of the perf delta for at smaller shapes.

This led to seeing if we can improve our triton codegen lowering  for _scaled_mm (I think we should still do this: #137517).

In the meantime @Chillee  suggested I make sure swizziling is set for the large matmul shapes

This PR makes sure that we increase the max_swizzle_size for the large matmuls.

## Performance
Note* Red means triton based tma beats _scaled_mm blue means _scaled_mm is faster

On Nighlty W/ Triton at (2ef33c6c4c3)
![swizzle_tst_8_full_nightly_heatmaps](https://github.com/user-attachments/assets/e92af19b-4e79-4126-b9d0-da039da5363b)

You can see that as M,K,N increase there is a clear win for the Triton Persistent TMA.

After this PR:

![swizzle_tst_8_full_heatmaps](https://github.com/user-attachments/assets/472068b3-45c2-43f8-84d3-b116da7898d5)

For example w/ this change(power limited gpu)

M=16384  K=16384  N=16384
TFlops Before :`985.49`
TFlops After: `1304.69`
Pull Request resolved: #137681
Approved by: https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants