diff --git a/benchmarks/fp8_matmul.py b/benchmarks/fp8_matmul.py index 61c9e22..d45a04a 100644 --- a/benchmarks/fp8_matmul.py +++ b/benchmarks/fp8_matmul.py @@ -68,8 +68,9 @@ def get_e8_scales(A: torch.Tensor, B: torch.Tensor): n_a_cols = ceil_div(K, 32) n_b_rows = ceil_div(N, 128) * 128 n_b_cols = ceil_div(K, 32) - a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda") - b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda") + a_scales = torch.randn(n_a_rows, n_a_cols, dtype=torch.float32, device="cuda").to(torch.float8_e8m0fnu) + b_scales = torch.randn(n_b_rows, n_b_cols, dtype=torch.float32, device="cuda").to(torch.float8_e8m0fnu) + return a_scales, b_scales @@ -110,10 +111,9 @@ def get_fp8_matmul( return lambda: torch._scaled_mm( A_fp8, B_fp8, - b_scale, # swap since we haven't figured this out yet a_scale, + b_scale, out_dtype=torch.bfloat16, - scale_dtype=1, ) return lambda: addmm_float8_unwrapped_inference( A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True @@ -272,7 +272,7 @@ def get_configs_varying_k( compile_options = [False] configs = [] fp8_kernels = [ - # FP8Kernel.SCALED_MM, + FP8Kernel.SCALED_MM, # FP8Kernel.PERSISTENT, # FP8Kernel.PERSISTENT_TMA, # FP8Kernel.DEVICE_TMA,