diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index fd8fd6d12..7d1e506a0 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -14,12 +14,14 @@ (and indeed, can bootstrap these off of GGUF files). """ -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Any, Optional import torch +from transformers import T5Config as T5ConfigHf from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name + __all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] @@ -231,53 +233,41 @@ def __post_init__(self): self.dense_act_fn = "gelu_new" @staticmethod - def from_gguf_properties(properties: dict[str, Any], **kwargs): - assert properties["general.architecture"] == "t5" - assert ( - properties["t5.attention.layer_norm_epsilon"] - == properties["t5.attention.layer_norm_rms_epsilon"] - ) - - all_kwargs = {"vocab_size": None, "feed_forward_proj": None} - - gguf_to_config_names_map = { - "t5.context_length": ["context_length"], - "t5.embedding_length": ["d_model"], - "t5.feed_forward_length": ["d_ff"], - "t5.block_count": ["num_layers", "num_decoder_layers"], - "t5.attention.head_count": ["num_heads"], - "t5.attention.key_length": ["d_kv"], - "t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"], - "t5.attention.relative_buckets_count": ["relative_attention_num_buckets"], - "tokenizer.ggml.eos_token_id": ["eos_token_id"], - "tokenizer.ggml.padding_token_id": ["pad_token_id"], - } - all_kwargs.update( - { - config_name: properties[gguf_name] - for gguf_name, config_names in gguf_to_config_names_map.items() - for config_name in config_names - } - ) - - gguf_to_optional_config_names_map = { - "t5.decoder_start_token_id": ["decoder_start_token_id"], - } - all_kwargs.update( - { - config_name: properties[gguf_name] - for gguf_name, config_names in gguf_to_optional_config_names_map.items() - for config_name in config_names - if gguf_name in properties - } - ) - - if "tokenizer.ggml.tokens" in properties: - all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"]) + def from_hugging_face_config( + config: T5ConfigHf, tokenizer_config: dict[str, Any], **kwargs + ) -> "T5Config": + all_kwargs = {} + for filed in fields(T5Config): + if hasattr(config, filed.name): + all_kwargs[filed.name] = getattr(config, filed.name) + all_kwargs["context_length"] = tokenizer_config["model_max_length"] + del all_kwargs["is_gated_act"] + del all_kwargs["dense_act_fn"] all_kwargs.update(kwargs) - return T5Config(**all_kwargs) + @staticmethod + def from_properties(properties: dict[str, Any]) -> "T5Config": + kwargs = dict(properties) + if "SHARK_DATASET_VERSION" in kwargs: + kwargs.pop("SHARK_DATASET_VERSION") + if "activation_dtype" in kwargs and kwargs["activation_dtype"] is not None: + kwargs["activation_dtype"] = serialized_name_to_dtype( + kwargs["activation_dtype"] + ) + if "is_gated_act" in kwargs: + kwargs.pop("is_gated_act") + if "dense_act_fn" in kwargs: + kwargs.pop("dense_act_fn") + + return T5Config(**kwargs) + + def to_properties(self) -> dict[str, Any]: + res = asdict(self) + if self.activation_dtype is not None: + res["activation_dtype"] = dtype_to_serialized_name(self.activation_dtype) + return res + @dataclass class ClipTextConfig: @@ -336,7 +326,8 @@ def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig @staticmethod def from_properties(properties: dict[str, Any]) -> "ClipTextConfig": kwargs = dict(properties) - kwargs.pop("SHARK_DATASET_VERSION") + if "SHARK_DATASET_VERSION" in kwargs: + kwargs.pop("SHARK_DATASET_VERSION") if "dtype" in kwargs and kwargs["dtype"] is not None: kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"]) diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py index 8d5f75db2..5072da025 100644 --- a/sharktank/sharktank/models/t5/export.py +++ b/sharktank/sharktank/models/t5/export.py @@ -5,20 +5,21 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools -from typing import Optional, Union +from typing import Any, Optional, Union from pathlib import Path import torch from copy import copy +import transformers from .t5 import T5Config, T5Encoder -from ...types import Dataset +from ...types import Dataset, Theta, DefaultPrimitiveTensor from ...transforms.dataset import set_float_dtype from iree.turbine.aot import FxProgramsBuilder, export __all__ = [ "export_encoder_mlir", "export_encoder_iree_parameters", - "prune_decoder_parameters", + "import_encoder_dataset_from_hugging_face", ] @@ -33,13 +34,8 @@ def export_encoder_mlir( """ if isinstance(model, (Path, str)): dataset = Dataset.load(model) - config = T5Config.from_gguf_properties( + config = T5Config.from_properties( dataset.properties, - # TODO: add this property to our HuggingFace-to-GGUF conversion script. - # We currently use llama.cpp's converter and it can not make a distinction - # between T5 V1 and V1.1. - # V1 uses ReLU and V1.1 uses gated GeLU. - feed_forward_proj="gated-gelu", ) model = T5Encoder(theta=dataset.root_theta, config=config) @@ -82,18 +78,6 @@ def _( output.save_mlir(mlir_output_path) -def prune_decoder_parameters(dataset: Dataset): - # Remove decoder tensors/parameters if present. - try: - del dataset.root_theta.tree["dec"] - except KeyError: - pass - try: - del dataset.properties["t5.decoder_start_token_id"] - except KeyError: - pass - - def export_encoder_iree_parameters( model_path_or_dataset: str | Dataset, output_path: str, @@ -107,5 +91,35 @@ def export_encoder_iree_parameters( dataset.root_theta = dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=dtype) ) - prune_decoder_parameters(dataset) dataset.save(output_path) + + +def import_encoder_dataset_from_hugging_face( + repo_id_or_model: transformers.T5EncoderModel | str, + /, + *, + tokenizer_config: dict[str, Any] | None = None, +) -> Dataset: + model = repo_id_or_model + if not isinstance(repo_id_or_model, transformers.T5EncoderModel): + model = transformers.T5EncoderModel.from_pretrained(repo_id_or_model) + from transformers.models.auto.tokenization_auto import get_tokenizer_config + + if tokenizer_config is None: + tokenizer_config = get_tokenizer_config(repo_id_or_model) + else: + if tokenizer_config is None: + raise ValueError( + "When providing a model directly tokenizer_config must also be provided." + ) + + theta = Theta( + { + name: DefaultPrimitiveTensor(data=param, name=name) + for name, param in model.named_parameters() + } + ) + config = T5Config.from_hugging_face_config( + model.config, tokenizer_config=tokenizer_config + ) + return Dataset(properties=config.to_properties(), root_theta=theta) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py index a2a8958be..9af2f5636 100644 --- a/sharktank/sharktank/models/t5/t5.py +++ b/sharktank/sharktank/models/t5/t5.py @@ -54,12 +54,27 @@ def __init__( activation_dtype: torch.dtype, ): super().__init__() + + ffn_theta = theta("DenseReluDense") + ffn_theta_dict = {} + if is_gated_act: + ffn_theta_dict["ffn_gate"] = ffn_theta("wi_0").tree + ffn_theta_dict["ffn_up"] = ffn_theta("wi_1").tree + else: + ffn_theta_dict["ffn_up"] = ffn_theta("wi").tree + ffn_theta_dict["ffn_down"] = ffn_theta("wo").tree + ffn_theta = Theta(ffn_theta_dict) + self.dense_activation_dense = FFN( - theta=theta, is_gated=is_gated_act, activation_fn=ACT2FN[dense_act_fn] + theta=ffn_theta, + is_gated=is_gated_act, + activation_fn=ACT2FN[dense_act_fn], ) self.layer_norm = RMSNormLayer( - theta=theta("ffn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + theta=theta("layer_norm"), + epsilon=layer_norm_epsilon, + dtype=activation_dtype, ) def forward(self, hidden_states): @@ -93,14 +108,14 @@ def __init__( self.inner_dim = self.n_heads * self.key_value_proj_dim self.activation_dtype = activation_dtype - self.q = LinearLayer(theta("attn_q")) - self.k = LinearLayer(theta("attn_k")) - self.v = LinearLayer(theta("attn_v")) - self.o = LinearLayer(theta("attn_o")) + self.q = LinearLayer(theta("q")) + self.k = LinearLayer(theta("k")) + self.v = LinearLayer(theta("v")) + self.o = LinearLayer(theta("o")) if self.has_relative_attention_bias: self.relative_attention_bias = TokenEmbeddingLayer( - theta("attn_rel_b"), dtype=activation_dtype + theta("relative_attention_bias"), dtype=activation_dtype ) self.pruned_heads = set() @@ -360,7 +375,7 @@ def __init__( ): super().__init__() self.attention = T5Attention( - theta=theta, + theta=theta("SelfAttention"), is_decoder=is_decoder, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -371,7 +386,9 @@ def __init__( has_relative_attention_bias=has_relative_attention_bias, ) self.layer_norm = RMSNormLayer( - theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + theta=theta("layer_norm"), + epsilon=layer_norm_epsilon, + dtype=activation_dtype, ) def forward( @@ -482,7 +499,7 @@ def __init__( self.layer = nn.ModuleList() self.layer.append( T5SelfAttention( - theta=theta, + theta=theta(f"layer.{len(self.layer)}"), is_decoder=is_decoder, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -497,7 +514,7 @@ def __init__( if self.is_decoder: self.layer.append( T5CrossAttention( - theta=theta, + theta=theta(f"layer.{len(self.layer)}"), is_decoder=is_decoder, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -511,7 +528,7 @@ def __init__( self.layer.append( T5LayerFF( - theta=theta, + theta=theta(f"layer.{len(self.layer)}"), is_gated_act=is_gated_act, dense_act_fn=dense_act_fn, layer_norm_epsilon=layer_norm_epsilon, @@ -654,12 +671,11 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): self.embed_tokens = embed_tokens self.config = config self.is_decoder = config.is_decoder - theta_prefix = "dec" if config.is_decoder else "enc" self.block = torch.nn.ModuleList( [ T5Block( - theta=theta(f"{theta_prefix}.blk.{i}"), + theta=theta(f"block.{i}"), is_decoder=config.is_decoder, relative_attention_num_buckets=config.relative_attention_num_buckets, relative_attention_max_distance=config.relative_attention_max_distance, @@ -678,7 +694,7 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): self.add_module( "final_layer_norm", RMSNormLayer( - theta(f"{theta_prefix}.output_norm"), + theta(f"final_layer_norm"), epsilon=config.layer_norm_epsilon, dtype=config.activation_dtype, ), @@ -1043,7 +1059,7 @@ def __init__(self, theta: Theta, config: T5Config): self.add_module( "token_embedding", TokenEmbeddingLayer( - theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype + theta("shared"), dtype=theta("shared").tensor("weight").dtype ), ) @@ -1051,7 +1067,9 @@ def __init__(self, theta: Theta, config: T5Config): encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = T5Stack( - theta=theta, config=encoder_config, embed_tokens=self.token_embedding + theta=theta("encoder"), + config=encoder_config, + embed_tokens=self.token_embedding, ) @property diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py index 126c05f2c..881923b8b 100644 --- a/sharktank/sharktank/models/vae/model.py +++ b/sharktank/sharktank/models/vae/model.py @@ -37,8 +37,8 @@ def __init__(self, hp, theta: Theta): self.mid_block = self._create_mid_block(theta("decoder")("mid_block")) # up self.up_blocks = nn.ModuleList([]) - self.upscale_dtype = theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")( - "weight" + self.upscale_dtype = unbox_tensor( + theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")("weight") ).dtype for i, up_block_name in enumerate(hp.up_block_types): up_block_theta = theta("decoder")("up_blocks")(i) @@ -74,6 +74,16 @@ def forward( "latent_embeds": latent_embeds, }, ) + if not self.hp.use_post_quant_conv: + sample = rearrange( + sample, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(1024 / 16), + w=math.ceil(1024 / 16), + ph=2, + pw=2, + ) + sample = sample / self.hp.scaling_factor + self.hp.shift_factor if self.hp.use_post_quant_conv: diff --git a/sharktank/sharktank/pipelines/flux/export_components.py b/sharktank/sharktank/pipelines/flux/export_components.py index 85c7888de..334dbdadb 100644 --- a/sharktank/sharktank/pipelines/flux/export_components.py +++ b/sharktank/sharktank/pipelines/flux/export_components.py @@ -267,15 +267,7 @@ def __init__(self, weight_file, height=1024, width=1024, precision="fp32"): self.width = width def forward(self, z): - d_in = rearrange( - z, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(self.height / 16), - w=math.ceil(self.width / 16), - ph=2, - pw=2, - ) - return self.ae.forward(d_in) + return self.ae.forward(z) def get_ae_model_and_inputs( diff --git a/sharktank/sharktank/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py index d4e9c43be..deeb741b2 100644 --- a/sharktank/sharktank/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -52,7 +52,9 @@ def import_hf_dataset( for params_path in param_paths: with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: tensors += [ - DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) + DefaultPrimitiveTensor( + name=name, data=st.get_tensor(name).to(target_dtype) + ) for name in st.keys() ] diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index c38976e01..6ae0a6a88 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -128,7 +128,7 @@ def torch_tensor_to_device_array( if tensor.dtype == torch.bfloat16: tensor_as_int16 = tensor.view(dtype=torch.int16) device_array_as_int16 = iree.runtime.asdevicearray( - device, unbox_tensor(tensor_as_int16).to("cpu").numpy() + device, unbox_tensor(tensor_as_int16).to("cpu").detach().numpy() ) buffer_view = iree.runtime.HalBufferView( buffer=device_array_as_int16._buffer_view.get_buffer(), @@ -137,7 +137,9 @@ def torch_tensor_to_device_array( ) return iree.runtime.DeviceArray(device, buffer_view) - return iree.runtime.asdevicearray(device, unbox_tensor(tensor).to("cpu").numpy()) + return iree.runtime.asdevicearray( + device, unbox_tensor(tensor).to("cpu").detach().numpy() + ) def run_iree_module_function( @@ -162,7 +164,7 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + promote_bfloat16_to_float32(device_array_to_host(arg)).detach().numpy(), ) results = invoker(*args) if isinstance(results, iree.runtime.DeviceArray): @@ -172,12 +174,12 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", - device_array_to_host(arg).numpy(), + device_array_to_host(arg).detach().numpy(), ) for i, arg in enumerate(results): np.save( f"{trace_path_prefix}{function_name}_result{i}.npy", - promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + promote_bfloat16_to_float32(device_array_to_host(arg)).detach().numpy(), ) return results @@ -233,7 +235,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - promote_bfloat16_to_float32(arg.to("cpu")).numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).detach().numpy(), ) res = getattr(module, function_name)(*args, **kwargs) if trace_path_prefix is not None: @@ -241,7 +243,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - promote_bfloat16_to_float32(arg.to("cpu")).numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).detach().numpy(), ) results = ( (res,) @@ -258,7 +260,7 @@ def call_torch_module_function( for i, result in enumerate(flat_results): np.save( f"{trace_path_prefix}{function_name}_result{i}.npy", - result.to("cpu").numpy(), + result.to("cpu").detach().numpy(), ) return res diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py index 629354a15..d15f951c6 100644 --- a/sharktank/tests/models/t5/t5_test.py +++ b/sharktank/tests/models/t5/t5_test.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools +from copy import copy from transformers.models.t5.modeling_t5 import ( T5Attention as ReferenceT5Attention, T5LayerSelfAttention as ReferenceT5LayerSelfAttention, @@ -15,13 +16,14 @@ T5EncoderModel as ReferenceT5EncoderModel, T5Config as ReferenceT5Config, ) +from transformers.models.auto.tokenization_auto import get_tokenizer_config from typing import Optional import os from collections import OrderedDict import logging import pytest import torch -from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path +from torch.utils._pytree import tree_map from unittest import TestCase from parameterized import parameterized from sharktank.types import ( @@ -29,7 +31,6 @@ DefaultPrimitiveTensor, unbox_tensor, Dataset, - dtype_to_serialized_short_name, ) from sharktank.models.t5 import ( T5Attention, @@ -38,7 +39,7 @@ T5Encoder, T5LayerFF, export_encoder_mlir, - export_encoder_iree_parameters, + import_encoder_dataset_from_hugging_face, ) from sharktank.utils.testing import ( assert_text_encoder_state_close, @@ -168,17 +169,12 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace( ) reference_model.eval() - target_model_name = ( - f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_f32_model" - ) - target_model_path = getattr(self, target_model_name) - dataset = Dataset.load(target_model_path) + dataset = import_encoder_dataset_from_hugging_face(huggingface_repo_id) dataset.root_theta = dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=target_dtype) ) - config = T5Config.from_gguf_properties( + config = T5Config.from_properties( dataset.properties, - feed_forward_proj="gated-gelu", ) input_ids = tokenizer( @@ -258,6 +254,7 @@ def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32(self): "google/t5-v1_1-small", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 0.055. atol=1e-1, ) @@ -285,6 +282,7 @@ def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32(self): "google/t5-v1_1-xxl", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 0.026. atol=5e-2, ) @@ -310,19 +308,20 @@ def runTestV1_1CompareIreeAgainstTorchEager( ).download() tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) - huggingface_repo_id_as_path = ( - f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + reference_dataset = import_encoder_dataset_from_hugging_face( + huggingface_repo_id ) - source_model_name = f"{huggingface_repo_id_as_path}_f32_model" - source_model_path = getattr(self, source_model_name) + target_dataset = copy(reference_dataset) - reference_dataset = Dataset.load(source_model_path) reference_dataset.root_theta = reference_dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=reference_dtype) ) - config = T5Config.from_gguf_properties( + config = T5Config.from_properties( reference_dataset.properties, - feed_forward_proj="gated-gelu", + ) + + target_dataset.root_theta = target_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_dtype) ) input_ids = tokenizer( @@ -334,32 +333,26 @@ def runTestV1_1CompareIreeAgainstTorchEager( input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - reference_dtype_name = dtype_to_serialized_short_name(reference_dtype) - target_dtype_name = dtype_to_serialized_short_name(target_dtype) - target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{target_dtype_name}" - reference_model = T5Encoder(theta=reference_dataset.root_theta, config=config) reference_result_dict = call_torch_module_function( module=reference_model, function_name="forward", kwargs=input_args, - trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{reference_dtype_name}_torch_", + trace_path_prefix=f"{self.path_prefix}torch_", ) reference_result = flatten_for_iree_signature(reference_result_dict) - parameters_path = f"{target_model_path_prefix}.irpa" + parameters_path = f"{self.path_prefix}parameters.irpa" if not self.caching or not os.path.exists(parameters_path): - export_encoder_iree_parameters( - source_model_path, parameters_path, dtype=target_dtype - ) + target_dataset.save(parameters_path) - mlir_path = f"{target_model_path_prefix}.mlir" + mlir_path = f"{self.path_prefix}model.mlir" if not self.caching or not os.path.exists(mlir_path): logger.info("Exporting T5 encoder model to MLIR...") export_encoder_mlir( parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path ) - iree_module_path = f"{target_model_path_prefix}.vmfb" + iree_module_path = f"{self.path_prefix}model.vmfb" if not self.caching or not os.path.exists(iree_module_path): logger.info("Compiling MLIR file...") iree.compiler.compile_file( @@ -386,7 +379,7 @@ def runTestV1_1CompareIreeAgainstTorchEager( args=iree_args, device=iree_devices[0], function_name=f"forward_bs{batch_size}", - trace_path_prefix=f"{target_model_path_prefix}_iree_", + trace_path_prefix=f"{self.path_prefix}iree_", ) ) iree_result = [ @@ -498,19 +491,19 @@ def testCompareAgainstTransformers( theta = Theta( { - "attn_q.weight": DefaultPrimitiveTensor( + "q.weight": DefaultPrimitiveTensor( data=reference_model.q.weight.to(dtype=target_dtype) ), - "attn_k.weight": DefaultPrimitiveTensor( + "k.weight": DefaultPrimitiveTensor( data=reference_model.k.weight.to(dtype=target_dtype) ), - "attn_v.weight": DefaultPrimitiveTensor( + "v.weight": DefaultPrimitiveTensor( data=reference_model.v.weight.to(dtype=target_dtype) ), - "attn_o.weight": DefaultPrimitiveTensor( + "o.weight": DefaultPrimitiveTensor( data=reference_model.o.weight.to(dtype=target_dtype) ), - "attn_rel_b.weight": DefaultPrimitiveTensor( + "relative_attention_bias.weight": DefaultPrimitiveTensor( data=reference_model.relative_attention_bias.weight.to( dtype=target_dtype ) @@ -593,24 +586,24 @@ def testCompareSelfAttentionAgainstTransformers( theta = Theta( { - "attn_q.weight": DefaultPrimitiveTensor( + "SelfAttention.q.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.q.weight.to(dtype=target_dtype) ), - "attn_k.weight": DefaultPrimitiveTensor( + "SelfAttention.k.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.k.weight.to(dtype=target_dtype) ), - "attn_v.weight": DefaultPrimitiveTensor( + "SelfAttention.v.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.v.weight.to(dtype=target_dtype) ), - "attn_o.weight": DefaultPrimitiveTensor( + "SelfAttention.o.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.o.weight.to(dtype=target_dtype) ), - "attn_rel_b.weight": DefaultPrimitiveTensor( + "SelfAttention.relative_attention_bias.weight": DefaultPrimitiveTensor( data=reference_model.SelfAttention.relative_attention_bias.weight.to( dtype=target_dtype ) ), - "attn_norm.weight": DefaultPrimitiveTensor( + "layer_norm.weight": DefaultPrimitiveTensor( data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } @@ -708,20 +701,20 @@ def testCompareAgainstTransformers( theta = Theta( { - "ffn_gate.weight": DefaultPrimitiveTensor( + "DenseReluDense.wi_0.weight": DefaultPrimitiveTensor( data=reference_model.DenseReluDense.wi_0.weight.to( dtype=target_dtype ) ), - "ffn_up.weight": DefaultPrimitiveTensor( + "DenseReluDense.wi_1.weight": DefaultPrimitiveTensor( data=reference_model.DenseReluDense.wi_1.weight.to( dtype=target_dtype ) ), - "ffn_down.weight": DefaultPrimitiveTensor( + "DenseReluDense.wo.weight": DefaultPrimitiveTensor( data=reference_model.DenseReluDense.wo.weight.to(dtype=target_dtype) ), - "ffn_norm.weight": DefaultPrimitiveTensor( + "layer_norm.weight": DefaultPrimitiveTensor( data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } diff --git a/sharktank/tests/models/vae/vae_test.py b/sharktank/tests/models/vae/vae_test.py index 1e0a1a8f3..43a989f2f 100644 --- a/sharktank/tests/models/vae/vae_test.py +++ b/sharktank/tests/models/vae/vae_test.py @@ -277,7 +277,7 @@ def testVaeIreeVsHuggingFace(self): model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") # TODO: Decomposing attention due to https://github.com/iree-org/iree/issues/19286, remove once issue is resolved - module = export_vae(model, inputs, True) + module = export_vae(model, inputs.to(dtype=dtype), True) module_f32 = export_vae(model_f32, inputs, True) module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir") @@ -317,7 +317,7 @@ def testVaeIreeVsHuggingFace(self): parameters_path="{self._temp_dir}/flux_vae_bf16.irpa", ) - input_args = OrderedDict([("inputs", inputs)]) + input_args = OrderedDict([("inputs", inputs.to(dtype=dtype))]) iree_args = flatten_for_iree_signature(input_args) iree_args = prepare_iree_module_function_args(