Skip to content

Commit

Permalink
Merge branch 'main' into llm-micro-server
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida authored Feb 22, 2025
2 parents 5cb7094 + 6397ead commit 292e55d
Show file tree
Hide file tree
Showing 25 changed files with 527 additions and 222 deletions.
1 change: 1 addition & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def main():
block_seq_stride=args.block_seq_stride,
activation_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
kv_cache_dtype=args.kv_cache_dtype,
)
llama_config.fake_quant = args.fake_quant

Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def main():
block_seq_stride=args.block_seq_stride,
device=device,
activation_dtype=args.activation_dtype,
attention_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
attention_kernel=args.attention_kernel,
kv_cache_dtype=args.kv_cache_dtype,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
fake_quant=args.fake_quant,
Expand Down
42 changes: 36 additions & 6 deletions sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,36 @@
from sharktank.kernels.base import *

import torch
from typing import cast, Optional

from iree.compiler.ir import IntegerType
from iree.compiler.ir import IntegerType, Type
from iree.turbine.support.conversions import (
TORCH_DTYPE_TO_IREE_TYPE_ASM,
IREE_TYPE_ASM_TO_TORCH_DTYPE,
)
from iree.turbine.runtime.op_reg import AttrArg

__all__ = [
"batch_matmul_transpose_b",
]


def batch_matmul_transpose_b(
lhs: torch.Tensor,
rhs: torch.Tensor,
/,
*,
accum_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if accum_dtype is None:
accum_dtype = lhs.dtype
return _batch_matmul_transpose_b(
lhs, rhs, accum_dtype=TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype]
)


@CustomOp.register(library=LIBRARY)
class batch_matmul_transpose_b(CustomOp):
class _batch_matmul_transpose_b(CustomOp):
"""Generic block scaled matmul with transposed RHS.
The LHS is expected to be a 3d tensor of shape [B, M, K]. RHS must be
Expand All @@ -25,11 +45,18 @@ class batch_matmul_transpose_b(CustomOp):
The kernel will be specialized for all values of N, K and LHS dtype.
"""

signature = "batch_matmul_transpose_b(Tensor lhs, Tensor rhs) -> (Tensor)"
signature = (
"batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)"
)

def eager_execute(self, lhs: torch.Tensor, rhs: torch.Tensor, accum_dtype: str):
dtype = IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_dtype]
return torch.matmul(lhs.to(dtype=dtype), rhs.transpose(-1, -2).to(dtype=dtype))

def select(self, ksel: KernelSelection):
lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K]
rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K]
accum_type_attr = ksel.attr_str(2)

# Rank check.
torch._check(
Expand Down Expand Up @@ -60,7 +87,8 @@ def select(self, ksel: KernelSelection):
)
# Shape batch, m, n
c_desc = ksel.return_new_tensor(
[lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype
[lhs_batch, lhs_m, rhs_n],
dtype=IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_type_attr.v],
)
specialize_all_known_dims(lhs_desc)
specialize_all_known_dims(rhs_desc)
Expand All @@ -74,12 +102,14 @@ def select(self, ksel: KernelSelection):
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
lhs = kb.arg_value(0)
rhs = kb.arg_value(1)
accum_type_str = cast(AttrArg, ksel.arg_descs[2]).v
result_desc = ksel.result_descs[0]

# Generate specialization signature and types.
a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type)
a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type)
b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type)
spec_sig = f"L{a_ident}_R{b_ident}"
accum_type = Type.parse(accum_type_str)
spec_sig = f"L{a_ident}_R{b_ident}_{accum_type_str}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
cst_zero = "0" if IntegerType.isinstance(accum_type) else "0."
Expand Down
86 changes: 40 additions & 46 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Optional
import torch
from transformers import T5Config as T5ConfigHf

from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name


__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]


Expand Down Expand Up @@ -158,6 +160,9 @@ class LlamaModelConfig:
# Either "paged" or "direct".
kv_cache_type: str = "paged"

# If None will use attention_dtype.
kv_cache_dtype: Optional[torch.dtype] = None

# The device on which to place intermediate state.
device: Optional[torch.device] = None

Expand Down Expand Up @@ -231,53 +236,41 @@ def __post_init__(self):
self.dense_act_fn = "gelu_new"

@staticmethod
def from_gguf_properties(properties: dict[str, Any], **kwargs):
assert properties["general.architecture"] == "t5"
assert (
properties["t5.attention.layer_norm_epsilon"]
== properties["t5.attention.layer_norm_rms_epsilon"]
)

all_kwargs = {"vocab_size": None, "feed_forward_proj": None}

gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
"t5.feed_forward_length": ["d_ff"],
"t5.block_count": ["num_layers", "num_decoder_layers"],
"t5.attention.head_count": ["num_heads"],
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_config_names_map.items()
for config_name in config_names
}
)

gguf_to_optional_config_names_map = {
"t5.decoder_start_token_id": ["decoder_start_token_id"],
}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_optional_config_names_map.items()
for config_name in config_names
if gguf_name in properties
}
)

if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
def from_hugging_face_config(
config: T5ConfigHf, tokenizer_config: dict[str, Any], **kwargs
) -> "T5Config":
all_kwargs = {}
for filed in fields(T5Config):
if hasattr(config, filed.name):
all_kwargs[filed.name] = getattr(config, filed.name)
all_kwargs["context_length"] = tokenizer_config["model_max_length"]
del all_kwargs["is_gated_act"]
del all_kwargs["dense_act_fn"]
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)

@staticmethod
def from_properties(properties: dict[str, Any]) -> "T5Config":
kwargs = dict(properties)
if "SHARK_DATASET_VERSION" in kwargs:
kwargs.pop("SHARK_DATASET_VERSION")
if "activation_dtype" in kwargs and kwargs["activation_dtype"] is not None:
kwargs["activation_dtype"] = serialized_name_to_dtype(
kwargs["activation_dtype"]
)
if "is_gated_act" in kwargs:
kwargs.pop("is_gated_act")
if "dense_act_fn" in kwargs:
kwargs.pop("dense_act_fn")

return T5Config(**kwargs)

def to_properties(self) -> dict[str, Any]:
res = asdict(self)
if self.activation_dtype is not None:
res["activation_dtype"] = dtype_to_serialized_name(self.activation_dtype)
return res


@dataclass
class ClipTextConfig:
Expand Down Expand Up @@ -336,7 +329,8 @@ def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig
@staticmethod
def from_properties(properties: dict[str, Any]) -> "ClipTextConfig":
kwargs = dict(properties)
kwargs.pop("SHARK_DATASET_VERSION")
if "SHARK_DATASET_VERSION" in kwargs:
kwargs.pop("SHARK_DATASET_VERSION")
if "dtype" in kwargs and kwargs["dtype"] is not None:
kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"])

Expand Down
18 changes: 16 additions & 2 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,14 @@ def write_timestep(
partitions = partitions.repeat(bs, 1)

indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partition)
values = ops.to(cache_partition, dtype=page_table.dtype)
if page_table.dtype == torch.float8_e4m3fnuz:
# Workaround for Torch not supporting torch.Tensor.index_copy_ for f8.
page_table_as_int8 = page_table.view(dtype=torch.int8)
values_int8 = values.view(dtype=torch.int8)
page_table_as_int8.index_put_(indices=indices, values=values_int8)
else:
page_table.index_put_(indices=indices, values=values)

return

Expand Down Expand Up @@ -320,4 +327,11 @@ def write(
(base_subblock_ids + index) if index > 0 else base_subblock_ids
).flatten(0, 1)

subblock_table.index_copy_(0, subblock_ids, part_block_view)
part_block = ops.to(part_block_view, dtype=subblock_table.dtype)
if subblock_table.dtype == torch.float8_e4m3fnuz:
# Workaround for Torch not supporting torch.Tensor.index_copy_ for f8.
subblock_table_as_int8 = subblock_table.view(dtype=torch.int8)
part_block_as_int8 = part_block.view(dtype=torch.int8)
subblock_table_as_int8.index_copy_(0, subblock_ids, part_block_as_int8)
else:
subblock_table.index_copy_(0, subblock_ids, part_block)
8 changes: 0 additions & 8 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,6 @@ def forward(self, x):
# Unconditionally dequantize.
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.bfloat16)
return y
if qdq_output is not None:
y = qdq_output.quantize(y).unpack().dequant()
return y
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
attention_dtype: Optional[torch.dtype] = None,
attention_kernel: str = "decomposed",
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
Expand All @@ -49,6 +50,7 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.attention_dtype = attention_dtype
self.attention_kernel = attention_kernel
self.attention_scale = attention_scale
self.softcap = softcap
Expand Down Expand Up @@ -161,13 +163,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# Fake quant is already dequantized when stored in the cache.
if self.cache_quantizer and not self.fake_quant:
xk = self.cache_quantizer.dequantize_raw_tensor(
xk, torch.bfloat16, name="xk_deq"
xk, self.attention_dtype, name="xk_deq"
)
xv = self.cache_quantizer.dequantize_raw_tensor(
xv, torch.bfloat16, name="xv_deq"
xv, self.attention_dtype, name="xv_deq"
)
if attention_mask is not None:
attention_mask = attention_mask.to(torch.bfloat16)
attention_mask = attention_mask.to(self.attention_dtype)

# Transpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
Expand Down
3 changes: 3 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
attention_dtype=config.attention_dtype,
attention_kernel=self.attention_kernel,
fake_quant=self.fake_quant,
)
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
attention_dtype: Optional[torch.dtype] = None,
attention_kernel: str = "decomposed",
fake_quant: bool = True,
):
Expand All @@ -255,6 +257,7 @@ def __init__(
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
attention_dtype=attention_dtype,
attention_kernel=attention_kernel,
fake_quant=fake_quant,
),
Expand Down
Loading

0 comments on commit 292e55d

Please sign in to comment.