Skip to content

Commit

Permalink
Support e8m0
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Feb 22, 2025
1 parent 2d17686 commit f234d14
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions transformer_nuggets/mx/to_blocked.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn.functional as F

Tensor = torch.Tensor

Expand All @@ -8,7 +7,7 @@ def ceil_div(a, b):
return (a + b - 1) // b


def to_blocked(input_matrix) -> Tensor:
def to_blocked(input_matrix: Tensor) -> Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
Expand All @@ -25,14 +24,21 @@ def to_blocked(input_matrix) -> Tensor:
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Pad out and view as tiles of (128, 4)
padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128))
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype
)
padded[:rows, :cols] = input_matrix

# rearrange all tiles
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

# Layout rearranged tiles according to second pic
return rearranged.flatten()


Expand Down

0 comments on commit f234d14

Please sign in to comment.