Skip to content

Commit

Permalink
[ScaleMM] Add a shape dependent max_swizzle size (#137681)
Browse files Browse the repository at this point in the history
# Summary

I started to explore the performance of _scaled_mm against a triton-based persistent TMA kernel for RowWise scaling.
There are more details here: drisspg/transformer_nuggets#36

It clearly showed that where was some room for improvement on larger problem sizes compared to triton's performance. Note that the triton kernel only has a 128x128x128 Tile shape, where scaled_mm has a 64, 128, 128 tile shape which we use for smaller problem sizes which may explain some of the perf delta for at smaller shapes.

This led to seeing if we can improve our triton codegen lowering  for _scaled_mm (I think we should still do this: #137517).

In the meantime @Chillee  suggested I make sure swizziling is set for the large matmul shapes

This PR makes sure that we increase the max_swizzle_size for the large matmuls.

## Performance
Note* Red means triton based tma beats _scaled_mm blue means _scaled_mm is faster

On Nighlty W/ Triton at (2ef33c6c4c3)
![swizzle_tst_8_full_nightly_heatmaps](https://github.com/user-attachments/assets/e92af19b-4e79-4126-b9d0-da039da5363b)

You can see that as M,K,N increase there is a clear win for the Triton Persistent TMA.

After this PR:

![swizzle_tst_8_full_heatmaps](https://github.com/user-attachments/assets/472068b3-45c2-43f8-84d3-b116da7898d5)

For example w/ this change(power limited gpu)

M=16384  K=16384  N=16384
TFlops Before :`985.49`
TFlops After: `1304.69`
Pull Request resolved: #137681
Approved by: https://github.com/eqy
  • Loading branch information
drisspg authored and pytorchmergebot committed Oct 11, 2024
1 parent 4e30989 commit 1c71de5
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ void f8f8bf16_rowwise_impl(
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
at::Tensor out) {
at::Tensor out,
const int swizzle) {
int M = XQ.size(0);
int N = WQ.size(1);
int K = XQ.size(1);
Expand Down Expand Up @@ -276,6 +277,9 @@ void f8f8bf16_rowwise_impl(
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;

// Allocate workspace memory
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
Expand Down Expand Up @@ -309,7 +313,8 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
at::Tensor out) {
at::Tensor out,
const int swizzle) {
int M = XQ.size(0);
int N = WQ.size(1);

Expand All @@ -323,13 +328,13 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
ClusterShape,
/*PingPong=*/std::false_type,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
} else {
return f8f8bf16_rowwise_impl<
/*TileShape=*/cute::Shape<cute::_128, cute::_128, cute::_128>,
ClusterShape,
/*PingPong=*/std::true_type,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
}
}

Expand All @@ -346,23 +351,24 @@ void handle_transposition(
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
at::Tensor out) {
at::Tensor out,
const int swizzle=1) {
if constexpr (!Transposed::value) {
dispatch_fp8_rowwise_kernel_on_tile_size<
ClusterShape,
Transposed,
FastAccum,
DtypeA,
DtypeB,
DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out);
DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
} else {
dispatch_fp8_rowwise_kernel_on_tile_size<
ClusterShape,
Transposed,
FastAccum,
DtypeB,
DtypeA,
DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t());
DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle);
}
}

Expand Down Expand Up @@ -438,6 +444,14 @@ void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(
}

// General case for large tensors.

// Large M, N, k
if (M >= 4096 && N >= 4096 && N >= 4096) {
return handle_transposition<
/*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
/*Transposed=*/std::true_type,
Types...>(XQ, WQ, x_scale, w_scale, bias, out, 8);
}
if ((M <= N) ^ (M >= 2048 && N >= 2048)) {
return handle_transposition<
/*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
Expand Down

0 comments on commit 1c71de5

Please sign in to comment.