diff --git a/benchmarks/fp8_matmul.py b/benchmarks/fp8_matmul.py index 0309bc1..5b0ab61 100644 --- a/benchmarks/fp8_matmul.py +++ b/benchmarks/fp8_matmul.py @@ -46,12 +46,13 @@ class FP8Kernel(Enum): PERSISTENT_TMA = "Persistent-TMA" DEVICE_TMA = "Device-TMA" SCALED_MM = "Scaled-MM" - MX_FP8 = "MX-FP8" + CUTLASS_MX = "Cutlass-MX-FP8" class ScalingStrategy(Enum): PER_TENSOR = "PerTensor" PER_ROW = "PerRow" + E8M0 = "E8M0" def is_col_major(stride): @@ -59,6 +60,18 @@ def is_col_major(stride): return stride[1] > stride[0] and stride[0] == 1 +def get_e8_scales(A: torch.Tensor, B: torch.Tensor): + M, K = A.shape + N, _ = B.shape + n_a_rows = ceil_div(M, 128) + n_a_cols = ceil_div(K, 32) + n_b_rows = ceil_div(N, 128) + n_b_cols = ceil_div(K, 32) + a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda") + b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda") + return a_scales, b_scales + + def get_fp8_matmul( A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel ): @@ -66,14 +79,24 @@ def get_fp8_matmul( B_fp8 = B.to(torch.float8_e4m3fn) A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True)) + # Handle E8M0 format for supported kernels + if scaling_strategy == ScalingStrategy.E8M0: + if fp8_kernel not in [FP8Kernel.CUTLASS_MX, FP8Kernel.SCALED_MM]: + raise ValueError( + "E8M0 scaling strategy is only supported by MX_FP8 and SCALED_MM kernels" + ) + if scaling_strategy == ScalingStrategy.PER_TENSOR: a_scale = torch.tensor(1, device="cuda", dtype=torch.float32) b_scale = torch.tensor(1, device="cuda", dtype=torch.float32) elif scaling_strategy == ScalingStrategy.PER_ROW: a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32) b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T + elif scaling_strategy == ScalingStrategy.E8M0: + a_scale, b_scale = get_e8_scales(A_fp8, B_fp8) else: raise ValueError(f"Invalid scaling strategy: {scaling_strategy}") + if fp8_kernel == FP8Kernel.PERSISTENT: return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) elif fp8_kernel == FP8Kernel.PERSISTENT_TMA: @@ -81,26 +104,29 @@ def get_fp8_matmul( elif fp8_kernel == FP8Kernel.DEVICE_TMA: return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) elif fp8_kernel == FP8Kernel.SCALED_MM: + if scaling_strategy == ScalingStrategy.E8M0: + # Use the scales we computed earlier for E8M0 + return lambda: torch._scaled_mm( + A_fp8, + B_fp8, + b_scale, # swap since we haven't figured this out yet + a_scale, + out_dtype=torch.bfloat16, + scale_dtype=1, + ) return lambda: addmm_float8_unwrapped_inference( A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True ) - elif fp8_kernel == FP8Kernel.MX_FP8: + elif fp8_kernel == FP8Kernel.CUTLASS_MX: + assert ( + scaling_strategy == ScalingStrategy.E8M0 + ), "E8M0 scaling strategy is required for MX_FP8" try: from driss_torch import mx_fp8_bf16 except ModuleNotFoundError: print("Driss Torch not installed") return None - # This is crude but we just care about perf, numerics checked elsewhere - M, K = A_fp8.shape - N, _ = B_fp8.shape - n_a_rows = ceil_div(M, 128) - n_a_cols = ceil_div(K, 32) - n_b_rows = ceil_div(N, 128) - n_b_cols = ceil_div(K, 32) - a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda") - b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda") - - return lambda: mx_fp8_bf16(A_fp8, B_fp8, a_scales, b_scales) + return lambda: mx_fp8_bf16(A_fp8, B_fp8, a_scale, b_scale) else: raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}") @@ -246,7 +272,7 @@ def get_configs_varying_k( M: int = 8192, N: int = 8192, bf16: bool = False ) -> List[ExperimentConfig]: shapes = [(M, K, N) for K in range(1024, 16385, 1024)] - scaling_strategies = [ScalingStrategy.PER_TENSOR] + scaling_strategies = [ScalingStrategy.E8M0] compile_options = [False] configs = [] fp8_kernels = [ @@ -254,7 +280,7 @@ def get_configs_varying_k( # FP8Kernel.PERSISTENT, # FP8Kernel.PERSISTENT_TMA, # FP8Kernel.DEVICE_TMA, - FP8Kernel.MX_FP8, + FP8Kernel.CUTLASS_MX, ] for (M, K, N), strategy, compile, kernel in itertools.product(