From 382cb0f19a5f615827174289b8ef552419d51fea Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 28 Jan 2025 11:54:21 -0800 Subject: [PATCH] remove bad format --- test/test_mx.py | 10 ---- transformer_nuggets/mx/to_blocked.py | 73 ---------------------------- 2 files changed, 83 deletions(-) diff --git a/test/test_mx.py b/test/test_mx.py index a0195e9..2ed0570 100644 --- a/test/test_mx.py +++ b/test/test_mx.py @@ -5,9 +5,7 @@ _to_blocked_single, _to_blocked_single_manual, to_blocked, - to_blocked_v2, to_blocked_manual, - to_blocked_manual_v2, ) @@ -25,11 +23,3 @@ def test_rearrange(shape): eager = to_blocked(scales) manual = to_blocked_manual(scales) assert torch.equal(eager, manual) - - -@pytest.mark.parametrize("shape", [(128, 4), (256, 8), (300, 9)]) -def test_rearrange_v2(shape): - scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) - eager = to_blocked_v2(scales) - manual = to_blocked_manual_v2(scales) - assert torch.equal(eager, manual) diff --git a/transformer_nuggets/mx/to_blocked.py b/transformer_nuggets/mx/to_blocked.py index 44461ac..c6f55cd 100644 --- a/transformer_nuggets/mx/to_blocked.py +++ b/transformer_nuggets/mx/to_blocked.py @@ -32,37 +32,6 @@ def to_blocked(input_matrix) -> Tensor: # rearrange all tiles rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - return ( - rearranged.reshape(n_row_blocks, n_col_blocks, 32, 16) - .permute(0, 2, 1, 3) - .reshape(32 * n_row_blocks, 16 * n_col_blocks) - ) - - -def to_blocked_v2(input_matrix) -> Tensor: - """ - Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. - - See: - https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout - - Args: - input_matrix: Input tensor of shape (H, W) - - Returns: - Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) - """ - rows, cols = input_matrix.shape - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) - - # Pad out and view as tiles of (128, 4) - padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) - - # rearrange all tiles - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - # Layout rearranged tiles according to second pic return rearranged.flatten() @@ -99,48 +68,6 @@ def to_blocked_manual(input_matrix) -> Tensor: n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) - # Create output tensor - output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device) - - # Process each block - for row_block in range(n_row_blocks): - for col_block in range(n_col_blocks): - # Calculate input block boundaries - row_start = row_block * 128 - row_end = min(row_start + 128, rows) # Avoid going out of bounds - col_start = col_block * 4 - col_end = min(col_start + 4, cols) # Avoid going out of bounds - - # Calculate output block boundaries - out_row_start = row_block * 32 - out_row_end = out_row_start + 32 - out_col_start = col_block * 16 - out_col_end = out_col_start + 16 - - block = input_matrix[row_start:row_end, col_start:col_end] - - row_size = row_end - row_start - col_size = col_end - col_start - if row_size < 128 or col_size < 4: - # pad out local block with zeros - block = torch.nn.functional.pad(block, (0, 4 - col_size, 0, 128 - row_size)) - - rearranged_block = _to_blocked_single(block) - output[out_row_start:out_row_end, out_col_start:out_col_end] = rearranged_block - - return output - - -def to_blocked_manual_v2(input_matrix) -> Tensor: - """Slow for testing purposes""" - device = input_matrix.device - dtype = input_matrix.dtype - - rows, cols = input_matrix.shape - - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) - # Create output tensor output = torch.zeros(512 * n_row_blocks * n_col_blocks, dtype=dtype, device=device) # output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device)