From fc4565668bf91d82712364f6ad5841df1bc9eb3b Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 23 Jan 2025 17:58:54 -0800 Subject: [PATCH] hacking ghstack-source-id: f8b11a94ec97e73f1754983ff1029f5ef672f1f0 Pull Request resolved: https://github.com/drisspg/transformer_nuggets/pull/40 stack-info: PR: https://github.com/drisspg/transformer_nuggets/pull/41, branch: drisspg/stack/3 --- test/test_mx.py | 25 +++++++ transformer_nuggets/mx/__init__.py | 0 transformer_nuggets/mx/to_blocked.py | 104 +++++++++++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 test/test_mx.py create mode 100644 transformer_nuggets/mx/__init__.py create mode 100644 transformer_nuggets/mx/to_blocked.py diff --git a/test/test_mx.py b/test/test_mx.py new file mode 100644 index 0000000..10768a8 --- /dev/null +++ b/test/test_mx.py @@ -0,0 +1,25 @@ +import pytest +import torch + +from transformer_nuggets.mx.to_blocked import ( + _to_blocked_single, + _to_blocked_single_manual, + to_blocked, + to_blocked_manual, +) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]) +def test_individual(device): + scales = torch.randint(256, size=(128, 4), device="cuda", dtype=torch.uint8) + single = _to_blocked_single(scales) + single_vmap = _to_blocked_single_manual(scales) + assert torch.equal(single, single_vmap) + + +@pytest.mark.parametrize("shape", [(128, 4), (256, 8), (300, 9)]) +def test_kernel(shape): + scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) + eager = to_blocked(scales) + triton = to_blocked_manual(scales) + assert torch.equal(eager, triton) diff --git a/transformer_nuggets/mx/__init__.py b/transformer_nuggets/mx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transformer_nuggets/mx/to_blocked.py b/transformer_nuggets/mx/to_blocked.py new file mode 100644 index 0000000..bf476af --- /dev/null +++ b/transformer_nuggets/mx/to_blocked.py @@ -0,0 +1,104 @@ +import torch +import torch.nn.functional as F + +Tensor = torch.Tensor + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked(input_matrix) -> Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Pad out and view as tiles of (128, 4) + padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + + # rearrange all tiles + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + # Layout rearranged tiles according to second pic + return ( + rearranged.reshape(n_row_blocks, n_col_blocks, 32, 16) + .permute(0, 2, 1, 3) + .reshape(32 * n_row_blocks, 16 * n_col_blocks) + ) + + +def _to_blocked_single_manual(scales: Tensor) -> Tensor: + """Slow for testing""" + scales = scales.view(-1, 32, 4) + output = torch.zeros(512, dtype=scales.dtype, device=scales.device).view(32, 16) + for i in range(4): + start = i * 4 + end = start + 4 + output[:, start:end] = scales[i, :, :] # copying 32x4 blocks + return output + + +def _to_blocked_single(scales: Tensor) -> Tensor: + """Assume that we have a 128x4 block of scales in K Major order + + To see more information on the individual tile layout: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + assert scales.shape == (128, 4) + scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles + return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles + + +def to_blocked_manual(input_matrix) -> Tensor: + """Slow for testing purposes""" + device = input_matrix.device + dtype = input_matrix.dtype + + rows, cols = input_matrix.shape + + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Create output tensor + output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device) + + # Process each block + for row_block in range(n_row_blocks): + for col_block in range(n_col_blocks): + # Calculate input block boundaries + row_start = row_block * 128 + row_end = min(row_start + 128, rows) # Avoid going out of bounds + col_start = col_block * 4 + col_end = min(col_start + 4, cols) # Avoid going out of bounds + + # Calculate output block boundaries + out_row_start = row_block * 32 + out_row_end = out_row_start + 32 + out_col_start = col_block * 16 + out_col_end = out_col_start + 16 + + block = input_matrix[row_start:row_end, col_start:col_end] + + row_size = row_end - row_start + col_size = col_end - col_start + if row_size < 128 or col_size < 4: + # pad out local block with zeros + block = torch.nn.functional.pad(block, (0, 4 - col_size, 0, 128 - row_size)) + + rearranged_block = _to_blocked_single(block) + output[out_row_start:out_row_end, out_col_start:out_col_end] = rearranged_block + + return output