Skip to content

Commit

Permalink
Clean-up todos and other comments
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Sep 26, 2024
1 parent fa79dc9 commit b293169
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
19 changes: 8 additions & 11 deletions sharktank/sharktank/kernels/einsum_2args_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def einsum_util(einsum_str):
else:
out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + ","
else:
printf("invalid einsum string")
exit(1)
raise Exception("Invalid einsum string")
out_dyn_dim_size_str = out_dyn_dim_size_str[:-1]
return (
(in0_idx, in1_idx, out_idx),
Expand All @@ -79,17 +78,16 @@ def einsum_util(einsum_str):

@CustomOp.register(library=LIBRARY)
class einsum_2args_q4(CustomOp):
"""Generic block scaled matmul with transposed RHS.
"""Einsum that takes two tensor inputs and returns one tensor.
This corresponds to the BlockScaledLayout and operates on planar `d`
and `qs` tensors as specified there:
The first input is expected to be a normal tensor.
* `d`: `[N, K // BLOCK_SIZE, 1]`
* `qs`: `[N, K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8)
* `m`: `[N, K // BLOCK_SIZE, 1]`
The second input corresponds to the BlockScaledLayout and operates on planar `d`
and `qs` tensors as specified there:
The LHS is expected to be a 3d tensor of shape [B, M, K]. The kernel
will be specialized for all values of N, K and LHS dtype.
* `d`: `[..., K // BLOCK_SIZE, 1]`
* `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8)
* `m`: `[..., K // BLOCK_SIZE, 1]`
"""

signature = (
Expand Down Expand Up @@ -178,7 +176,6 @@ def select(self, ksel: KernelSelection):
elif pos1 >= 0:
out_dims.append(b_dim)
else:
# TODO: I'm not sure if einsum notation supports broadcast in outputs, disabling it for now
torch._check(
False,
lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')",
Expand Down
5 changes: 0 additions & 5 deletions sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
-> !c_tensor_type {
%debug = tensor.empty() : tensor<1xf32>
%zero = arith.constant 0.0: !accum_type
// todo: loop
{% for i in range(a_size) %}
%k{{i}} = arith.constant {{i}} : index
{% endfor %}
Expand All @@ -47,7 +46,6 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}}

// Dequantize.
// todo: loop
%b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type
%b_grouped_dequant = linalg.generic {
indexing_maps = [
Expand All @@ -71,11 +69,9 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
} -> !b_grouped_tensor_type

// Collapse %b to the same unblocked structure.
// todo: loop
%b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type

// Einsum
// todo: loop, right dimensions
%result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type
%result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type
%result = linalg.generic {
Expand All @@ -96,7 +92,6 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
} -> !accum_tensor_type

// Cast.
// todo: loop, right dimensions
%result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type
%result_cast = linalg.copy
ins(%result : !accum_tensor_type)
Expand Down

0 comments on commit b293169

Please sign in to comment.