diff --git a/transformer_nuggets/mx/to_blocked.py b/transformer_nuggets/mx/to_blocked.py index c6f55cd..4432f15 100644 --- a/transformer_nuggets/mx/to_blocked.py +++ b/transformer_nuggets/mx/to_blocked.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F Tensor = torch.Tensor @@ -8,7 +7,7 @@ def ceil_div(a, b): return (a + b - 1) // b -def to_blocked(input_matrix) -> Tensor: +def to_blocked(input_matrix: Tensor) -> Tensor: """ Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. @@ -25,14 +24,21 @@ def to_blocked(input_matrix) -> Tensor: n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) - # Pad out and view as tiles of (128, 4) - padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype + ) + padded[:rows, :cols] = input_matrix - # rearrange all tiles + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - # Layout rearranged tiles according to second pic return rearranged.flatten()