From ba0b51bb83ba8ffcc0e99eb32d538ca21f8e2d22 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 26 Sep 2024 18:14:18 -0700 Subject: [PATCH] Adding comparison for different fp8 matmuls --- benchmarks/fp8_matmul.py | 278 +++++++++++ benchmarks/profile_fp8_matmul.py | 133 +++++ transformer_nuggets/fp8/fp8_matmul.py | 673 ++++++++++++++++++++++++++ 3 files changed, 1084 insertions(+) create mode 100644 benchmarks/fp8_matmul.py create mode 100644 benchmarks/profile_fp8_matmul.py create mode 100644 transformer_nuggets/fp8/fp8_matmul.py diff --git a/benchmarks/fp8_matmul.py b/benchmarks/fp8_matmul.py new file mode 100644 index 0000000..8958eb8 --- /dev/null +++ b/benchmarks/fp8_matmul.py @@ -0,0 +1,278 @@ +import itertools +from dataclasses import dataclass +from typing import List, Optional +import torch +from tabulate import tabulate +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from jsonargparse import CLI +from pathlib import Path +from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds +from torchao.float8.inference import ( + addmm_float8_unwrapped_inference, + preprocess_data, + Float8MMConfig, +) +from transformer_nuggets.fp8.fp8_matmul import ( + matmul_persistent, + matmul_tma_persistent, + matmul_device_tma_persistent, +) +from enum import Enum +import csv + +torch._dynamo.config.cache_size_limit = 1000 + + +class FP8Kernel(Enum): + PERSISTENT = "Persistent" + PERSISTENT_TMA = "Persistent-TMA" + DEVICE_TMA = "Device-TMA" + SCALED_MM = "Scaled-MM" + + +class ScalingStrategy(Enum): + PER_TENSOR = "PerTensor" + PER_ROW = "PerRow" + + +def is_col_major(stride): + assert len(stride) == 2, "is_col_major only supports 2D tensors" + return stride[1] > stride[0] and stride[0] == 1 + + +def get_fp8_matmul( + A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel +): + A_fp8 = A.to(torch.float8_e4m3fn) + B_fp8 = B.to(torch.float8_e4m3fn) + A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True)) + + if scaling_strategy == ScalingStrategy.PER_TENSOR: + a_scale = torch.tensor(1, device="cuda", dtype=torch.float32) + b_scale = torch.tensor(1, device="cuda", dtype=torch.float32) + elif scaling_strategy == ScalingStrategy.PER_ROW: + a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32) + b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T + else: + raise ValueError(f"Invalid scaling strategy: {scaling_strategy}") + if fp8_kernel == FP8Kernel.PERSISTENT: + return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) + elif fp8_kernel == FP8Kernel.PERSISTENT_TMA: + return lambda: matmul_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) + elif fp8_kernel == FP8Kernel.DEVICE_TMA: + return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) + elif fp8_kernel == FP8Kernel.SCALED_MM: + return lambda: addmm_float8_unwrapped_inference( + A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True + ) + else: + raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}") + + +@dataclass(frozen=True) +class ExperimentConfig: + M: int + K: int + N: int + scaling_strategy: ScalingStrategy + fp8_kernel: FP8Kernel + compile: bool + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_time: float + fp8_time: float + bf16_tflops: float + fp8_tflops: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def calculate_tflops(M: int, N: int, K: int, time_us: float) -> float: + """Calculate TFLOPS (Tera Floating Point Operations Per Second)""" + flops = 2 * M * N * K # Number of floating point operations for matrix multiplication + tflops = (flops / time_us) / 1e6 # Convert to TFLOPS + return tflops + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16) + B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16) + + bf16_matmul = lambda x, y: torch.matmul(x, y) + fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel) + + if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM: + bf16_matmul = torch.compile(bf16_matmul) + fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune") + + # Warmup phase + warmup_iterations = 5 + for _ in range(warmup_iterations): + _ = bf16_matmul(A, B) + _ = fp8_matmul() + torch.cuda.synchronize() + + # Actual benchmarking + bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) + fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul) + + # Calculate TFLOPS + bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) + fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time) + + # Baseline fp8_matmul correctness + scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM) + out_base = scaled_mm_base() + out = fp8_matmul() + # Failing on one sample with large N + torch.testing.assert_close(out, out_base) + + return ExperimentResult( + bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops + ) + + +def print_results(experiments: List[Experiment], save_path: Optional[Path] = None): + headers = [ + "M", + "K", + "N", + "Scaling Strategy", + "Fp8 Kernel", + "Compiled", + "BF16 Time (ms)", + "FP8 Time (ms)", + "Speedup", + "BF16 TFLOPS", + "FP8 TFLOPS", + "TFLOPS Ratio", + ] + rows = [] + for experiment in experiments: + config = experiment.config + result = experiment.result + speedup = result.bf16_time / result.fp8_time + tflops_ratio = result.fp8_tflops / result.bf16_tflops + rows.append( + [ + config.M, + config.K, + config.N, + config.scaling_strategy, + config.fp8_kernel, + config.compile, + f"{result.bf16_time:.4f}", + f"{result.fp8_time:.4f}", + f"{speedup:.2f}x", + f"{result.bf16_tflops:.2f}", + f"{result.fp8_tflops:.2f}", + f"{tflops_ratio:.2f}x", + ] + ) + print(tabulate(rows, headers=headers, floatfmt=".4f")) + + if save_path is not None: + with open(save_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(headers) + writer.writerows(rows) + print(f"💾 Results saved to: {save_path}") + + +def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig]: + shapes = [(M, K, N) for K in range(512, 8193, 512)] + scaling_strategies = [ScalingStrategy.PER_ROW] + compile_options = [False] + configs = [] + fp8_kernels = [ + FP8Kernel.SCALED_MM, + # FP8Kernel.PERSISTENT, + FP8Kernel.PERSISTENT_TMA, + # FP8Kernel.DEVICE_TMA, + ] + + for (M, K, N), strategy, compile, kernel in itertools.product( + shapes, scaling_strategies, compile_options, fp8_kernels + ): + configs.append( + ExperimentConfig( + M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel + ) + ) + return configs + + +def load_and_process_data(file_path): + df = pd.read_csv(file_path) + df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float) + df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float) + return df + + +def plot_tflops_comparison(df, save_path: Path): + plt.figure(figsize=(12, 6)) + grouped = df.groupby(["K", "Fp8 Kernel"]) + k_values = sorted(df["K"].unique()) + kernel_types = df["Fp8 Kernel"].unique() + scaling_strategy = df["Scaling Strategy"].iloc[0] + m_value = df["M"].iloc[0] + n_value = df["N"].iloc[0] + + for kernel in kernel_types: + tflops_values = [grouped.get_group((k, kernel))["FP8 TFLOPS"].values[0] for k in k_values] + plt.plot(k_values, tflops_values, marker="o", label=kernel.split(".")[-1]) + + plt.xlabel("K (Matrix Dimension)") + plt.ylabel("TFLOPS") + plt.title( + f"FP8 Kernel Performance Comparison\nM={m_value}, N={n_value}\nScaling Strategy: {scaling_strategy}" + ) + plt.legend() + plt.grid(True, which="both", ls="-", alpha=0.2) + plt.xticks(k_values, rotation=45, ha="right") + plt.tight_layout() + + # Generate the file name and save in the same directory as the CSV file + file_name = f"fp8_kernel_comparison_{m_value}_{n_value}.png" + graph_path = save_path.parent / file_name + plt.savefig(graph_path, dpi=300) + print(f"TFLOPS comparison plot saved as {graph_path}") + + +def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False): + """Benchmark FP8 MatMul with different configurations and optionally graph results. + + Args: + save_path (Optional[str], optional): Path to save the results. Defaults to None. + M (int, optional): Number of rows in the first matrix. Defaults to 8192. + N (int, optional): Number of columns in the second matrix. Defaults to 8192. + graph_results (bool, optional): Whether to create a graph of the results. Defaults to False. + """ + torch.random.manual_seed(123) + configs = get_configs_varying_k(M, N) + results = [] + if save_path is not None: + save_path = Path(save_path) + save_path = save_path.with_suffix(".csv") + save_path.parent.mkdir(parents=True, exist_ok=True) + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + print_results(results, save_path) + + if graph and save_path is not None: + df = load_and_process_data(save_path) + plot_tflops_comparison(df, save_path) + + +if __name__ == "__main__": + CLI(main) diff --git a/benchmarks/profile_fp8_matmul.py b/benchmarks/profile_fp8_matmul.py new file mode 100644 index 0000000..cdcd2af --- /dev/null +++ b/benchmarks/profile_fp8_matmul.py @@ -0,0 +1,133 @@ +import torch +from dataclasses import dataclass +from jsonargparse import CLI +import logging +from pathlib import Path + +from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function +from torchao.float8.inference import ( + addmm_float8_unwrapped_inference, + preprocess_data, + Float8MMConfig, +) +from transformer_nuggets.fp8.fp8_matmul import ( + matmul_persistent, + matmul_tma_persistent, + matmul_device_tma_persistent, +) +from enum import Enum + +logging.getLogger("transformer_nuggets").setLevel(logging.INFO) + + +class FP8Kernel(Enum): + PERSISTENT = "Persistent" + PERSISTENT_TMA = "Persistent-TMA" + DEVICE_TMA = "Device-TMA" + SCALED_MM = "Scaled-MM" + + +class ScalingStrategy(Enum): + PER_TENSOR = "PerTensor" + PER_ROW = "PerRow" + + +@dataclass(frozen=True) +class ExperimentConfig: + M: int + K: int + N: int + scaling_strategy: ScalingStrategy + fp8_kernel: FP8Kernel + compile: bool + + +def get_fp8_matmul( + A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel +): + A_fp8 = A.to(torch.float8_e4m3fn) + B_fp8 = B.to(torch.float8_e4m3fn) + A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True)) + + if scaling_strategy == ScalingStrategy.PER_TENSOR: + a_scale = torch.tensor(1, device="cuda", dtype=torch.float32) + b_scale = torch.tensor(1, device="cuda", dtype=torch.float32) + elif scaling_strategy == ScalingStrategy.PER_ROW: + a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32) + b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T + else: + raise ValueError(f"Invalid scaling strategy: {scaling_strategy}") + + if fp8_kernel == FP8Kernel.PERSISTENT: + return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) + elif fp8_kernel == FP8Kernel.PERSISTENT_TMA: + return lambda: matmul_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) + elif fp8_kernel == FP8Kernel.DEVICE_TMA: + return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16) + elif fp8_kernel == FP8Kernel.SCALED_MM: + return lambda: addmm_float8_unwrapped_inference( + A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True + ) + else: + raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}") + + +def profile_matmul(config: ExperimentConfig, profile_config: ProfileConfig): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16) + B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16) + + fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel) + bf16_matmul = lambda x, y: torch.matmul(x, y) + + if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM: + bf16_matmul = torch.compile(bf16_matmul) + fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune") + + # Warmup phase + warmup_iterations = 5 + for _ in range(warmup_iterations): + _ = bf16_matmul(A, B) + _ = fp8_matmul() + torch.cuda.synchronize() + + logging.info("Profiling FP8 MatMul") + fp8_profile = profile_function(profile_config, fp8_matmul) + + return fp8_profile + + +def main(): + torch.random.manual_seed(123) + + # Define your experiment configuration here + config = ExperimentConfig( + M=8192, + K=8192, + N=8192, + scaling_strategy=ScalingStrategy.PER_TENSOR, + fp8_kernel=FP8Kernel.PERSISTENT_TMA, + compile=False, + ) + + base = Path(__file__).resolve().parent / Path("data") + path = base / Path(f"matmul_profile_{config.fp8_kernel.name}.csv") + # Define your profile configuration here + profile_config = ProfileConfig( + file_path=str(path), + name=f"MatMul Profiling {config.fp8_kernel}", + cuda=True, + iters=3, + warmup_iters=5, + sync=True, + ) + + fp8_profile = profile_matmul(config, profile_config) + + print(f"\nProfile for config: {config}") + print("\nFP8 Profile:") + print(fp8_profile.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +if __name__ == "__main__": + CLI(main) diff --git a/transformer_nuggets/fp8/fp8_matmul.py b/transformer_nuggets/fp8/fp8_matmul.py new file mode 100644 index 0000000..a270a9a --- /dev/null +++ b/transformer_nuggets/fp8/fp8_matmul.py @@ -0,0 +1,673 @@ +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor +from triton.language.extra.cuda._experimental_tma import experimental_device_tensormap_create2d + +# Autotuner does not work with TMA. Use manual config. +configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 4, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, +} + + +def validate_matmul_inputs( + a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor +) -> bool: + """ + Validate inputs for matrix multiplication with scaling. + + Args: + a (torch.Tensor): First input matrix + b (torch.Tensor): Second input matrix + a_scale (torch.Tensor): Scaling factor for a + b_scale (torch.Tensor): Scaling factor for b + + Returns: + bool: True if inputs are valid, raises AssertionError otherwise + """ + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + + ROW_WISE_SCALING = a_scale.numel() != 1 + if ROW_WISE_SCALING: + assert a_scale.dim() == 2, f"a_scale must be a 2D tensor but got {a_scale.dim()}" + assert b_scale.dim() == 2, f"b_scale must be a 2D tensor but got {b_scale.dim()}" + assert ( + a_scale.shape[0] == a.shape[0] + ), f"a_scale must have same number of rows as a, got {a_scale.shape[0]} vs {a.shape[0]}" + assert a_scale.shape[1] == 1, f"a_scale must have 1 column, got {a_scale.shape[1]}" + assert ( + b_scale.shape[1] == b.shape[1] + ), f"b_scale must have same number of columns as b, got {b_scale.shape[0]} vs {b.shape[1]}" + assert b_scale.shape[0] == 1, f"b_scale must have 1 column, got {b_scale.shape[1]}" + else: + assert ( + a_scale.numel() == 1 + ), f"a_scale must be a scalar for per-tensor scaling, got shape {a_scale.shape}" + assert ( + b_scale.numel() == 1 + ), f"b_scale must be a scalar for per-tensor scaling, got shape {b_scale.shape}" + + return ROW_WISE_SCALING + + +def is_row_major(stride): + assert len(stride) == 2, "is_row_major only supports 2D tensors" + return stride[0] > stride[1] and stride[1] == 1 + + +def is_col_major(stride): + assert len(stride) == 2, "is_col_major only supports 2D tensors" + return stride[1] > stride[0] and stride[0] == 1 + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, ROW_WISE_SCALING: tl.constexpr): + if ROW_WISE_SCALING: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale + + +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + ROW_WISE_SCALING: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if ROW_WISE_SCALING: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_a_scale_m, + stride_b_scale_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + ROW_WISE_SCALING: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + a_scale, b_scale = load_scales(a_scale_ptr, b_scale_ptr, ROW_WISE_SCALING) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + # Apply inverse scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + ROW_WISE_SCALING, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_persistent( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + # Check constraints. + ROW_WISE_SCALING = validate_matmul_inputs(a, b, a_scale, b_scale) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=output_dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), + ) + matmul_kernel_persistent[grid]( + a, + a_scale, + b, + b_scale, + c, # + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + a_scale.stride(0) if ROW_WISE_SCALING else 0, + b_scale.stride(1) if ROW_WISE_SCALING else 0, + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], + NUM_SMS=NUM_SMS, + num_stages=configs[dtype]["num_stages"], + num_warps=configs[dtype]["num_warps"], + ROW_WISE_SCALING=ROW_WISE_SCALING, + ) + return c + + +@triton.jit +def matmul_kernel_tma_persistent( + a_desc_ptr, + a_scale_ptr, + b_desc_ptr, + b_scale_ptr, + c_desc_ptr, + M, + N, + K, + stride_a_scale_m, + stride_b_scale_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + output_dtype: tl.constexpr, + ROW_WISE_SCALING: tl.constexpr, +): + tl.inline_asm_elementwise( + "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", + "=r, l", + [a_desc_ptr], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + tl.inline_asm_elementwise( + "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", + "=r, l", + [b_desc_ptr], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + tl.inline_asm_elementwise( + "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", + "=r, l", + [c_desc_ptr], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + dtype = tl.float8e4nv + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + a_scale, b_scale = load_scales(a_scale_ptr, b_scale_ptr, ROW_WISE_SCALING) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + offs_k = ki * BLOCK_SIZE_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype + ) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + offs_cm = offs_am + tl.arange(0, BLOCK_SIZE_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_SIZE_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + ROW_WISE_SCALING, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + c = accumulator.to(output_dtype) + + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_tma_persistent( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + # Check constraints. + assert is_row_major(a.stride()), "a must be row major" + assert is_col_major(b.stride()), "b must be col major" + ROW_WISE_SCALING = validate_matmul_inputs(a, b, a_scale, b_scale) + + M, K = a.shape + K, N = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=output_dtype) + desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + a.data_ptr(), + M, + K, + configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_K"], + a.element_size(), + ) + desc_b = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + b.data_ptr(), + N, + K, + configs[dtype]["BLOCK_SIZE_N"], + configs[dtype]["BLOCK_SIZE_K"], + b.element_size(), + ) + desc_c = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + c.data_ptr(), + M, + N, + configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_N"], + c.element_size(), + ) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + triton_out_dtype = tl.float8e4nv if output_dtype == torch.float8_e4m3fn else tl.bfloat16 + + grid = lambda META: ( + min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), + ) + matmul_kernel_tma_persistent[grid]( + desc_a, + a_scale, + desc_b, + b_scale, + desc_c, + M, + N, + K, + a_scale.stride(0) if ROW_WISE_SCALING else 0, + b_scale.stride(1) if ROW_WISE_SCALING else 0, + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], + NUM_SMS=NUM_SMS, + num_stages=configs[dtype]["num_stages"], + num_warps=configs[dtype]["num_warps"], + output_dtype=triton_out_dtype, + ROW_WISE_SCALING=ROW_WISE_SCALING, + ) + return c + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_device_tma_persistent( + workspace_ptr, + tiles_per_update: tl.constexpr, + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + c_ptr, + M, + N, + K, + stride_a_scale_m, + stride_b_scale_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + OUTPUT_DTYPE: tl.constexpr, + ROW_WISE_SCALING: tl.constexpr, +): + # Matmul using TMA and device-side descriptor creation + dtype = tl.float8e4nv + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + TMA_SIZE: tl.constexpr = 128 + workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + c_desc_ptr = workspace_base + 2 * TMA_SIZE + + experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=a_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[M, K], + element_ty=a_ptr.dtype.element_ty, + ) + experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=b_ptr, + load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], + global_size=[N, K], + element_ty=b_ptr.dtype.element_ty, + ) + experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[M, N], + element_ty=c_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + ni = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + a_scale, b_scale = load_scales(a_scale_ptr, b_scale_ptr, ROW_WISE_SCALING) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + ni += 1 + + # Simulate a grouped gemm + if ni == tiles_per_update: + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=a_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[M, K], + element_ty=a_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=b_ptr, + load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], + global_size=[N, K], + element_ty=b_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[M, N], + element_ty=c_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + ni = 0 + + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + offs_k = ki * BLOCK_SIZE_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype + ) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_SIZE_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_SIZE_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + ROW_WISE_SCALING, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + c = accumulator.to(OUTPUT_DTYPE) + + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_device_tma_persistent( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + output_dtype: torch.dtype, + tiles_per_update: int = 1, +) -> torch.Tensor: + assert is_row_major(a.stride()), "a must be row major" + assert is_col_major(b.stride()), "b must be col major" + ROW_WISE_SCALING = validate_matmul_inputs(a, b, a_scale, b_scale) + + M, K = a.shape + K, N = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=output_dtype) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + tma_size = 128 + workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") + triton_out_dtype = tl.float8e4nv if output_dtype == torch.float8_e4m3fn else tl.bfloat16 + + grid = lambda META: ( + min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), + ) + matmul_kernel_device_tma_persistent[grid]( + workspace, + tiles_per_update, + a, + a_scale, + b, + b_scale, + c, + M, + N, + K, + a_scale.stride(0) if ROW_WISE_SCALING else 0, + b_scale.stride(1) if ROW_WISE_SCALING else 0, + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], + NUM_SMS=NUM_SMS, + num_stages=configs[dtype]["num_stages"], + num_warps=configs[dtype]["num_warps"], + OUTPUT_DTYPE=triton_out_dtype, + ROW_WISE_SCALING=ROW_WISE_SCALING, + ) + return c