-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: f8b11a94ec97e73f1754983ff1029f5ef672f1f0 Pull Request resolved: #40 stack-info: PR: #41, branch: drisspg/stack/3
- Loading branch information
Showing
3 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
import torch | ||
|
||
from transformer_nuggets.mx.to_blocked import ( | ||
_to_blocked_single, | ||
_to_blocked_single_manual, | ||
to_blocked, | ||
to_blocked_manual, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]) | ||
def test_individual(device): | ||
scales = torch.randint(256, size=(128, 4), device="cuda", dtype=torch.uint8) | ||
single = _to_blocked_single(scales) | ||
single_vmap = _to_blocked_single_manual(scales) | ||
assert torch.equal(single, single_vmap) | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(128, 4), (256, 8), (300, 9)]) | ||
def test_kernel(shape): | ||
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) | ||
eager = to_blocked(scales) | ||
triton = to_blocked_manual(scales) | ||
assert torch.equal(eager, triton) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
Tensor = torch.Tensor | ||
|
||
|
||
def ceil_div(a, b): | ||
return (a + b - 1) // b | ||
|
||
|
||
def to_blocked(input_matrix) -> Tensor: | ||
""" | ||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. | ||
See: | ||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout | ||
Args: | ||
input_matrix: Input tensor of shape (H, W) | ||
Returns: | ||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) | ||
""" | ||
rows, cols = input_matrix.shape | ||
n_row_blocks = ceil_div(rows, 128) | ||
n_col_blocks = ceil_div(cols, 4) | ||
|
||
# Pad out and view as tiles of (128, 4) | ||
padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) | ||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) | ||
|
||
# rearrange all tiles | ||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) | ||
|
||
# Layout rearranged tiles according to second pic | ||
return ( | ||
rearranged.reshape(n_row_blocks, n_col_blocks, 32, 16) | ||
.permute(0, 2, 1, 3) | ||
.reshape(32 * n_row_blocks, 16 * n_col_blocks) | ||
) | ||
|
||
|
||
def _to_blocked_single_manual(scales: Tensor) -> Tensor: | ||
"""Slow for testing""" | ||
scales = scales.view(-1, 32, 4) | ||
output = torch.zeros(512, dtype=scales.dtype, device=scales.device).view(32, 16) | ||
for i in range(4): | ||
start = i * 4 | ||
end = start + 4 | ||
output[:, start:end] = scales[i, :, :] # copying 32x4 blocks | ||
return output | ||
|
||
|
||
def _to_blocked_single(scales: Tensor) -> Tensor: | ||
"""Assume that we have a 128x4 block of scales in K Major order | ||
To see more information on the individual tile layout: | ||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout | ||
""" | ||
assert scales.shape == (128, 4) | ||
scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles | ||
return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles | ||
|
||
|
||
def to_blocked_manual(input_matrix) -> Tensor: | ||
"""Slow for testing purposes""" | ||
device = input_matrix.device | ||
dtype = input_matrix.dtype | ||
|
||
rows, cols = input_matrix.shape | ||
|
||
n_row_blocks = ceil_div(rows, 128) | ||
n_col_blocks = ceil_div(cols, 4) | ||
|
||
# Create output tensor | ||
output = torch.zeros((32 * n_row_blocks, 16 * n_col_blocks), dtype=dtype, device=device) | ||
|
||
# Process each block | ||
for row_block in range(n_row_blocks): | ||
for col_block in range(n_col_blocks): | ||
# Calculate input block boundaries | ||
row_start = row_block * 128 | ||
row_end = min(row_start + 128, rows) # Avoid going out of bounds | ||
col_start = col_block * 4 | ||
col_end = min(col_start + 4, cols) # Avoid going out of bounds | ||
|
||
# Calculate output block boundaries | ||
out_row_start = row_block * 32 | ||
out_row_end = out_row_start + 32 | ||
out_col_start = col_block * 16 | ||
out_col_end = out_col_start + 16 | ||
|
||
block = input_matrix[row_start:row_end, col_start:col_end] | ||
|
||
row_size = row_end - row_start | ||
col_size = col_end - col_start | ||
if row_size < 128 or col_size < 4: | ||
# pad out local block with zeros | ||
block = torch.nn.functional.pad(block, (0, 4 - col_size, 0, 128 - row_size)) | ||
|
||
rearranged_block = _to_blocked_single(block) | ||
output[out_row_start:out_row_end, out_col_start:out_col_end] = rearranged_block | ||
|
||
return output |