Skip to content

Commit

Permalink
Merge branch 'main' into llm-integration-caching
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida authored Feb 21, 2025
2 parents f0598a2 + 5250392 commit 46d2adb
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 151 deletions.
83 changes: 37 additions & 46 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Optional
import torch
from transformers import T5Config as T5ConfigHf

from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name


__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]


Expand Down Expand Up @@ -231,53 +233,41 @@ def __post_init__(self):
self.dense_act_fn = "gelu_new"

@staticmethod
def from_gguf_properties(properties: dict[str, Any], **kwargs):
assert properties["general.architecture"] == "t5"
assert (
properties["t5.attention.layer_norm_epsilon"]
== properties["t5.attention.layer_norm_rms_epsilon"]
)

all_kwargs = {"vocab_size": None, "feed_forward_proj": None}

gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
"t5.feed_forward_length": ["d_ff"],
"t5.block_count": ["num_layers", "num_decoder_layers"],
"t5.attention.head_count": ["num_heads"],
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_config_names_map.items()
for config_name in config_names
}
)

gguf_to_optional_config_names_map = {
"t5.decoder_start_token_id": ["decoder_start_token_id"],
}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_optional_config_names_map.items()
for config_name in config_names
if gguf_name in properties
}
)

if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
def from_hugging_face_config(
config: T5ConfigHf, tokenizer_config: dict[str, Any], **kwargs
) -> "T5Config":
all_kwargs = {}
for filed in fields(T5Config):
if hasattr(config, filed.name):
all_kwargs[filed.name] = getattr(config, filed.name)
all_kwargs["context_length"] = tokenizer_config["model_max_length"]
del all_kwargs["is_gated_act"]
del all_kwargs["dense_act_fn"]
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)

@staticmethod
def from_properties(properties: dict[str, Any]) -> "T5Config":
kwargs = dict(properties)
if "SHARK_DATASET_VERSION" in kwargs:
kwargs.pop("SHARK_DATASET_VERSION")
if "activation_dtype" in kwargs and kwargs["activation_dtype"] is not None:
kwargs["activation_dtype"] = serialized_name_to_dtype(
kwargs["activation_dtype"]
)
if "is_gated_act" in kwargs:
kwargs.pop("is_gated_act")
if "dense_act_fn" in kwargs:
kwargs.pop("dense_act_fn")

return T5Config(**kwargs)

def to_properties(self) -> dict[str, Any]:
res = asdict(self)
if self.activation_dtype is not None:
res["activation_dtype"] = dtype_to_serialized_name(self.activation_dtype)
return res


@dataclass
class ClipTextConfig:
Expand Down Expand Up @@ -336,7 +326,8 @@ def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig
@staticmethod
def from_properties(properties: dict[str, Any]) -> "ClipTextConfig":
kwargs = dict(properties)
kwargs.pop("SHARK_DATASET_VERSION")
if "SHARK_DATASET_VERSION" in kwargs:
kwargs.pop("SHARK_DATASET_VERSION")
if "dtype" in kwargs and kwargs["dtype"] is not None:
kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"])

Expand Down
58 changes: 36 additions & 22 deletions sharktank/sharktank/models/t5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
from typing import Optional, Union
from typing import Any, Optional, Union
from pathlib import Path
import torch
from copy import copy
import transformers

from .t5 import T5Config, T5Encoder
from ...types import Dataset
from ...types import Dataset, Theta, DefaultPrimitiveTensor
from ...transforms.dataset import set_float_dtype
from iree.turbine.aot import FxProgramsBuilder, export

__all__ = [
"export_encoder_mlir",
"export_encoder_iree_parameters",
"prune_decoder_parameters",
"import_encoder_dataset_from_hugging_face",
]


Expand All @@ -33,13 +34,8 @@ def export_encoder_mlir(
"""
if isinstance(model, (Path, str)):
dataset = Dataset.load(model)
config = T5Config.from_gguf_properties(
config = T5Config.from_properties(
dataset.properties,
# TODO: add this property to our HuggingFace-to-GGUF conversion script.
# We currently use llama.cpp's converter and it can not make a distinction
# between T5 V1 and V1.1.
# V1 uses ReLU and V1.1 uses gated GeLU.
feed_forward_proj="gated-gelu",
)
model = T5Encoder(theta=dataset.root_theta, config=config)

Expand Down Expand Up @@ -82,18 +78,6 @@ def _(
output.save_mlir(mlir_output_path)


def prune_decoder_parameters(dataset: Dataset):
# Remove decoder tensors/parameters if present.
try:
del dataset.root_theta.tree["dec"]
except KeyError:
pass
try:
del dataset.properties["t5.decoder_start_token_id"]
except KeyError:
pass


def export_encoder_iree_parameters(
model_path_or_dataset: str | Dataset,
output_path: str,
Expand All @@ -107,5 +91,35 @@ def export_encoder_iree_parameters(
dataset.root_theta = dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=dtype)
)
prune_decoder_parameters(dataset)
dataset.save(output_path)


def import_encoder_dataset_from_hugging_face(
repo_id_or_model: transformers.T5EncoderModel | str,
/,
*,
tokenizer_config: dict[str, Any] | None = None,
) -> Dataset:
model = repo_id_or_model
if not isinstance(repo_id_or_model, transformers.T5EncoderModel):
model = transformers.T5EncoderModel.from_pretrained(repo_id_or_model)
from transformers.models.auto.tokenization_auto import get_tokenizer_config

if tokenizer_config is None:
tokenizer_config = get_tokenizer_config(repo_id_or_model)
else:
if tokenizer_config is None:
raise ValueError(
"When providing a model directly tokenizer_config must also be provided."
)

theta = Theta(
{
name: DefaultPrimitiveTensor(data=param, name=name)
for name, param in model.named_parameters()
}
)
config = T5Config.from_hugging_face_config(
model.config, tokenizer_config=tokenizer_config
)
return Dataset(properties=config.to_properties(), root_theta=theta)
52 changes: 35 additions & 17 deletions sharktank/sharktank/models/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,27 @@ def __init__(
activation_dtype: torch.dtype,
):
super().__init__()

ffn_theta = theta("DenseReluDense")
ffn_theta_dict = {}
if is_gated_act:
ffn_theta_dict["ffn_gate"] = ffn_theta("wi_0").tree
ffn_theta_dict["ffn_up"] = ffn_theta("wi_1").tree
else:
ffn_theta_dict["ffn_up"] = ffn_theta("wi").tree
ffn_theta_dict["ffn_down"] = ffn_theta("wo").tree
ffn_theta = Theta(ffn_theta_dict)

self.dense_activation_dense = FFN(
theta=theta, is_gated=is_gated_act, activation_fn=ACT2FN[dense_act_fn]
theta=ffn_theta,
is_gated=is_gated_act,
activation_fn=ACT2FN[dense_act_fn],
)

self.layer_norm = RMSNormLayer(
theta=theta("ffn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype
theta=theta("layer_norm"),
epsilon=layer_norm_epsilon,
dtype=activation_dtype,
)

def forward(self, hidden_states):
Expand Down Expand Up @@ -93,14 +108,14 @@ def __init__(
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.activation_dtype = activation_dtype

self.q = LinearLayer(theta("attn_q"))
self.k = LinearLayer(theta("attn_k"))
self.v = LinearLayer(theta("attn_v"))
self.o = LinearLayer(theta("attn_o"))
self.q = LinearLayer(theta("q"))
self.k = LinearLayer(theta("k"))
self.v = LinearLayer(theta("v"))
self.o = LinearLayer(theta("o"))

if self.has_relative_attention_bias:
self.relative_attention_bias = TokenEmbeddingLayer(
theta("attn_rel_b"), dtype=activation_dtype
theta("relative_attention_bias"), dtype=activation_dtype
)
self.pruned_heads = set()

Expand Down Expand Up @@ -360,7 +375,7 @@ def __init__(
):
super().__init__()
self.attention = T5Attention(
theta=theta,
theta=theta("SelfAttention"),
is_decoder=is_decoder,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
Expand All @@ -371,7 +386,9 @@ def __init__(
has_relative_attention_bias=has_relative_attention_bias,
)
self.layer_norm = RMSNormLayer(
theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype
theta=theta("layer_norm"),
epsilon=layer_norm_epsilon,
dtype=activation_dtype,
)

def forward(
Expand Down Expand Up @@ -482,7 +499,7 @@ def __init__(
self.layer = nn.ModuleList()
self.layer.append(
T5SelfAttention(
theta=theta,
theta=theta(f"layer.{len(self.layer)}"),
is_decoder=is_decoder,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
Expand All @@ -497,7 +514,7 @@ def __init__(
if self.is_decoder:
self.layer.append(
T5CrossAttention(
theta=theta,
theta=theta(f"layer.{len(self.layer)}"),
is_decoder=is_decoder,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
Expand All @@ -511,7 +528,7 @@ def __init__(

self.layer.append(
T5LayerFF(
theta=theta,
theta=theta(f"layer.{len(self.layer)}"),
is_gated_act=is_gated_act,
dense_act_fn=dense_act_fn,
layer_norm_epsilon=layer_norm_epsilon,
Expand Down Expand Up @@ -654,12 +671,11 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None):
self.embed_tokens = embed_tokens
self.config = config
self.is_decoder = config.is_decoder
theta_prefix = "dec" if config.is_decoder else "enc"

self.block = torch.nn.ModuleList(
[
T5Block(
theta=theta(f"{theta_prefix}.blk.{i}"),
theta=theta(f"block.{i}"),
is_decoder=config.is_decoder,
relative_attention_num_buckets=config.relative_attention_num_buckets,
relative_attention_max_distance=config.relative_attention_max_distance,
Expand All @@ -678,7 +694,7 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None):
self.add_module(
"final_layer_norm",
RMSNormLayer(
theta(f"{theta_prefix}.output_norm"),
theta(f"final_layer_norm"),
epsilon=config.layer_norm_epsilon,
dtype=config.activation_dtype,
),
Expand Down Expand Up @@ -1043,15 +1059,17 @@ def __init__(self, theta: Theta, config: T5Config):
self.add_module(
"token_embedding",
TokenEmbeddingLayer(
theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype
theta("shared"), dtype=theta("shared").tensor("weight").dtype
),
)

encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(
theta=theta, config=encoder_config, embed_tokens=self.token_embedding
theta=theta("encoder"),
config=encoder_config,
embed_tokens=self.token_embedding,
)

@property
Expand Down
14 changes: 12 additions & 2 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self, hp, theta: Theta):
self.mid_block = self._create_mid_block(theta("decoder")("mid_block"))
# up
self.up_blocks = nn.ModuleList([])
self.upscale_dtype = theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")(
"weight"
self.upscale_dtype = unbox_tensor(
theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")("weight")
).dtype
for i, up_block_name in enumerate(hp.up_block_types):
up_block_theta = theta("decoder")("up_blocks")(i)
Expand Down Expand Up @@ -74,6 +74,16 @@ def forward(
"latent_embeds": latent_embeds,
},
)
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)

sample = sample / self.hp.scaling_factor + self.hp.shift_factor

if self.hp.use_post_quant_conv:
Expand Down
10 changes: 1 addition & 9 deletions sharktank/sharktank/pipelines/flux/export_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,7 @@ def __init__(self, weight_file, height=1024, width=1024, precision="fp32"):
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
return self.ae.forward(d_in)
return self.ae.forward(z)


def get_ae_model_and_inputs(
Expand Down
Loading

0 comments on commit 46d2adb

Please sign in to comment.