Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] restore custom matmul kernel #896

Merged
merged 35 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8300bc8
restore custom matmul kernel
dan-garvey Feb 2, 2025
a98a332
not mergeable as-is
dan-garvey Feb 2, 2025
80fee98
Make batch_matmul_transpose_b accept accumulation dtype
sogartar Feb 4, 2025
9ff020e
Merge batch_matmul_transpose_b export tests into 1
sogartar Feb 4, 2025
7b93a6a
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
dan-garvey Feb 7, 2025
781c8e8
Add exception to qlinear to not use the kernel when unsigned ints
sogartar Feb 8, 2025
9f1c3d4
Small fix
sogartar Feb 8, 2025
82b032a
Add eager execution to circamvent failure to compile for llvm-cpu
sogartar Feb 10, 2025
aa5c7b0
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
sogartar Feb 11, 2025
de70094
Convert dtype when writing into the cache
sogartar Feb 11, 2025
ae89b55
Fix attention_dtype flag for paged_llm_v1
aviator19941 Feb 13, 2025
53f8cd1
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
sogartar Feb 13, 2025
5bf4636
KV cache workaround for Torch not supporting torch.Tensor.index_copy_…
sogartar Feb 14, 2025
fe5c881
Fix kv_cache index_put_ issue
archana-ramalingam Feb 14, 2025
b4be2a8
Revert "Fix kv_cache index_put_ issue"
archana-ramalingam Feb 14, 2025
338fe67
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
archana-ramalingam Feb 14, 2025
462ddc4
Fix KV cache index_copy_ f8 workaround
sogartar Feb 14, 2025
4dc2ac2
In linear for (Tensor, QuantizedTensor) raise if accum_dtype is given
sogartar Feb 14, 2025
13bfc68
Fix KV cache f8
sogartar Feb 20, 2025
8b20445
Remove unused HF dataset
sogartar Feb 20, 2025
9f160f1
Add KV cache dtype different from attention dtype
sogartar Feb 20, 2025
77a8443
Add more KV cache tests for various dtypes
sogartar Feb 20, 2025
1ea608a
Remove some unwanted corner casehandlings in linear layer
sogartar Feb 20, 2025
6f0c98b
Add more linear layer tests
sogartar Feb 20, 2025
664a847
Refactor quark parity test to use tmp dir
sogartar Feb 20, 2025
9816c35
Fix KV cache dtype CLI arg parsing
sogartar Feb 20, 2025
b8ff8cc
Merge remote-tracking branch 'origin/main' into users/dan-garvey/enab…
sogartar Feb 20, 2025
c17629e
Change doc example to not use the removed Llama dataset
sogartar Feb 20, 2025
9b7dfdf
Add KV cache dtype to benchmark
sogartar Feb 20, 2025
55c8701
Change testBenchmark8B_fp8_Non_Decomposed xfail reason to compilation…
sogartar Feb 21, 2025
fea5204
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
archana-ramalingam Feb 21, 2025
740bb80
Put back in the llama3_8B_fp16 HF dataset
sogartar Feb 21, 2025
02215d5
Remove left behind comment
sogartar Feb 21, 2025
40f993a
Make quark parity test use f8 KV cache
sogartar Feb 21, 2025
85053ef
Add more bf16 qlinear tests and make ref dtype be f64
sogartar Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def main():
block_seq_stride=args.block_seq_stride,
activation_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
kv_cache_dtype=args.kv_cache_dtype,
)
llama_config.fake_quant = args.fake_quant

Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def main():
block_seq_stride=args.block_seq_stride,
device=device,
activation_dtype=args.activation_dtype,
attention_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
attention_kernel=args.attention_kernel,
kv_cache_dtype=args.kv_cache_dtype,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
fake_quant=args.fake_quant,
Expand Down
42 changes: 36 additions & 6 deletions sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,36 @@
from sharktank.kernels.base import *

import torch
from typing import cast, Optional

from iree.compiler.ir import IntegerType
from iree.compiler.ir import IntegerType, Type
from iree.turbine.support.conversions import (
TORCH_DTYPE_TO_IREE_TYPE_ASM,
IREE_TYPE_ASM_TO_TORCH_DTYPE,
)
from iree.turbine.runtime.op_reg import AttrArg

__all__ = [
"batch_matmul_transpose_b",
]


def batch_matmul_transpose_b(
lhs: torch.Tensor,
rhs: torch.Tensor,
/,
*,
accum_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if accum_dtype is None:
accum_dtype = lhs.dtype
return _batch_matmul_transpose_b(
lhs, rhs, accum_dtype=TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype]
)


@CustomOp.register(library=LIBRARY)
class batch_matmul_transpose_b(CustomOp):
class _batch_matmul_transpose_b(CustomOp):
"""Generic block scaled matmul with transposed RHS.
The LHS is expected to be a 3d tensor of shape [B, M, K]. RHS must be
Expand All @@ -25,11 +45,18 @@ class batch_matmul_transpose_b(CustomOp):
The kernel will be specialized for all values of N, K and LHS dtype.
"""

signature = "batch_matmul_transpose_b(Tensor lhs, Tensor rhs) -> (Tensor)"
signature = (
"batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)"
)

def eager_execute(self, lhs: torch.Tensor, rhs: torch.Tensor, accum_dtype: str):
dtype = IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_dtype]
return torch.matmul(lhs.to(dtype=dtype), rhs.transpose(-1, -2).to(dtype=dtype))

def select(self, ksel: KernelSelection):
lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K]
rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K]
accum_type_attr = ksel.attr_str(2)

# Rank check.
torch._check(
Expand Down Expand Up @@ -60,7 +87,8 @@ def select(self, ksel: KernelSelection):
)
# Shape batch, m, n
c_desc = ksel.return_new_tensor(
[lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype
[lhs_batch, lhs_m, rhs_n],
dtype=IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_type_attr.v],
)
specialize_all_known_dims(lhs_desc)
specialize_all_known_dims(rhs_desc)
Expand All @@ -74,12 +102,14 @@ def select(self, ksel: KernelSelection):
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
lhs = kb.arg_value(0)
rhs = kb.arg_value(1)
accum_type_str = cast(AttrArg, ksel.arg_descs[2]).v
result_desc = ksel.result_descs[0]

# Generate specialization signature and types.
a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type)
a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type)
b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type)
spec_sig = f"L{a_ident}_R{b_ident}"
accum_type = Type.parse(accum_type_str)
spec_sig = f"L{a_ident}_R{b_ident}_{accum_type_str}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
cst_zero = "0" if IntegerType.isinstance(accum_type) else "0."
Expand Down
3 changes: 3 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ class LlamaModelConfig:
# Either "paged" or "direct".
kv_cache_type: str = "paged"

# If None will use attention_dtype.
kv_cache_dtype: Optional[torch.dtype] = None

# The device on which to place intermediate state.
device: Optional[torch.device] = None

Expand Down
18 changes: 16 additions & 2 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,14 @@ def write_timestep(
partitions = partitions.repeat(bs, 1)

indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partition)
values = ops.to(cache_partition, dtype=page_table.dtype)
if page_table.dtype == torch.float8_e4m3fnuz:
# Workaround for Torch not supporting torch.Tensor.index_copy_ for f8.
page_table_as_int8 = page_table.view(dtype=torch.int8)
values_int8 = values.view(dtype=torch.int8)
page_table_as_int8.index_put_(indices=indices, values=values_int8)
else:
page_table.index_put_(indices=indices, values=values)

return

Expand Down Expand Up @@ -320,4 +327,11 @@ def write(
(base_subblock_ids + index) if index > 0 else base_subblock_ids
).flatten(0, 1)

subblock_table.index_copy_(0, subblock_ids, part_block_view)
part_block = ops.to(part_block_view, dtype=subblock_table.dtype)
if subblock_table.dtype == torch.float8_e4m3fnuz:
# Workaround for Torch not supporting torch.Tensor.index_copy_ for f8.
subblock_table_as_int8 = subblock_table.view(dtype=torch.int8)
part_block_as_int8 = part_block.view(dtype=torch.int8)
subblock_table_as_int8.index_copy_(0, subblock_ids, part_block_as_int8)
else:
subblock_table.index_copy_(0, subblock_ids, part_block)
4 changes: 0 additions & 4 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ def forward(self, x):
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.bfloat16)
return y
if qdq_output is not None:
y = qdq_output.quantize(y).unpack().dequant()
return y
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
attention_dtype: Optional[torch.dtype] = None,
attention_kernel: str = "decomposed",
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
Expand All @@ -49,6 +50,7 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.attention_dtype = attention_dtype
self.attention_kernel = attention_kernel
self.attention_scale = attention_scale
self.softcap = softcap
Expand Down Expand Up @@ -161,13 +163,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# Fake quant is already dequantized when stored in the cache.
if self.cache_quantizer and not self.fake_quant:
xk = self.cache_quantizer.dequantize_raw_tensor(
xk, torch.bfloat16, name="xk_deq"
xk, self.attention_dtype, name="xk_deq"
)
xv = self.cache_quantizer.dequantize_raw_tensor(
xv, torch.bfloat16, name="xv_deq"
xv, self.attention_dtype, name="xv_deq"
)
if attention_mask is not None:
attention_mask = attention_mask.to(torch.bfloat16)
attention_mask = attention_mask.to(self.attention_dtype)

# Transpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
Expand Down
3 changes: 3 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
attention_dtype=config.attention_dtype,
attention_kernel=self.attention_kernel,
fake_quant=self.fake_quant,
)
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
attention_dtype: Optional[torch.dtype] = None,
attention_kernel: str = "decomposed",
fake_quant: bool = True,
):
Expand All @@ -255,6 +257,7 @@ def __init__(
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
attention_dtype=attention_dtype,
attention_kernel=attention_kernel,
fake_quant=fake_quant,
),
Expand Down
38 changes: 25 additions & 13 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def qlinear_tensor_scaled(
# Now we know that both the x/weight are TensorScaledLayout. There are still
# degrees of freedom:
# * Either/both can be per-tensor or per-axis scaled (d is 0D or d is nd>0).
# * Either/both can have offsets of not (m is not None).
# * Either/both can have offsets or not (m is not None).
x_layout: TensorScaledLayout = x.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# Handle only integer and fp8 quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
if x_layout.qs.dtype == torch.float8_e4m3fnuz:
# assume quark
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True)
else:
if (
x_layout.qs.dtype != torch.float8_e4m3fnuz
or weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented

# Bias.
Expand Down Expand Up @@ -160,6 +160,8 @@ def linear_quantized_weight(
*,
accum_dtype: Optional[torch.dtype],
) -> AnyTensor:
if accum_dtype is not None:
raise NotImplementedError("TODO: implement when is passed accum_dtype")
res = matmul(x, weight, transpose_rhs=True)
if bias is not None:
res = res + bias
Expand All @@ -170,7 +172,13 @@ def linear_quantized_weight(
linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight)


def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
def _is_dtype_unsigned_integer(dtype: torch.dtype):
return not dtype.is_complex and not dtype.is_floating_point and not dtype.is_signed


def _invoke_mmt_kernel(
lhs: torch.Tensor, rhs: torch.Tensor, *, accum_dtype: torch.dtype
):
if debugging.flags.use_custom_iree_kernels:
# The custom kernel requires that the lhs and rhs be the same
# rank. Broadcast the rhs to match.
Expand All @@ -187,13 +195,17 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
# Squeeze the batch dimension to maintain shape parity with other
# layers.
if len(y_qs.shape) > 2:
y_qs = y_qs.squeeze(0)
if (
_is_dtype_unsigned_integer(lhs.dtype)
or _is_dtype_unsigned_integer(rhs.dtype)
or _is_dtype_unsigned_integer(accum_dtype)
):
# TODO: make the kernel work with unsigned types.
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(dtype=accum_dtype), rhs.to(dtype=accum_dtype)
)
else:
y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype)
else:
# FP emulation.
y_qs = torch.matmul(
Expand Down
13 changes: 10 additions & 3 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None)
"""Parses arguments and does any prescribed global process setup."""
parsed_args = parser.parse_args(args)
# Set torch dtypes
for attr in ["activation_dtype", "attention_dtype"]:
for attr in ["activation_dtype", "attention_dtype", "kv_cache_dtype"]:
if hasattr(parsed_args, attr):
dtype = getattr(torch, getattr(parsed_args, attr))
assert isinstance(dtype, torch.dtype)
dtype = getattr(parsed_args, attr)
if dtype is not None:
dtype = getattr(torch, dtype)
assert isinstance(dtype, torch.dtype)
setattr(parsed_args, attr, dtype)
return parsed_args

Expand Down Expand Up @@ -100,6 +102,11 @@ def add_model_options(parser: argparse.ArgumentParser):
help="DType to use for activations in the model",
default="float16",
)
parser.add_argument(
"--kv-cache-dtype",
help="DType to use for the KV cache. If not given will be attention dtype",
default=None,
)
parser.add_argument("--device", help="Torch device (or default)")

parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/utils/create_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache:
raise ValueError("Model does not use paged kv cache, cannot create kv cache")

hp = config.hp
dtype = config.kv_cache_dtype or config.attention_dtype
return PagedKVCache(
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=config.block_seq_stride,
device=config.device,
dtype=config.attention_dtype,
dtype=dtype,
shard_count=config.tensor_parallelism_size,
)
17 changes: 0 additions & 17 deletions sharktank/sharktank/utils/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,6 @@ def alias_dataset(from_name: str, to_name: str):
# Dataset definitions
################################################################################

Dataset(
"SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
(
RemoteFile(
"gguf",
"SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
"meta-llama-3.1-8b-instruct.f16.gguf",
),
RemoteFile(
"tokenizer_config.json",
"NousResearch/Meta-Llama-3-8B-Instruct",
"tokenizer_config.json",
extra_filenames=["tokenizer.json"],
),
),
).alias_to("llama3_8B_fp16")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model cannot be removed as it is part of llama_serving.md release docs, that was tested in sharktank/shortfin. However, agree that we need only one of the 2 llama 8b fp16 models listed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs could be updated to use standard huggingface tooling instead of sharktank.utils.hf_datasets. As written today, the hf_datasets file should only be a utility for project development and testing, not something user-facing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the doc.

Copy link
Collaborator

@archana-ramalingam archana-ramalingam Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama3_8B_fp16 is the right instruct version that was tested to be numerically right for the release, the other model is non-instruct version and I remember it was generating repetitive tokens.
This can be a separate PR where we consult with shortfin folks if this model switch on the release docs can be made.


Dataset(
"QuantFactory/Llama-3-8B_q4_1_gguf",
(
Expand Down
Loading
Loading