diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 7c12f4e20..efbba089b 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -83,6 +83,7 @@ def main(): block_seq_stride=args.block_seq_stride, activation_dtype=args.activation_dtype, attention_dtype=args.attention_dtype, + kv_cache_dtype=args.kv_cache_dtype, ) llama_config.fake_quant = args.fake_quant diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 5d338bd74..b8b6026b6 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -288,8 +288,9 @@ def main(): block_seq_stride=args.block_seq_stride, device=device, activation_dtype=args.activation_dtype, - attention_dtype=args.activation_dtype, + attention_dtype=args.attention_dtype, attention_kernel=args.attention_kernel, + kv_cache_dtype=args.kv_cache_dtype, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, fake_quant=args.fake_quant, diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 21f9e9ed4..e9879bfab 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -7,16 +7,36 @@ from sharktank.kernels.base import * import torch +from typing import cast, Optional -from iree.compiler.ir import IntegerType +from iree.compiler.ir import IntegerType, Type +from iree.turbine.support.conversions import ( + TORCH_DTYPE_TO_IREE_TYPE_ASM, + IREE_TYPE_ASM_TO_TORCH_DTYPE, +) +from iree.turbine.runtime.op_reg import AttrArg __all__ = [ "batch_matmul_transpose_b", ] +def batch_matmul_transpose_b( + lhs: torch.Tensor, + rhs: torch.Tensor, + /, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if accum_dtype is None: + accum_dtype = lhs.dtype + return _batch_matmul_transpose_b( + lhs, rhs, accum_dtype=TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype] + ) + + @CustomOp.register(library=LIBRARY) -class batch_matmul_transpose_b(CustomOp): +class _batch_matmul_transpose_b(CustomOp): """Generic block scaled matmul with transposed RHS. The LHS is expected to be a 3d tensor of shape [B, M, K]. RHS must be @@ -25,11 +45,18 @@ class batch_matmul_transpose_b(CustomOp): The kernel will be specialized for all values of N, K and LHS dtype. """ - signature = "batch_matmul_transpose_b(Tensor lhs, Tensor rhs) -> (Tensor)" + signature = ( + "batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)" + ) + + def eager_execute(self, lhs: torch.Tensor, rhs: torch.Tensor, accum_dtype: str): + dtype = IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_dtype] + return torch.matmul(lhs.to(dtype=dtype), rhs.transpose(-1, -2).to(dtype=dtype)) def select(self, ksel: KernelSelection): lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K] rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K] + accum_type_attr = ksel.attr_str(2) # Rank check. torch._check( @@ -60,7 +87,8 @@ def select(self, ksel: KernelSelection): ) # Shape batch, m, n c_desc = ksel.return_new_tensor( - [lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype + [lhs_batch, lhs_m, rhs_n], + dtype=IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_type_attr.v], ) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) @@ -74,12 +102,14 @@ def select(self, ksel: KernelSelection): def generate(self, ksel: KernelSelection, kb: KernelBuilder): lhs = kb.arg_value(0) rhs = kb.arg_value(1) + accum_type_str = cast(AttrArg, ksel.arg_descs[2]).v result_desc = ksel.result_descs[0] # Generate specialization signature and types. - a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type) + a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) - spec_sig = f"L{a_ident}_R{b_ident}" + accum_type = Type.parse(accum_type_str) + spec_sig = f"L{a_ident}_R{b_ident}_{accum_type_str}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" cst_zero = "0" if IntegerType.isinstance(accum_type) else "0." diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index fd8fd6d12..fafc2a98e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -14,12 +14,14 @@ (and indeed, can bootstrap these off of GGUF files). """ -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Any, Optional import torch +from transformers import T5Config as T5ConfigHf from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name + __all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] @@ -158,6 +160,9 @@ class LlamaModelConfig: # Either "paged" or "direct". kv_cache_type: str = "paged" + # If None will use attention_dtype. + kv_cache_dtype: Optional[torch.dtype] = None + # The device on which to place intermediate state. device: Optional[torch.device] = None @@ -231,53 +236,41 @@ def __post_init__(self): self.dense_act_fn = "gelu_new" @staticmethod - def from_gguf_properties(properties: dict[str, Any], **kwargs): - assert properties["general.architecture"] == "t5" - assert ( - properties["t5.attention.layer_norm_epsilon"] - == properties["t5.attention.layer_norm_rms_epsilon"] - ) - - all_kwargs = {"vocab_size": None, "feed_forward_proj": None} - - gguf_to_config_names_map = { - "t5.context_length": ["context_length"], - "t5.embedding_length": ["d_model"], - "t5.feed_forward_length": ["d_ff"], - "t5.block_count": ["num_layers", "num_decoder_layers"], - "t5.attention.head_count": ["num_heads"], - "t5.attention.key_length": ["d_kv"], - "t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"], - "t5.attention.relative_buckets_count": ["relative_attention_num_buckets"], - "tokenizer.ggml.eos_token_id": ["eos_token_id"], - "tokenizer.ggml.padding_token_id": ["pad_token_id"], - } - all_kwargs.update( - { - config_name: properties[gguf_name] - for gguf_name, config_names in gguf_to_config_names_map.items() - for config_name in config_names - } - ) - - gguf_to_optional_config_names_map = { - "t5.decoder_start_token_id": ["decoder_start_token_id"], - } - all_kwargs.update( - { - config_name: properties[gguf_name] - for gguf_name, config_names in gguf_to_optional_config_names_map.items() - for config_name in config_names - if gguf_name in properties - } - ) - - if "tokenizer.ggml.tokens" in properties: - all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"]) + def from_hugging_face_config( + config: T5ConfigHf, tokenizer_config: dict[str, Any], **kwargs + ) -> "T5Config": + all_kwargs = {} + for filed in fields(T5Config): + if hasattr(config, filed.name): + all_kwargs[filed.name] = getattr(config, filed.name) + all_kwargs["context_length"] = tokenizer_config["model_max_length"] + del all_kwargs["is_gated_act"] + del all_kwargs["dense_act_fn"] all_kwargs.update(kwargs) - return T5Config(**all_kwargs) + @staticmethod + def from_properties(properties: dict[str, Any]) -> "T5Config": + kwargs = dict(properties) + if "SHARK_DATASET_VERSION" in kwargs: + kwargs.pop("SHARK_DATASET_VERSION") + if "activation_dtype" in kwargs and kwargs["activation_dtype"] is not None: + kwargs["activation_dtype"] = serialized_name_to_dtype( + kwargs["activation_dtype"] + ) + if "is_gated_act" in kwargs: + kwargs.pop("is_gated_act") + if "dense_act_fn" in kwargs: + kwargs.pop("dense_act_fn") + + return T5Config(**kwargs) + + def to_properties(self) -> dict[str, Any]: + res = asdict(self) + if self.activation_dtype is not None: + res["activation_dtype"] = dtype_to_serialized_name(self.activation_dtype) + return res + @dataclass class ClipTextConfig: @@ -336,7 +329,8 @@ def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig @staticmethod def from_properties(properties: dict[str, Any]) -> "ClipTextConfig": kwargs = dict(properties) - kwargs.pop("SHARK_DATASET_VERSION") + if "SHARK_DATASET_VERSION" in kwargs: + kwargs.pop("SHARK_DATASET_VERSION") if "dtype" in kwargs and kwargs["dtype"] is not None: kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"]) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index be8c66fb4..a961a9dd7 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -275,7 +275,14 @@ def write_timestep( partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) - page_table.index_put_(indices=indices, values=cache_partition) + values = ops.to(cache_partition, dtype=page_table.dtype) + if page_table.dtype == torch.float8_e4m3fnuz: + # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. + page_table_as_int8 = page_table.view(dtype=torch.int8) + values_int8 = values.view(dtype=torch.int8) + page_table_as_int8.index_put_(indices=indices, values=values_int8) + else: + page_table.index_put_(indices=indices, values=values) return @@ -320,4 +327,11 @@ def write( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) - subblock_table.index_copy_(0, subblock_ids, part_block_view) + part_block = ops.to(part_block_view, dtype=subblock_table.dtype) + if subblock_table.dtype == torch.float8_e4m3fnuz: + # Workaround for Torch not supporting torch.Tensor.index_copy_ for f8. + subblock_table_as_int8 = subblock_table.view(dtype=torch.int8) + part_block_as_int8 = part_block.view(dtype=torch.int8) + subblock_table_as_int8.index_copy_(0, subblock_ids, part_block_as_int8) + else: + subblock_table.index_copy_(0, subblock_ids, part_block) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..79ede9f5c 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -81,14 +81,6 @@ def forward(self, x): # Unconditionally dequantize. if isinstance(y, QuantizedTensor): y = y.unpack().dequant() - # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. - # We can truncate to fp16 in iree, so we do a cast here - # to account for this in the IR. This is may not be the right - # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor): - if y.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.bfloat16) - return y if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 71bcd1e8c..76dafb918 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,6 +37,7 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, + attention_dtype: Optional[torch.dtype] = None, attention_kernel: str = "decomposed", attention_scale: Optional[float] = None, softcap: Optional[float] = None, @@ -49,6 +50,7 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv + self.attention_dtype = attention_dtype self.attention_kernel = attention_kernel self.attention_scale = attention_scale self.softcap = softcap @@ -161,13 +163,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Fake quant is already dequantized when stored in the cache. if self.cache_quantizer and not self.fake_quant: xk = self.cache_quantizer.dequantize_raw_tensor( - xk, torch.bfloat16, name="xk_deq" + xk, self.attention_dtype, name="xk_deq" ) xv = self.cache_quantizer.dequantize_raw_tensor( - xv, torch.bfloat16, name="xv_deq" + xv, self.attention_dtype, name="xv_deq" ) if attention_mask is not None: - attention_mask = attention_mask.to(torch.bfloat16) + attention_mask = attention_mask.to(self.attention_dtype) # Transpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index a31913833..bc18cb492 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -117,6 +117,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, + attention_dtype=config.attention_dtype, attention_kernel=self.attention_kernel, fake_quant=self.fake_quant, ) @@ -241,6 +242,7 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, + attention_dtype: Optional[torch.dtype] = None, attention_kernel: str = "decomposed", fake_quant: bool = True, ): @@ -255,6 +257,7 @@ def __init__( head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, + attention_dtype=attention_dtype, attention_kernel=attention_kernel, fake_quant=fake_quant, ), diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py index 8d5f75db2..5072da025 100644 --- a/sharktank/sharktank/models/t5/export.py +++ b/sharktank/sharktank/models/t5/export.py @@ -5,20 +5,21 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools -from typing import Optional, Union +from typing import Any, Optional, Union from pathlib import Path import torch from copy import copy +import transformers from .t5 import T5Config, T5Encoder -from ...types import Dataset +from ...types import Dataset, Theta, DefaultPrimitiveTensor from ...transforms.dataset import set_float_dtype from iree.turbine.aot import FxProgramsBuilder, export __all__ = [ "export_encoder_mlir", "export_encoder_iree_parameters", - "prune_decoder_parameters", + "import_encoder_dataset_from_hugging_face", ] @@ -33,13 +34,8 @@ def export_encoder_mlir( """ if isinstance(model, (Path, str)): dataset = Dataset.load(model) - config = T5Config.from_gguf_properties( + config = T5Config.from_properties( dataset.properties, - # TODO: add this property to our HuggingFace-to-GGUF conversion script. - # We currently use llama.cpp's converter and it can not make a distinction - # between T5 V1 and V1.1. - # V1 uses ReLU and V1.1 uses gated GeLU. - feed_forward_proj="gated-gelu", ) model = T5Encoder(theta=dataset.root_theta, config=config) @@ -82,18 +78,6 @@ def _( output.save_mlir(mlir_output_path) -def prune_decoder_parameters(dataset: Dataset): - # Remove decoder tensors/parameters if present. - try: - del dataset.root_theta.tree["dec"] - except KeyError: - pass - try: - del dataset.properties["t5.decoder_start_token_id"] - except KeyError: - pass - - def export_encoder_iree_parameters( model_path_or_dataset: str | Dataset, output_path: str, @@ -107,5 +91,35 @@ def export_encoder_iree_parameters( dataset.root_theta = dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=dtype) ) - prune_decoder_parameters(dataset) dataset.save(output_path) + + +def import_encoder_dataset_from_hugging_face( + repo_id_or_model: transformers.T5EncoderModel | str, + /, + *, + tokenizer_config: dict[str, Any] | None = None, +) -> Dataset: + model = repo_id_or_model + if not isinstance(repo_id_or_model, transformers.T5EncoderModel): + model = transformers.T5EncoderModel.from_pretrained(repo_id_or_model) + from transformers.models.auto.tokenization_auto import get_tokenizer_config + + if tokenizer_config is None: + tokenizer_config = get_tokenizer_config(repo_id_or_model) + else: + if tokenizer_config is None: + raise ValueError( + "When providing a model directly tokenizer_config must also be provided." + ) + + theta = Theta( + { + name: DefaultPrimitiveTensor(data=param, name=name) + for name, param in model.named_parameters() + } + ) + config = T5Config.from_hugging_face_config( + model.config, tokenizer_config=tokenizer_config + ) + return Dataset(properties=config.to_properties(), root_theta=theta) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py index a2a8958be..9af2f5636 100644 --- a/sharktank/sharktank/models/t5/t5.py +++ b/sharktank/sharktank/models/t5/t5.py @@ -54,12 +54,27 @@ def __init__( activation_dtype: torch.dtype, ): super().__init__() + + ffn_theta = theta("DenseReluDense") + ffn_theta_dict = {} + if is_gated_act: + ffn_theta_dict["ffn_gate"] = ffn_theta("wi_0").tree + ffn_theta_dict["ffn_up"] = ffn_theta("wi_1").tree + else: + ffn_theta_dict["ffn_up"] = ffn_theta("wi").tree + ffn_theta_dict["ffn_down"] = ffn_theta("wo").tree + ffn_theta = Theta(ffn_theta_dict) + self.dense_activation_dense = FFN( - theta=theta, is_gated=is_gated_act, activation_fn=ACT2FN[dense_act_fn] + theta=ffn_theta, + is_gated=is_gated_act, + activation_fn=ACT2FN[dense_act_fn], ) self.layer_norm = RMSNormLayer( - theta=theta("ffn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + theta=theta("layer_norm"), + epsilon=layer_norm_epsilon, + dtype=activation_dtype, ) def forward(self, hidden_states): @@ -93,14 +108,14 @@ def __init__( self.inner_dim = self.n_heads * self.key_value_proj_dim self.activation_dtype = activation_dtype - self.q = LinearLayer(theta("attn_q")) - self.k = LinearLayer(theta("attn_k")) - self.v = LinearLayer(theta("attn_v")) - self.o = LinearLayer(theta("attn_o")) + self.q = LinearLayer(theta("q")) + self.k = LinearLayer(theta("k")) + self.v = LinearLayer(theta("v")) + self.o = LinearLayer(theta("o")) if self.has_relative_attention_bias: self.relative_attention_bias = TokenEmbeddingLayer( - theta("attn_rel_b"), dtype=activation_dtype + theta("relative_attention_bias"), dtype=activation_dtype ) self.pruned_heads = set() @@ -360,7 +375,7 @@ def __init__( ): super().__init__() self.attention = T5Attention( - theta=theta, + theta=theta("SelfAttention"), is_decoder=is_decoder, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -371,7 +386,9 @@ def __init__( has_relative_attention_bias=has_relative_attention_bias, ) self.layer_norm = RMSNormLayer( - theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + theta=theta("layer_norm"), + epsilon=layer_norm_epsilon, + dtype=activation_dtype, ) def forward( @@ -482,7 +499,7 @@ def __init__( self.layer = nn.ModuleList() self.layer.append( T5SelfAttention( - theta=theta, + theta=theta(f"layer.{len(self.layer)}"), is_decoder=is_decoder, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -497,7 +514,7 @@ def __init__( if self.is_decoder: self.layer.append( T5CrossAttention( - theta=theta, + theta=theta(f"layer.{len(self.layer)}"), is_decoder=is_decoder, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -511,7 +528,7 @@ def __init__( self.layer.append( T5LayerFF( - theta=theta, + theta=theta(f"layer.{len(self.layer)}"), is_gated_act=is_gated_act, dense_act_fn=dense_act_fn, layer_norm_epsilon=layer_norm_epsilon, @@ -654,12 +671,11 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): self.embed_tokens = embed_tokens self.config = config self.is_decoder = config.is_decoder - theta_prefix = "dec" if config.is_decoder else "enc" self.block = torch.nn.ModuleList( [ T5Block( - theta=theta(f"{theta_prefix}.blk.{i}"), + theta=theta(f"block.{i}"), is_decoder=config.is_decoder, relative_attention_num_buckets=config.relative_attention_num_buckets, relative_attention_max_distance=config.relative_attention_max_distance, @@ -678,7 +694,7 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): self.add_module( "final_layer_norm", RMSNormLayer( - theta(f"{theta_prefix}.output_norm"), + theta(f"final_layer_norm"), epsilon=config.layer_norm_epsilon, dtype=config.activation_dtype, ), @@ -1043,7 +1059,7 @@ def __init__(self, theta: Theta, config: T5Config): self.add_module( "token_embedding", TokenEmbeddingLayer( - theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype + theta("shared"), dtype=theta("shared").tensor("weight").dtype ), ) @@ -1051,7 +1067,9 @@ def __init__(self, theta: Theta, config: T5Config): encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = T5Stack( - theta=theta, config=encoder_config, embed_tokens=self.token_embedding + theta=theta("encoder"), + config=encoder_config, + embed_tokens=self.token_embedding, ) @property diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py index 126c05f2c..881923b8b 100644 --- a/sharktank/sharktank/models/vae/model.py +++ b/sharktank/sharktank/models/vae/model.py @@ -37,8 +37,8 @@ def __init__(self, hp, theta: Theta): self.mid_block = self._create_mid_block(theta("decoder")("mid_block")) # up self.up_blocks = nn.ModuleList([]) - self.upscale_dtype = theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")( - "weight" + self.upscale_dtype = unbox_tensor( + theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")("weight") ).dtype for i, up_block_name in enumerate(hp.up_block_types): up_block_theta = theta("decoder")("up_blocks")(i) @@ -74,6 +74,16 @@ def forward( "latent_embeds": latent_embeds, }, ) + if not self.hp.use_post_quant_conv: + sample = rearrange( + sample, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(1024 / 16), + w=math.ceil(1024 / 16), + ph=2, + pw=2, + ) + sample = sample / self.hp.scaling_factor + self.hp.shift_factor if self.hp.use_post_quant_conv: diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f88684273..104349266 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -44,16 +44,16 @@ def qlinear_tensor_scaled( # Now we know that both the x/weight are TensorScaledLayout. There are still # degrees of freedom: # * Either/both can be per-tensor or per-axis scaled (d is 0D or d is nd>0). - # * Either/both can have offsets of not (m is not None). + # * Either/both can have offsets or not (m is not None). x_layout: TensorScaledLayout = x.unpack() weight_layout: TensorScaledLayout = weight.unpack() # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - if x_layout.qs.dtype == torch.float8_e4m3fnuz: - # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) - else: + if ( + x_layout.qs.dtype != torch.float8_e4m3fnuz + or weight_layout.qs.dtype != torch.float8_e4m3fnuz + ): return NotImplemented # Bias. @@ -160,6 +160,8 @@ def linear_quantized_weight( *, accum_dtype: Optional[torch.dtype], ) -> AnyTensor: + if accum_dtype is not None: + raise NotImplementedError("TODO: implement when accum_dtype is passed") res = matmul(x, weight, transpose_rhs=True) if bias is not None: res = res + bias @@ -170,7 +172,13 @@ def linear_quantized_weight( linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight) -def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): +def _is_dtype_unsigned_integer(dtype: torch.dtype): + return not dtype.is_complex and not dtype.is_floating_point and not dtype.is_signed + + +def _invoke_mmt_kernel( + lhs: torch.Tensor, rhs: torch.Tensor, *, accum_dtype: torch.dtype +): if debugging.flags.use_custom_iree_kernels: # The custom kernel requires that the lhs and rhs be the same # rank. Broadcast the rhs to match. @@ -187,13 +195,17 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b( - lhs.to(accum_dtype), rhs.to(accum_dtype) - ) - # Squeeze the batch dimension to maintain shape parity with other - # layers. - if len(y_qs.shape) > 2: - y_qs = y_qs.squeeze(0) + if ( + _is_dtype_unsigned_integer(lhs.dtype) + or _is_dtype_unsigned_integer(rhs.dtype) + or _is_dtype_unsigned_integer(accum_dtype) + ): + # TODO: make the kernel work with unsigned types. + y_qs = kernels.batch_matmul_transpose_b( + lhs.to(dtype=accum_dtype), rhs.to(dtype=accum_dtype) + ) + else: + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype) else: # FP emulation. y_qs = torch.matmul( diff --git a/sharktank/sharktank/pipelines/flux/export_components.py b/sharktank/sharktank/pipelines/flux/export_components.py index 85c7888de..334dbdadb 100644 --- a/sharktank/sharktank/pipelines/flux/export_components.py +++ b/sharktank/sharktank/pipelines/flux/export_components.py @@ -267,15 +267,7 @@ def __init__(self, weight_file, height=1024, width=1024, precision="fp32"): self.width = width def forward(self, z): - d_in = rearrange( - z, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(self.height / 16), - w=math.ceil(self.width / 16), - ph=2, - pw=2, - ) - return self.ae.forward(d_in) + return self.ae.forward(z) def get_ae_model_and_inputs( diff --git a/sharktank/sharktank/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py index d4e9c43be..deeb741b2 100644 --- a/sharktank/sharktank/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -52,7 +52,9 @@ def import_hf_dataset( for params_path in param_paths: with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: tensors += [ - DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) + DefaultPrimitiveTensor( + name=name, data=st.get_tensor(name).to(target_dtype) + ) for name in st.keys() ] diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index e3dba31fa..36ae89cc6 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -32,10 +32,12 @@ def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None) """Parses arguments and does any prescribed global process setup.""" parsed_args = parser.parse_args(args) # Set torch dtypes - for attr in ["activation_dtype", "attention_dtype"]: + for attr in ["activation_dtype", "attention_dtype", "kv_cache_dtype"]: if hasattr(parsed_args, attr): - dtype = getattr(torch, getattr(parsed_args, attr)) - assert isinstance(dtype, torch.dtype) + dtype = getattr(parsed_args, attr) + if dtype is not None: + dtype = getattr(torch, dtype) + assert isinstance(dtype, torch.dtype) setattr(parsed_args, attr, dtype) return parsed_args @@ -100,6 +102,11 @@ def add_model_options(parser: argparse.ArgumentParser): help="DType to use for activations in the model", default="float16", ) + parser.add_argument( + "--kv-cache-dtype", + help="DType to use for the KV cache. If not given will be attention dtype", + default=None, + ) parser.add_argument("--device", help="Torch device (or default)") parser.add_argument( diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py index f462d9c00..eb26f5a14 100644 --- a/sharktank/sharktank/utils/create_cache.py +++ b/sharktank/sharktank/utils/create_cache.py @@ -12,6 +12,7 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: raise ValueError("Model does not use paged kv cache, cannot create kv cache") hp = config.hp + dtype = config.kv_cache_dtype or config.attention_dtype return PagedKVCache( transformer_block_count=hp.block_count, attn_head_count=hp.attention_head_count_kv, @@ -19,6 +20,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: cache_partition_count=2, # One for each of K/V. block_seq_stride=config.block_seq_stride, device=config.device, - dtype=config.attention_dtype, + dtype=dtype, shard_count=config.tensor_parallelism_size, ) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 75071e286..608f65c48 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -95,6 +95,7 @@ def __init__( use_attention_mask: bool = False, activation_dtype: str = "float16", attention_dtype: str = "float16", + kv_cache_dtype: Optional[str] = None, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -109,6 +110,7 @@ def __init__( self.use_attention_mask = use_attention_mask self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype + self.kv_cache_dtype = kv_cache_dtype def timeit(func): def wrapper(*args, **kwargs): @@ -189,6 +191,8 @@ def export_to_mlir( f"--attention-dtype={self.attention_dtype}", f"--activation-dtype={self.activation_dtype}", ] + if self.kv_cache_dtype is not None: + export_args.append(f"--kv-cache-dtype={self.kv_cache_dtype}") if skip_decode: export_args.append("--skip-decode") if self.attention_kernel in ["decomposed", "torch"]: diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index c38976e01..6ae0a6a88 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -128,7 +128,7 @@ def torch_tensor_to_device_array( if tensor.dtype == torch.bfloat16: tensor_as_int16 = tensor.view(dtype=torch.int16) device_array_as_int16 = iree.runtime.asdevicearray( - device, unbox_tensor(tensor_as_int16).to("cpu").numpy() + device, unbox_tensor(tensor_as_int16).to("cpu").detach().numpy() ) buffer_view = iree.runtime.HalBufferView( buffer=device_array_as_int16._buffer_view.get_buffer(), @@ -137,7 +137,9 @@ def torch_tensor_to_device_array( ) return iree.runtime.DeviceArray(device, buffer_view) - return iree.runtime.asdevicearray(device, unbox_tensor(tensor).to("cpu").numpy()) + return iree.runtime.asdevicearray( + device, unbox_tensor(tensor).to("cpu").detach().numpy() + ) def run_iree_module_function( @@ -162,7 +164,7 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + promote_bfloat16_to_float32(device_array_to_host(arg)).detach().numpy(), ) results = invoker(*args) if isinstance(results, iree.runtime.DeviceArray): @@ -172,12 +174,12 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", - device_array_to_host(arg).numpy(), + device_array_to_host(arg).detach().numpy(), ) for i, arg in enumerate(results): np.save( f"{trace_path_prefix}{function_name}_result{i}.npy", - promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + promote_bfloat16_to_float32(device_array_to_host(arg)).detach().numpy(), ) return results @@ -233,7 +235,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - promote_bfloat16_to_float32(arg.to("cpu")).numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).detach().numpy(), ) res = getattr(module, function_name)(*args, **kwargs) if trace_path_prefix is not None: @@ -241,7 +243,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - promote_bfloat16_to_float32(arg.to("cpu")).numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).detach().numpy(), ) results = ( (res,) @@ -258,7 +260,7 @@ def call_torch_module_function( for i, result in enumerate(flat_results): np.save( f"{trace_path_prefix}{function_name}_result{i}.npy", - result.to("cpu").numpy(), + result.to("cpu").detach().numpy(), ) return res diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index 208d54782..3ac260265 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -10,11 +10,13 @@ import unittest from parameterized import parameterized - +import pytest import torch from iree.turbine import aot +from iree.turbine.support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM from sharktank import kernels +from sharktank.utils.testing import skip class batch_matmul_transpose_b_test(unittest.TestCase): @@ -40,24 +42,111 @@ def testBS32(self, atol, rtol): ref = torch.matmul(a, bT) torch.testing.assert_close(result, ref, atol=atol, rtol=rtol) - def testExportStaticDims(self): + def testArgF8AccumF32(self): + # TODO: make this test not use eager but actually execute with IREE. + # Does not compile for llvm-cpu with + # :0: error: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'vector<4xi8>' + # :0: note: see current operation: %120 = "llvm.fpext"(%109) : (vector<4xi8>) -> vector<4xf32> + arg_dtype = torch.float8_e4m3fnuz + a = torch.rand([3, 4, 6]).to(arg_dtype) + b = torch.rand([3, 5, 6]).to(arg_dtype) + accum_dtype = torch.float32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=1e-3, rtol=0) + + def testArgUi8AccumI32(self): + # TODO: make this test not use eager but actually execute with IREE. + # Does not work with unsigned types. The kernel needs to be adapted. + arg_dtype = torch.uint8 + a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + def testArgLhsI8RhsUi8AccumI32(self): + a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=torch.int8) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=torch.uint8) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + def testArgI8AccumI32(self): + arg_dtype = torch.int8 + a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype) + b = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype) + accum_dtype = torch.int32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype)) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + @pytest.mark.xfail( + reason="""No uint32 dtype conversions in IREE Turbine. + Does not work with unsigned types. The kernel needs to be adapted. + The problem is that we reinterpret cast to signless integer types. + Maybe linalg.batch_matmul_transpose_b when promoting from i8 to i32 assumes a + signed type even though i8 is signless.""" + ) + def testArgUi8AccumUi32(self): + arg_dtype = torch.uint8 + a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype) + b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype) + accum_dtype = torch.uint32 + result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) + + bT = torch.transpose(b, 1, 2) + ref = torch.matmul(a.to(dtype=torch.int32), bT.to(dtype=torch.int32)) + ref = ref.to(dtype=accum_dtype) + torch.testing.assert_close(result, ref, atol=0, rtol=0) + + @parameterized.expand( + [ + (torch.int32, None), + (torch.float8_e4m3fnuz, torch.float32), + ] + ) + def testExportStaticDims( + self, arg_dtype: torch.dtype, accum_dtype: torch.dtype | None + ): class MyModule(torch.nn.Module): def forward(self, a, b): - return kernels.batch_matmul_transpose_b(a, b) + return kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype) mod = MyModule() - dtype = torch.int32 ep = torch.export.export( mod, args=( - (torch.rand([4, 16, 2]) * 64).to(dtype), - (torch.rand([4, 8, 2]) * 64).to(dtype), + (torch.rand([4, 16, 2]) * 64).to(arg_dtype), + (torch.rand([4, 8, 2]) * 64).to(arg_dtype), ), ) output = aot.export(ep) output.verify() asm = str(output.mlir_module) - self.assertIn("@sharktank_batch_matmul_transpose_b_L4x16x2xi32_R4x8x2xi32", asm) + arg_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[arg_dtype] + accum_dtype_asm = arg_dtype_asm + if accum_dtype is not None: + accum_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype] + self.assertIn( + ( + "@sharktank_batch_matmul_transpose_b_" + f"L4x16x2x{arg_dtype_asm}_R4x8x2x{arg_dtype_asm}_{accum_dtype_asm}" + ), + asm, + ) if __name__ == "__main__": diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index d59d0a85b..ff60ef8b6 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -4,8 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import unittest - +import pytest import torch from sharktank.ops import replicate, reshard_split, unshard @@ -13,7 +12,16 @@ from sharktank.types import * -def test_paged(): +@pytest.mark.parametrize( + "dtype", + [ + torch.float8_e4m3fnuz, + torch.bfloat16, + torch.float16, + torch.float32, + ], +) +def test_paged(dtype: torch.dtype): bs = 4 seq_length = 24 attn_head_count = 4 @@ -25,7 +33,7 @@ def test_paged(): transformer_block_count=transformer_block_count, attn_head_count=attn_head_count, attn_head_dim=attn_head_dim, - dtype=torch.float32, + dtype=dtype, device=None, ) @@ -36,15 +44,18 @@ def test_paged(): write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] allocation = cache.allocate(page_count=page_count) - allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + for t in allocation: + t[...] = torch.full(t.shape, 0.0).to(dtype=dtype) # Write a prefill in: write_ones = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 - ) + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + ).to(dtype=dtype) write_twos = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 - ) + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + ).to(dtype=dtype) cache.write( allocation, @@ -72,15 +83,19 @@ def test_paged(): seq_len=write_seq_length, page_ids=write_page_ids, ) - torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) - torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close( + read_ones[0], torch.full(read_ones[0].shape, 0.0).to(dtype=dtype) + ) + torch.testing.assert_close( + read_ones[1], torch.full(read_ones[0].shape, 0.0).to(dtype=dtype) + ) # Write timestep - write_threes = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + write_threes = torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0).to( + dtype=dtype ) - write_fours = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + write_fours = torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0).to( + dtype=dtype ) write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) cache.write_timestep( @@ -98,8 +113,16 @@ def test_paged(): page_ids=page_ids, ) - check_concat_0 = torch.concat([write_ones, write_threes], dim=1) - check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + if dtype == torch.float8_e4m3fnuz: + check_concat_0 = torch.concat( + [write_ones.view(torch.int8), write_threes.view(torch.int8)], dim=1 + ).view(torch.float8_e4m3fnuz) + check_concat_1 = torch.concat( + [write_twos.view(torch.int8), write_fours.view(torch.int8)], dim=1 + ).view(torch.float8_e4m3fnuz) + else: + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) torch.testing.assert_close(check_concat_0, read_back[0]) torch.testing.assert_close(check_concat_1, read_back[1]) diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py index ad657889d..08164176f 100644 --- a/sharktank/tests/layers/linear_test.py +++ b/sharktank/tests/layers/linear_test.py @@ -4,12 +4,16 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import logging import unittest - import torch +from parameterized import parameterized from sharktank.layers import * from sharktank.types import * +from sharktank.utils.testing import make_rand_torch + +logger = logging.getLogger(__name__) def _randomize_per_axis(t: torch.Tensor, axis: int, offset_range: float = 0.0): @@ -91,6 +95,102 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self): print(torch.abs(output - output_ref)) torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1) + @parameterized.expand( + [ + (torch.bfloat16, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-2), + (torch.bfloat16, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-2), + (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, False, 1e-6), + (torch.float32, torch.float32, torch.float8_e4m3fnuz, False, True, 1e-6), + (torch.float32, torch.float32, torch.float16, True, False, 1e-6), + (torch.float32, torch.float32, torch.float16, False, False, 1e-6), + (torch.float32, torch.float32, torch.float16, False, True, 1e-6), + (torch.float32, torch.float32, torch.float32, True, False, 1e-6), + ], + ) + def testPerTensorScale( + self, + dequantized_dtype: torch.dtype, + quantized_scale_dtype: torch.dtype, + quantized_dtype: torch.dtype, + with_bias: bool, + fake_quant: bool, + atol: float, + ): + """Test a linear layer where each tensor being quantized with a single + different scale.""" + ref_dtype = torch.float64 + + x = make_rand_torch([10, 8, 8], dtype=dequantized_dtype) + input_scale = torch.tensor(0.5, dtype=quantized_scale_dtype) + input_quantizer = StaticScaledQuantizer( + name="q_input", scale=input_scale, dtype=quantized_dtype + ) + # We roundtrip through quantization to know that any discrepancies in the + # results come from the quantized linear operation itself. Not form the + # inaccuracies of the initial quantization. + x_dequantized = input_quantizer.quantize(x).unpack().dequant() + torch.testing.assert_close( + input_quantizer.quantize(x_dequantized).unpack().dequant(), + x_dequantized, + atol=0, + rtol=0, + ) + + weight = make_rand_torch([12, x.shape[2]], dtype=dequantized_dtype) + weight_scale = torch.tensor(0.66, dtype=quantized_scale_dtype) + weight_quantizer = StaticScaledQuantizer( + scale=weight_scale, dtype=quantized_dtype + ) + weight_dequantized = weight_quantizer.quantize(weight).unpack().dequant() + weight_quantized = weight_quantizer.quantize(weight_dequantized, name="weight") + torch.testing.assert_close( + weight_quantizer.quantize(weight_dequantized).unpack().dequant(), + weight_dequantized, + atol=0, + rtol=0, + ) + + if with_bias: + bias = make_rand_torch( + [x.shape[1], weight.shape[0]], dtype=dequantized_dtype + ) + bias_scale = torch.tensor(1.25, dtype=quantized_scale_dtype) + bias_quantizer = StaticScaledQuantizer( + scale=bias_scale, dtype=quantized_dtype + ) + bias_dequantized = bias_quantizer.quantize(bias).unpack().dequant() + bias_quantized = bias_quantizer.quantize(bias_dequantized, name="bias") + torch.testing.assert_close( + bias_quantizer.quantize(bias_dequantized).unpack().dequant(), + bias_dequantized, + atol=0, + rtol=0, + ) + + expected = torch.matmul( + x_dequantized.to(ref_dtype), weight_dequantized.T.to(ref_dtype) + ) + if with_bias: + expected += bias_dequantized.to(ref_dtype) + + theta_tensors = [ + input_quantizer, + weight_quantized, + ] + if with_bias: + theta_tensors += [bias_quantized] + theta = Theta(theta_tensors) + linear = LinearLayer(theta, fake_quant=fake_quant) + actual = linear(x_dequantized) + actual = actual.to(dtype=expected.dtype) + + abs_diff = (expected - actual).abs() + logger.info( + f"abs diff from expected (std, mean, median) = {[float(abs_diff.std()), float(abs_diff.mean()), float(abs_diff.median())]}" + ) + + torch.testing.assert_close(actual, expected, atol=atol, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 6056f9c6c..39d6f4ce8 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -94,7 +94,8 @@ def setUp(self): tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=32, activation_dtype="bfloat16", - attention_dtype="float8_e4m3fnuz", + attention_dtype="bfloat16", + kv_cache_dtype="float8_e4m3fnuz", ) self.prefill_args_bs4_128_stride_32_f16 = ( self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp1" @@ -242,9 +243,9 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_2048(self): ) @pytest.mark.xfail( - reason="Benchmark inputs not configured yet.", + reason="Fails due to https://github.com/iree-org/iree/issues/20002.", strict=True, - raises=IreeBenchmarkException, + raises=IreeCompileException, ) def testBenchmark8B_fp8_Non_Decomposed(self): output_file_name = self.dir_path_8b / "fp8_torch" diff --git a/sharktank/tests/models/llama/quark_parity_test.py b/sharktank/tests/models/llama/quark_parity_test.py index b45696fe4..b8a30c543 100644 --- a/sharktank/tests/models/llama/quark_parity_test.py +++ b/sharktank/tests/models/llama/quark_parity_test.py @@ -11,11 +11,12 @@ import pytest from pathlib import Path import subprocess +from sharktank.utils.testing import TempDirTestBase with_quark_data = pytest.mark.skipif("not config.getoption('with_quark_data')") -class QuarkParityTest(unittest.TestCase): +class QuarkParityTest(TempDirTestBase): def setUp(self): super().setUp() self.path_prefix = Path("/shark-dev/quark_test") @@ -25,7 +26,7 @@ def test_compare_against_quark(self): sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent.parent ) - our_path = self.path_prefix / "ours_prefill.safetensors" + our_path = self._temp_dir / "ours_prefill.safetensors" if os.path.exists(our_path): os.remove(our_path) @@ -58,16 +59,15 @@ def test_compare_against_quark(self): "--fake-quant", "--attention-kernel=torch", "--activation-dtype=bfloat16", - f"--save_intermediates_path={self.path_prefix}/ours", + f"--save_intermediates_path={self._temp_dir / 'ours'}", "--use-hf", "--attention-dtype=bfloat16", + "--kv-cache-dtype=float8_e4m3fnuz", "--skip-decode", "--block-seq-stride=16", ] command = subprocess.list2cmdline(command) - proc = subprocess.run( - command, shell=True, capture_output=True, cwd=sharktank_dir - ) + subprocess.check_call(command, shell=True, cwd=sharktank_dir) ours = dict() with safe_open(our_path, "pytorch") as st: diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py index 629354a15..d15f951c6 100644 --- a/sharktank/tests/models/t5/t5_test.py +++ b/sharktank/tests/models/t5/t5_test.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools +from copy import copy from transformers.models.t5.modeling_t5 import ( T5Attention as ReferenceT5Attention, T5LayerSelfAttention as ReferenceT5LayerSelfAttention, @@ -15,13 +16,14 @@ T5EncoderModel as ReferenceT5EncoderModel, T5Config as ReferenceT5Config, ) +from transformers.models.auto.tokenization_auto import get_tokenizer_config from typing import Optional import os from collections import OrderedDict import logging import pytest import torch -from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path +from torch.utils._pytree import tree_map from unittest import TestCase from parameterized import parameterized from sharktank.types import ( @@ -29,7 +31,6 @@ DefaultPrimitiveTensor, unbox_tensor, Dataset, - dtype_to_serialized_short_name, ) from sharktank.models.t5 import ( T5Attention, @@ -38,7 +39,7 @@ T5Encoder, T5LayerFF, export_encoder_mlir, - export_encoder_iree_parameters, + import_encoder_dataset_from_hugging_face, ) from sharktank.utils.testing import ( assert_text_encoder_state_close, @@ -168,17 +169,12 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace( ) reference_model.eval() - target_model_name = ( - f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_f32_model" - ) - target_model_path = getattr(self, target_model_name) - dataset = Dataset.load(target_model_path) + dataset = import_encoder_dataset_from_hugging_face(huggingface_repo_id) dataset.root_theta = dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=target_dtype) ) - config = T5Config.from_gguf_properties( + config = T5Config.from_properties( dataset.properties, - feed_forward_proj="gated-gelu", ) input_ids = tokenizer( @@ -258,6 +254,7 @@ def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32(self): "google/t5-v1_1-small", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 0.055. atol=1e-1, ) @@ -285,6 +282,7 @@ def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32(self): "google/t5-v1_1-xxl", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 0.026. atol=5e-2, ) @@ -310,19 +308,20 @@ def runTestV1_1CompareIreeAgainstTorchEager( ).download() tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) - huggingface_repo_id_as_path = ( - f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + reference_dataset = import_encoder_dataset_from_hugging_face( + huggingface_repo_id ) - source_model_name = f"{huggingface_repo_id_as_path}_f32_model" - source_model_path = getattr(self, source_model_name) + target_dataset = copy(reference_dataset) - reference_dataset = Dataset.load(source_model_path) reference_dataset.root_theta = reference_dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=reference_dtype) ) - config = T5Config.from_gguf_properties( + config = T5Config.from_properties( reference_dataset.properties, - feed_forward_proj="gated-gelu", + ) + + target_dataset.root_theta = target_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_dtype) ) input_ids = tokenizer( @@ -334,32 +333,26 @@ def runTestV1_1CompareIreeAgainstTorchEager( input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - reference_dtype_name = dtype_to_serialized_short_name(reference_dtype) - target_dtype_name = dtype_to_serialized_short_name(target_dtype) - target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{target_dtype_name}" - reference_model = T5Encoder(theta=reference_dataset.root_theta, config=config) reference_result_dict = call_torch_module_function( module=reference_model, function_name="forward", kwargs=input_args, - trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{reference_dtype_name}_torch_", + trace_path_prefix=f"{self.path_prefix}torch_", ) reference_result = flatten_for_iree_signature(reference_result_dict) - parameters_path = f"{target_model_path_prefix}.irpa" + parameters_path = f"{self.path_prefix}parameters.irpa" if not self.caching or not os.path.exists(parameters_path): - export_encoder_iree_parameters( - source_model_path, parameters_path, dtype=target_dtype - ) + target_dataset.save(parameters_path) - mlir_path = f"{target_model_path_prefix}.mlir" + mlir_path = f"{self.path_prefix}model.mlir" if not self.caching or not os.path.exists(mlir_path): logger.info("Exporting T5 encoder model to MLIR...") export_encoder_mlir( parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path ) - iree_module_path = f"{target_model_path_prefix}.vmfb" + iree_module_path = f"{self.path_prefix}model.vmfb" if not self.caching or not os.path.exists(iree_module_path): logger.info("Compiling MLIR file...") iree.compiler.compile_file( @@ -386,7 +379,7 @@ def runTestV1_1CompareIreeAgainstTorchEager( args=iree_args, device=iree_devices[0], function_name=f"forward_bs{batch_size}", - trace_path_prefix=f"{target_model_path_prefix}_iree_", + trace_path_prefix=f"{self.path_prefix}iree_", ) ) iree_result = [ @@ -498,19 +491,19 @@ def testCompareAgainstTransformers( theta = Theta( { - "attn_q.weight": DefaultPrimitiveTensor( + "q.weight": DefaultPrimitiveTensor( data=reference_model.q.weight.to(dtype=target_dtype) ), - "attn_k.weight": DefaultPrimitiveTensor( + "k.weight": DefaultPrimitiveTensor( data=reference_model.k.weight.to(dtype=target_dtype) ), - "attn_v.weight": DefaultPrimitiveTensor( + "v.weight": DefaultPrimitiveTensor( data=reference_model.v.weight.to(dtype=target_dtype) ), - "attn_o.weight": DefaultPrimitiveTensor( + "o.weight": DefaultPrimitiveTensor( data=reference_model.o.weight.to(dtype=target_dtype) ), - "attn_rel_b.weight": DefaultPrimitiveTensor( + "relative_attention_bias.weight": DefaultPrimitiveTensor( data=reference_model.relative_attention_bias.weight.to( dtype=target_dtype ) @@ -593,24 +586,24 @@ def testCompareSelfAttentionAgainstTransformers( theta = Theta( { - "attn_q.weight": DefaultPrimitiveTensor( + "SelfAttention.q.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.q.weight.to(dtype=target_dtype) ), - "attn_k.weight": DefaultPrimitiveTensor( + "SelfAttention.k.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.k.weight.to(dtype=target_dtype) ), - "attn_v.weight": DefaultPrimitiveTensor( + "SelfAttention.v.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.v.weight.to(dtype=target_dtype) ), - "attn_o.weight": DefaultPrimitiveTensor( + "SelfAttention.o.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.o.weight.to(dtype=target_dtype) ), - "attn_rel_b.weight": DefaultPrimitiveTensor( + "SelfAttention.relative_attention_bias.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.relative_attention_bias.weight.to( dtype=target_dtype ) ), - "attn_norm.weight": DefaultPrimitiveTensor( + "layer_norm.weight": DefaultPrimitiveTensor( data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } @@ -708,20 +701,20 @@ def testCompareAgainstTransformers( theta = Theta( { - "ffn_gate.weight": DefaultPrimitiveTensor( + "DenseReluDense.wi_0.weight": DefaultPrimitiveTensor( data=reference_model.DenseReluDense.wi_0.weight.to( dtype=target_dtype ) ), - "ffn_up.weight": DefaultPrimitiveTensor( + "DenseReluDense.wi_1.weight": DefaultPrimitiveTensor( data=reference_model.DenseReluDense.wi_1.weight.to( dtype=target_dtype ) ), - "ffn_down.weight": DefaultPrimitiveTensor( + "DenseReluDense.wo.weight": DefaultPrimitiveTensor( data=reference_model.DenseReluDense.wo.weight.to(dtype=target_dtype) ), - "ffn_norm.weight": DefaultPrimitiveTensor( + "layer_norm.weight": DefaultPrimitiveTensor( data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } diff --git a/sharktank/tests/models/vae/vae_test.py b/sharktank/tests/models/vae/vae_test.py index 1e0a1a8f3..43a989f2f 100644 --- a/sharktank/tests/models/vae/vae_test.py +++ b/sharktank/tests/models/vae/vae_test.py @@ -277,7 +277,7 @@ def testVaeIreeVsHuggingFace(self): model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") # TODO: Decomposing attention due to https://github.com/iree-org/iree/issues/19286, remove once issue is resolved - module = export_vae(model, inputs, True) + module = export_vae(model, inputs.to(dtype=dtype), True) module_f32 = export_vae(model_f32, inputs, True) module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir") @@ -317,7 +317,7 @@ def testVaeIreeVsHuggingFace(self): parameters_path="{self._temp_dir}/flux_vae_bf16.irpa", ) - input_args = OrderedDict([("inputs", inputs)]) + input_args = OrderedDict([("inputs", inputs.to(dtype=dtype))]) iree_args = flatten_for_iree_signature(input_args) iree_args = prepare_iree_module_function_args(