diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index c6a406c..0eb876b 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -13,7 +13,7 @@ jobs: matrix: python-version: ["3.10"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: diff --git a/benchmarks/fp8_matmul.py b/benchmarks/fp8_matmul.py index 0217fdf..6bd4bf8 100644 --- a/benchmarks/fp8_matmul.py +++ b/benchmarks/fp8_matmul.py @@ -14,16 +14,26 @@ preprocess_data, Float8MMConfig, ) -from transformer_nuggets.fp8.fp8_matmul import ( - matmul_persistent, - matmul_tma_persistent, - matmul_device_tma_persistent, -) + +try: + from transformer_nuggets.fp8.fp8_matmul import ( + matmul_persistent, + matmul_tma_persistent, + matmul_device_tma_persistent, + ) +except ModuleNotFoundError: + print("Triton version not new enough") + pass + from datetime import datetime from enum import Enum import csv +import logging -torch._dynamo.config.cache_size_limit = 1000 +torch._dynamo.config.cache_size_limit = 10000 +logging.getLogger("transformer_nuggets").setLevel(logging.INFO) +torch._inductor.config.max_autotune_gemm_backends = "TRITON" +CHECK = False class FP8Kernel(Enum): @@ -80,13 +90,14 @@ class ExperimentConfig: scaling_strategy: ScalingStrategy fp8_kernel: FP8Kernel compile: bool + bf16: bool @dataclass(frozen=True) class ExperimentResult: - bf16_time: float + bf16_time: Optional[float] fp8_time: float - bf16_tflops: float + bf16_tflops: Optional[float] fp8_tflops: float @@ -113,29 +124,34 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM: bf16_matmul = torch.compile(bf16_matmul) - fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune") + fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs", dynamic=False) # Warmup phase warmup_iterations = 5 for _ in range(warmup_iterations): - _ = bf16_matmul(A, B) + if config.bf16: + _ = bf16_matmul(A, B) _ = fp8_matmul() torch.cuda.synchronize() # Actual benchmarking - bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) + + bf16_time = ( + benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) if config.bf16 else None + ) fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul) # Calculate TFLOPS - bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) + bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) if bf16_time else None fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time) # Baseline fp8_matmul correctness - scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM) - out_base = scaled_mm_base() - out = fp8_matmul() - # Failing on one sample with large N - torch.testing.assert_close(out, out_base) + if CHECK: + scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM) + out_base = scaled_mm_base() + out = fp8_matmul() + # Failing on one sample with large N + torch.testing.assert_close(out, out_base) return ExperimentResult( bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops @@ -161,24 +177,38 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non for experiment in experiments: config = experiment.config result = experiment.result - speedup = result.bf16_time / result.fp8_time - tflops_ratio = result.fp8_tflops / result.bf16_tflops + + # Format values handling None cases + bf16_time = f"{result.bf16_time:.4f}" if result.bf16_time is not None else "N/A" + fp8_time = f"{result.fp8_time:.4f}" + bf16_tflops = f"{result.bf16_tflops:.2f}" if result.bf16_tflops is not None else "N/A" + fp8_tflops = f"{result.fp8_tflops:.2f}" + + # Calculate ratios only if bf16 results exist + if result.bf16_time is not None: + speedup = f"{(result.bf16_time / result.fp8_time):.2f}x" + tflops_ratio = f"{(result.fp8_tflops / result.bf16_tflops):.2f}x" + else: + speedup = "N/A" + tflops_ratio = "N/A" + rows.append( [ config.M, config.K, config.N, - config.scaling_strategy, - config.fp8_kernel, + config.scaling_strategy.value, + config.fp8_kernel.value, config.compile, - f"{result.bf16_time:.4f}", - f"{result.fp8_time:.4f}", - f"{speedup:.2f}x", - f"{result.bf16_tflops:.2f}", - f"{result.fp8_tflops:.2f}", - f"{tflops_ratio:.2f}x", + bf16_time, + fp8_time, + speedup, + bf16_tflops, + fp8_tflops, + tflops_ratio, ] ) + print(tabulate(rows, headers=headers, floatfmt=".4f")) if save_path is not None: @@ -189,16 +219,18 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non print(f"💾 Results saved to: {save_path}") -def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig]: - shapes = [(M, K, N) for K in range(512, 8193, 512)] +def get_configs_varying_k( + M: int = 8192, N: int = 8192, bf16: bool = False +) -> List[ExperimentConfig]: + shapes = [(M, K, N) for K in range(1024, 16385, 1024)] scaling_strategies = [ScalingStrategy.PER_ROW] - compile_options = [False] + compile_options = [True, False] configs = [] fp8_kernels = [ FP8Kernel.SCALED_MM, # FP8Kernel.PERSISTENT, - FP8Kernel.PERSISTENT_TMA, - FP8Kernel.DEVICE_TMA, + # FP8Kernel.PERSISTENT_TMA, + # FP8Kernel.DEVICE_TMA, ] for (M, K, N), strategy, compile, kernel in itertools.product( @@ -206,7 +238,13 @@ def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig ): configs.append( ExperimentConfig( - M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel + M=M, + K=K, + N=N, + scaling_strategy=strategy, + compile=compile, + fp8_kernel=kernel, + bf16=bf16, ) ) return configs @@ -214,8 +252,8 @@ def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig def load_and_process_data(file_path): df = pd.read_csv(file_path) - df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float) - df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float) + # df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float) + # df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float) return df @@ -250,7 +288,13 @@ def plot_tflops_comparison(df, save_path: Path): print(f"TFLOPS comparison plot saved as {graph_path}") -def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False): +def main( + save_path: Optional[str] = None, + M: int = 8192, + N: int = 8192, + graph: bool = False, + bf_16: bool = False, +): """Benchmark FP8 MatMul with different configurations and optionally graph results. Args: @@ -258,9 +302,10 @@ def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: b M (int, optional): Number of rows in the first matrix. Defaults to 8192. N (int, optional): Number of columns in the second matrix. Defaults to 8192. graph_results (bool, optional): Whether to create a graph of the results. Defaults to False. + bf_16 (bool, optional): Whether to use BF16 for the baseline. Defaults to False. """ torch.random.manual_seed(123) - configs = get_configs_varying_k(M, N) + configs = get_configs_varying_k(M, N, bf16=bf_16) results = [] if save_path is not None: save_path = Path(save_path) diff --git a/benchmarks/llama.py b/benchmarks/llama.py index 70d48ec..fa649ef 100644 --- a/benchmarks/llama.py +++ b/benchmarks/llama.py @@ -99,15 +99,15 @@ def forward( block_size = self.config.block_size if max_seq_length is None: max_seq_length = block_size - assert ( - T <= max_seq_length - ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" - assert ( - max_seq_length <= block_size - ), f"Cannot attend to {max_seq_length}, block size is only {block_size}" - assert ( - T <= block_size - ), f"Cannot forward sequence of length {T}, block size is only {block_size}" + assert T <= max_seq_length, ( + f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + ) + assert max_seq_length <= block_size, ( + f"Cannot attend to {max_seq_length}, block size is only {block_size}" + ) + assert T <= block_size, ( + f"Cannot forward sequence of length {T}, block size is only {block_size}" + ) if self.rope_cache is None: self.rope_cache = self.build_rope_cache(idx) diff --git a/transformer_nuggets/__init__.py b/transformer_nuggets/__init__.py index b572048..99625cc 100644 --- a/transformer_nuggets/__init__.py +++ b/transformer_nuggets/__init__.py @@ -1 +1,19 @@ from transformer_nuggets import quant as quant, utils as utils + + +def init_logging(): + """ + Configure logging for transformer_nuggets library at INFO level. + Adds a StreamHandler if none exists. + """ + import logging + + logger = logging.getLogger("transformer_nuggets") + logger.setLevel(logging.INFO) + + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.propagate = False diff --git a/transformer_nuggets/flex/utils.py b/transformer_nuggets/flex/utils.py index 1c5948d..387a636 100644 --- a/transformer_nuggets/flex/utils.py +++ b/transformer_nuggets/flex/utils.py @@ -95,9 +95,9 @@ def visualize_attention_scores( Returns: None """ - assert ( - score_mod is not None or mask_mod is not None - ), "Must provide either score_mod or mask_mod" + assert score_mod is not None or mask_mod is not None, ( + "Must provide either score_mod or mask_mod" + ) query = query[batch_idx, head_idx, :, :] key = key[batch_idx, head_idx, :, :] scores_viz = create_score_mod( diff --git a/transformer_nuggets/fp8/fp8_matmul.py b/transformer_nuggets/fp8/fp8_matmul.py index 8134793..037f222 100644 --- a/transformer_nuggets/fp8/fp8_matmul.py +++ b/transformer_nuggets/fp8/fp8_matmul.py @@ -47,21 +47,21 @@ def validate_matmul_inputs( if ROW_WISE_SCALING: assert a_scale.dim() == 2, f"a_scale must be a 2D tensor but got {a_scale.dim()}" assert b_scale.dim() == 2, f"b_scale must be a 2D tensor but got {b_scale.dim()}" - assert ( - a_scale.shape[0] == a.shape[0] - ), f"a_scale must have same number of rows as a, got {a_scale.shape[0]} vs {a.shape[0]}" + assert a_scale.shape[0] == a.shape[0], ( + f"a_scale must have same number of rows as a, got {a_scale.shape[0]} vs {a.shape[0]}" + ) assert a_scale.shape[1] == 1, f"a_scale must have 1 column, got {a_scale.shape[1]}" - assert ( - b_scale.shape[1] == b.shape[1] - ), f"b_scale must have same number of columns as b, got {b_scale.shape[0]} vs {b.shape[1]}" + assert b_scale.shape[1] == b.shape[1], ( + f"b_scale must have same number of columns as b, got {b_scale.shape[0]} vs {b.shape[1]}" + ) assert b_scale.shape[0] == 1, f"b_scale must have 1 column, got {b_scale.shape[1]}" else: - assert ( - a_scale.numel() == 1 - ), f"a_scale must be a scalar for per-tensor scaling, got shape {a_scale.shape}" - assert ( - b_scale.numel() == 1 - ), f"b_scale must be a scalar for per-tensor scaling, got shape {b_scale.shape}" + assert a_scale.numel() == 1, ( + f"a_scale must be a scalar for per-tensor scaling, got shape {a_scale.shape}" + ) + assert b_scale.numel() == 1, ( + f"b_scale must be a scalar for per-tensor scaling, got shape {b_scale.shape}" + ) return ROW_WISE_SCALING diff --git a/transformer_nuggets/llama/train.py b/transformer_nuggets/llama/train.py index cb44a08..a1f2d35 100644 --- a/transformer_nuggets/llama/train.py +++ b/transformer_nuggets/llama/train.py @@ -401,9 +401,9 @@ def entrypoint( overfit: bool = False, profile: bool = False, ): - assert ( - isinstance(fp8_linear_type, str) or fp8_linear_type is None - ), "fp8_linear_type must be str" + assert isinstance(fp8_linear_type, str) or fp8_linear_type is None, ( + "fp8_linear_type must be str" + ) assert isinstance(compile, bool), "compile must be bool" assert isinstance(overfit, bool), "overfit must be bool" assert isinstance(profile, bool), "profile must be bool" diff --git a/transformer_nuggets/quant/dequant_kernel.py b/transformer_nuggets/quant/dequant_kernel.py index 908f771..232b5ab 100644 --- a/transformer_nuggets/quant/dequant_kernel.py +++ b/transformer_nuggets/quant/dequant_kernel.py @@ -91,9 +91,9 @@ def dequant_nf4_tensor_kernel( def dequant_nf4_tensor(weight: NF4Tensor): """Takes a quantized tensor and dequantizes it to bfloat16""" assert isinstance(weight, NF4Tensor), "Input tensor must be of type NF4Tensor" - assert ( - weight.shape.numel() % weight.block_size == 0 - ), "Input tensor must be a multiple of block size" + assert weight.shape.numel() % weight.block_size == 0, ( + "Input tensor must be a multiple of block size" + ) out_tensor = torch.empty(weight.shape, dtype=weight.dtype, device="cuda") numel = weight.shape.numel() grid = (triton.cdiv(numel, (weight.block_size)),) diff --git a/transformer_nuggets/quant/nf4_tensor.py b/transformer_nuggets/quant/nf4_tensor.py index 12c4c6d..e42cfc4 100644 --- a/transformer_nuggets/quant/nf4_tensor.py +++ b/transformer_nuggets/quant/nf4_tensor.py @@ -85,9 +85,9 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor torch.Tensor: Tensor of scalers for each block """ assert inpt_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - (inpt_tensor.numel() % block_size) == 0 - ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" + assert (inpt_tensor.numel() % block_size) == 0, ( + f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" + ) n_blocks = inpt_tensor.numel() // block_size blocks = inpt_tensor.view(n_blocks, block_size) @@ -167,9 +167,9 @@ def from_tensor( block_size: int = 64, scaler_block_size: int = 256, ): - assert ( - inpt_tensor.numel() % block_size == 0 - ), "Input tensor must be divisible by block size" + assert inpt_tensor.numel() % block_size == 0, ( + "Input tensor must be divisible by block size" + ) assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this assert not inpt_tensor.requires_grad, "Input tensor must not require grad" @@ -246,9 +246,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert inpt_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - (inpt_tensor.numel() % scaler_block_size) == 0 - ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {scaler_block_size}" + assert (inpt_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {scaler_block_size}" + ) # First round of quantization # Produces: A tensor of size (n_blocks) of inpt_tensor.dtype @@ -256,9 +256,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert ( - scalers_1.numel() % scaler_block_size == 0 - ), "Number of scalers must be divisible by scaler block size" + assert scalers_1.numel() % scaler_block_size == 0, ( + "Number of scalers must be divisible by scaler block size" + ) n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -298,9 +298,9 @@ def dequantize_scalers( """ assert inpt_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - (inpt_tensor.numel() % scaler_block_size) == 0 - ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {scaler_block_size}" + assert (inpt_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {scaler_block_size}" + ) n_scaler_blocks = inpt_tensor.numel() // scaler_block_size inpt_tensor = inpt_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (inpt_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -316,9 +316,9 @@ def convert_to_norm_float_weight( flattened_tensor = inpt_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = inpt_tensor.numel() - assert ( - numel % 2 == 0 - ), "Number of elements must be even just to not have to think about the end" + assert numel % 2 == 0, ( + "Number of elements must be even just to not have to think about the end" + ) # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size) diff --git a/transformer_nuggets/quant/qlora_debug.py b/transformer_nuggets/quant/qlora_debug.py index 98952bb..c454a35 100644 --- a/transformer_nuggets/quant/qlora_debug.py +++ b/transformer_nuggets/quant/qlora_debug.py @@ -79,9 +79,9 @@ def get_scalers(self, inpt_tensor: torch.Tensor, block_size: int) -> torch.Tenso return torch.tensor(block_scalers) def __init__(self, inpt_tensor: torch.Tensor, block_size=64): - assert ( - inpt_tensor.numel() % block_size == 0 - ), "Input tensor must be divisible by block size" + assert inpt_tensor.numel() % block_size == 0, ( + "Input tensor must be divisible by block size" + ) self.block_size = block_size self.n_blocks = inpt_tensor.numel() // block_size self.scalers = self.get_scalers(inpt_tensor, self.block_size) @@ -94,9 +94,9 @@ def get_norm_float_weight(self, inpt_tensor: torch.Tensor) -> torch.Tensor: flattened_tensor = inpt_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = inpt_tensor.numel() - assert ( - numel % 2 == 0 - ), "Number of elements must be even just to not have to think about the end" + assert numel % 2 == 0, ( + "Number of elements must be even just to not have to think about the end" + ) quantized_length = numel // 2 quantized_tensor = torch.zeros(quantized_length, dtype=torch.uint8) for i in tqdm(range(len(self.scalers))): diff --git a/transformer_nuggets/subclass.py b/transformer_nuggets/subclass.py index fe851b6..826e9e0 100644 --- a/transformer_nuggets/subclass.py +++ b/transformer_nuggets/subclass.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from abc import ABC, abstractmethod -from typing import Tuple, Iterable, Type, Dict, Any +from typing import Tuple, Iterable, Type import torch import functools @@ -68,91 +67,3 @@ def wrapper(f, types, args, kwargs): return func return decorator - - -class PT2Subclass(ABC, torch.Tensor): - """Abstract base class for PyTorch 2.0 compliant tensor subclasses. - - This class enforces implementation of required methods for proper tensor subclassing - while maintaining inheritance from torch.Tensor. - - Required Methods: - __new__: Constructor for creating new instances - __tensor_flatten__: Method for flattening the tensor into constituent parts - __tensor_unflatten__: Method for reconstructing the tensor from flattened parts - __torch_dispatch__: Handler for tensor operations - """ - - implements = classmethod(_implements) - - @staticmethod - @abstractmethod - def __new__(cls, *args, **kwargs) -> "PT2Subclass": - """Create a new instance of the tensor subclass. - I like structuring this as SubclassArgs then everything else that - goes on the instance - Example: - subclass = torch.Tensor._make_wrapper_subclass( - cls, - tensor_meta.original_shape, - tensor_meta.original_strides, - tensor_meta.storage_offset, - dtype=tensor_meta.dtype, - device=tensor_meta.device, - requires_grad=tensor_meta.requires_grad, - ) - return subclass - - """ - pass - - @abstractmethod - def __init__(self, *args, **kwargs) -> None: - """Initialize the tensor subclass instance.""" - pass - - @abstractmethod - def __tensor_flatten__(self) -> Tuple[List[str], Dict[str, Any]]: - """Flatten the tensor into its constituent parts. - - Returns: - Tuple containing: - - List of the attributes on the subclass that are tensors - - Dictionary of metadata needed for reconstruction - """ - pass - - @staticmethod - @abstractmethod - def __tensor_unflatten__( - inner_tensors: Dict[str, torch.Tensor], meta: Dict[str, Any], outer_size: torch.Size, outer_stride: torch.Size - ) -> "PT2Subclass": - """Reconstruct the tensor from flattened parts. - - Args: - inner_tensors: Dictionary mapping names to constituent tensors - meta: Metadata dictionary from __tensor_flatten__ - *args, **kwargs: Additional arguments for reconstruction - - Returns: - Reconstructed tensor subclass instance - """ - pass - - @classmethod - @abstractmethod - def __torch_dispatch__( - cls, func: Op, types: Tuple[Type, ...], args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Any: - """Handle tensor operations. - - Args: - func: The operation to perform - types: Tuple of argument types - args: Positional arguments - kwargs: Keyword arguments - - Returns: - Result of the operation - """ - pass diff --git a/transformer_nuggets/utils/benchmark.py b/transformer_nuggets/utils/benchmark.py index b480fec..3722673 100644 --- a/transformer_nuggets/utils/benchmark.py +++ b/transformer_nuggets/utils/benchmark.py @@ -255,6 +255,7 @@ def oom_observer(device, alloc, device_alloc, device_free): torch._C._cuda_attach_out_of_memory_observer(oom_observer) torch.cuda.memory._record_memory_history(max_entries=max_entries) + @contextmanager def profiler( path: Path, @@ -296,7 +297,7 @@ def trace_handler(prof) -> None: profile_memory=profile_memory, with_stack=with_stack, ) - + try: profiler.start() yield profiler