Skip to content

Commit

Permalink
Support e8m0 2
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Feb 24, 2025
1 parent f234d14 commit 662f16a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def get_e8_scales(A: torch.Tensor, B: torch.Tensor):
n_a_cols = ceil_div(K, 32)
n_b_rows = ceil_div(N, 128) * 128
n_b_cols = ceil_div(K, 32)
a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda")
b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda")
a_scales = torch.randn(n_a_rows, n_a_cols, dtype=torch.float32, device="cuda").to(torch.float8_e8m0fnu)
b_scales = torch.randn(n_b_rows, n_b_cols, dtype=torch.float32, device="cuda").to(torch.float8_e8m0fnu)

return a_scales, b_scales


Expand Down Expand Up @@ -110,10 +111,9 @@ def get_fp8_matmul(
return lambda: torch._scaled_mm(
A_fp8,
B_fp8,
b_scale, # swap since we haven't figured this out yet
a_scale,
b_scale,
out_dtype=torch.bfloat16,
scale_dtype=1,
)
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
Expand Down Expand Up @@ -272,7 +272,7 @@ def get_configs_varying_k(
compile_options = [False]
configs = []
fp8_kernels = [
# FP8Kernel.SCALED_MM,
FP8Kernel.SCALED_MM,
# FP8Kernel.PERSISTENT,
# FP8Kernel.PERSISTENT_TMA,
# FP8Kernel.DEVICE_TMA,
Expand Down

0 comments on commit 662f16a

Please sign in to comment.