Skip to content

Commit

Permalink
remove bad format
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 28, 2025
1 parent ade059b commit 382cb0f
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 83 deletions.
10 changes: 0 additions & 10 deletions test/test_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
_to_blocked_single,
_to_blocked_single_manual,
to_blocked,
to_blocked_v2,
to_blocked_manual,
to_blocked_manual_v2,
)


Expand All @@ -25,11 +23,3 @@ def test_rearrange(shape):
eager = to_blocked(scales)
manual = to_blocked_manual(scales)
assert torch.equal(eager, manual)


@pytest.mark.parametrize("shape", [(128, 4), (256, 8), (300, 9)])
def test_rearrange_v2(shape):
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
eager = to_blocked_v2(scales)
manual = to_blocked_manual_v2(scales)
assert torch.equal(eager, manual)
73 changes: 0 additions & 73 deletions transformer_nuggets/mx/to_blocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,6 @@ def to_blocked(input_matrix) -> Tensor:
# rearrange all tiles
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

return (
rearranged.reshape(n_row_blocks, n_col_blocks, 32, 16)
.permute(0, 2, 1, 3)
.reshape(32 * n_row_blocks, 16 * n_col_blocks)
)


def to_blocked_v2(input_matrix) -> Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Pad out and view as tiles of (128, 4)
padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128))
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)

# rearrange all tiles
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

# Layout rearranged tiles according to second pic
return rearranged.flatten()

Expand Down Expand Up @@ -99,48 +68,6 @@ def to_blocked_manual(input_matrix) -> Tensor:
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Create output tensor
output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device)

# Process each block
for row_block in range(n_row_blocks):
for col_block in range(n_col_blocks):
# Calculate input block boundaries
row_start = row_block * 128
row_end = min(row_start + 128, rows) # Avoid going out of bounds
col_start = col_block * 4
col_end = min(col_start + 4, cols) # Avoid going out of bounds

# Calculate output block boundaries
out_row_start = row_block * 32
out_row_end = out_row_start + 32
out_col_start = col_block * 16
out_col_end = out_col_start + 16

block = input_matrix[row_start:row_end, col_start:col_end]

row_size = row_end - row_start
col_size = col_end - col_start
if row_size < 128 or col_size < 4:
# pad out local block with zeros
block = torch.nn.functional.pad(block, (0, 4 - col_size, 0, 128 - row_size))

rearranged_block = _to_blocked_single(block)
output[out_row_start:out_row_end, out_col_start:out_col_end] = rearranged_block

return output


def to_blocked_manual_v2(input_matrix) -> Tensor:
"""Slow for testing purposes"""
device = input_matrix.device
dtype = input_matrix.dtype

rows, cols = input_matrix.shape

n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Create output tensor
output = torch.zeros(512 * n_row_blocks * n_col_blocks, dtype=dtype, device=device)
# output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device)
Expand Down

0 comments on commit 382cb0f

Please sign in to comment.