Skip to content

Commit

Permalink
hacking
Browse files Browse the repository at this point in the history
ghstack-source-id: f8b11a94ec97e73f1754983ff1029f5ef672f1f0
Pull Request resolved: #40

stack-info: PR: #41, branch: drisspg/stack/3
  • Loading branch information
drisspg committed Jan 25, 2025
1 parent 46127c6 commit fc45656
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 0 deletions.
25 changes: 25 additions & 0 deletions test/test_mx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import torch

from transformer_nuggets.mx.to_blocked import (
_to_blocked_single,
_to_blocked_single_manual,
to_blocked,
to_blocked_manual,
)


@pytest.mark.parametrize("device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"])
def test_individual(device):
scales = torch.randint(256, size=(128, 4), device="cuda", dtype=torch.uint8)
single = _to_blocked_single(scales)
single_vmap = _to_blocked_single_manual(scales)
assert torch.equal(single, single_vmap)


@pytest.mark.parametrize("shape", [(128, 4), (256, 8), (300, 9)])
def test_kernel(shape):
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
eager = to_blocked(scales)
triton = to_blocked_manual(scales)
assert torch.equal(eager, triton)
Empty file.
104 changes: 104 additions & 0 deletions transformer_nuggets/mx/to_blocked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import torch.nn.functional as F

Tensor = torch.Tensor


def ceil_div(a, b):
return (a + b - 1) // b


def to_blocked(input_matrix) -> Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Pad out and view as tiles of (128, 4)
padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128))
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)

# rearrange all tiles
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

# Layout rearranged tiles according to second pic
return (
rearranged.reshape(n_row_blocks, n_col_blocks, 32, 16)
.permute(0, 2, 1, 3)
.reshape(32 * n_row_blocks, 16 * n_col_blocks)
)


def _to_blocked_single_manual(scales: Tensor) -> Tensor:
"""Slow for testing"""
scales = scales.view(-1, 32, 4)
output = torch.zeros(512, dtype=scales.dtype, device=scales.device).view(32, 16)
for i in range(4):
start = i * 4
end = start + 4
output[:, start:end] = scales[i, :, :] # copying 32x4 blocks
return output


def _to_blocked_single(scales: Tensor) -> Tensor:
"""Assume that we have a 128x4 block of scales in K Major order
To see more information on the individual tile layout:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
assert scales.shape == (128, 4)
scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles
return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles


def to_blocked_manual(input_matrix) -> Tensor:
"""Slow for testing purposes"""
device = input_matrix.device
dtype = input_matrix.dtype

rows, cols = input_matrix.shape

n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Create output tensor
output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device)

# Process each block
for row_block in range(n_row_blocks):
for col_block in range(n_col_blocks):
# Calculate input block boundaries
row_start = row_block * 128
row_end = min(row_start + 128, rows) # Avoid going out of bounds
col_start = col_block * 4
col_end = min(col_start + 4, cols) # Avoid going out of bounds

# Calculate output block boundaries
out_row_start = row_block * 32
out_row_end = out_row_start + 32
out_col_start = col_block * 16
out_col_end = out_col_start + 16

block = input_matrix[row_start:row_end, col_start:col_end]

row_size = row_end - row_start
col_size = col_end - col_start
if row_size < 128 or col_size < 4:
# pad out local block with zeros
block = torch.nn.functional.pad(block, (0, 4 - col_size, 0, 128 - row_size))

rearranged_block = _to_blocked_single(block)
output[out_row_start:out_row_end, out_col_start:out_col_end] = rearranged_block

return output

0 comments on commit fc45656

Please sign in to comment.