Skip to content

Commit

Permalink
Shorten T5Config.from_hugging_face_config[F
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Feb 20, 2025
1 parent 6adaffc commit 26b0ad7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 27 deletions.
32 changes: 8 additions & 24 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Optional
import torch
from transformers import T5Config as T5ConfigHf
Expand Down Expand Up @@ -236,29 +236,13 @@ def __post_init__(self):
def from_hugging_face_config(
config: T5ConfigHf, tokenizer_config: dict[str, Any], **kwargs
) -> "T5Config":
all_kwargs = {
"return_dict": config.return_dict,
"output_hidden_states": config.output_hidden_states,
"output_attentions": config.output_attentions,
"is_encoder_decoder": config.is_encoder_decoder,
"is_decoder": config.is_decoder,
"vocab_size": config.vocab_size,
"context_length": tokenizer_config["model_max_length"],
"d_model": config.d_model,
"d_kv": config.d_kv,
"d_ff": config.d_ff,
"num_layers": config.num_layers,
"num_decoder_layers": config.num_decoder_layers,
"num_heads": config.num_heads,
"relative_attention_num_buckets": config.relative_attention_num_buckets,
"relative_attention_max_distance": config.relative_attention_max_distance,
"layer_norm_epsilon": config.layer_norm_epsilon,
"feed_forward_proj": config.feed_forward_proj,
"use_cache": config.use_cache,
"pad_token_id": config.pad_token_id,
"eos_token_id": config.eos_token_id,
"decoder_start_token_id": config.decoder_start_token_id,
}
all_kwargs = {}
for filed in fields(T5Config):
if hasattr(config, filed.name):
all_kwargs[filed.name] = getattr(config, filed.name)
all_kwargs["context_length"] = tokenizer_config["model_max_length"]
del all_kwargs["is_gated_act"]
del all_kwargs["dense_act_fn"]
all_kwargs.update(kwargs)
return T5Config(**all_kwargs)

Expand Down
4 changes: 1 addition & 3 deletions sharktank/tests/models/t5/t5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
import logging
import pytest
import torch
from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path
from torch.utils._pytree import tree_map
from unittest import TestCase
from parameterized import parameterized
from sharktank.types import (
Theta,
DefaultPrimitiveTensor,
unbox_tensor,
Dataset,
dtype_to_serialized_short_name,
)
from sharktank.models.t5 import (
T5Attention,
Expand All @@ -40,7 +39,6 @@
T5Encoder,
T5LayerFF,
export_encoder_mlir,
export_encoder_iree_parameters,
import_encoder_dataset_from_hugging_face,
)
from sharktank.utils.testing import (
Expand Down

0 comments on commit 26b0ad7

Please sign in to comment.