From d4e5b53e7d91b752e622c71f50d85a6a1cac0acf Mon Sep 17 00:00:00 2001 From: leondgarse Date: Fri, 4 Aug 2023 21:37:11 +0800 Subject: [PATCH] add keras_core_functional.py --- .../keras_core_functional.py | 34 +++++++++++++++ keras_cv_attention_models/llama2/__init__.py | 41 +++++++++++++++++++ .../pytorch_backend/initializers.py | 5 ++- 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 keras_cv_attention_models/keras_core_functional.py create mode 100644 keras_cv_attention_models/llama2/__init__.py diff --git a/keras_cv_attention_models/keras_core_functional.py b/keras_cv_attention_models/keras_core_functional.py new file mode 100644 index 00000000..88e6a5d7 --- /dev/null +++ b/keras_cv_attention_models/keras_core_functional.py @@ -0,0 +1,34 @@ +import keras_core +from keras_core.ops import * +from keras_core.ops import concatenate as concat +from keras_core.ops import mean as reduce_mean +from keras_core.ops import sum as reduce_sum +from keras_core.ops import max as reduce_max +from keras_core.ops import min as reduce_min +from keras_core.ops import power as pow +from keras_core.ops import clip as clip_by_value +from keras_core.ops.image import extract_patches + + +def resize(images, size, method="bilinear", preserve_aspect_ratio=False, antialias=False, name=None): + return keras_core.ops.image.resize(images, size, interpolation=method, antialias=antialias, data_format=keras_core.backend.image_data_format()) + + +def split(inputs, num_or_size_splits, axis=0, num=None, name="split"): + if isinstance(num_or_size_splits, int): + return keras_core.ops.split(inputs, num_or_size_splits, axis=axis) + + axis = (len(inputs.shape) + axis) if axis < 0 else axis + split_axis_shape = inputs.shape[axis] + assert split_axis_shape is not None + + size_splits = num_or_size_splits + size_splits = [0 if ii is None or ii == -1 else ii for ii in size_splits] + num_unknown_dim = sum([ii == 0 for ii in size_splits]) + assert num_unknown_dim < 2, "At most one unknown dimension in num_or_size_splits: {}".format(num_or_size_splits) + + if num_unknown_dim == 1: + size_splits = [(split_axis_shape - sum(size_splits)) if ii == 0 else ii for ii in size_splits] + + cum_split = [sum(num_or_size_splits[: id + 1]) for id, _ in enumerate(size_splits[:-1])] + return keras_core.ops.split(inputs, cum_split, axis=axis) diff --git a/keras_cv_attention_models/llama2/__init__.py b/keras_cv_attention_models/llama2/__init__.py new file mode 100644 index 00000000..2d11c739 --- /dev/null +++ b/keras_cv_attention_models/llama2/__init__.py @@ -0,0 +1,41 @@ +from keras_cv_attention_models.llama2.llama2 import Llama2, Llama2_7B, RunPrediction, PositionalEncodingFourierRot1D, RMSNorm + +__head_doc__ = """ +Keras implementation of [Github openai/gpt-2](https://github.com/openai/gpt-2). +Paper [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf). +""" + +__tail_doc__ = """ vocab_size: model vocab size. + max_block_size: number of tokens generated in each sample. + include_top: boolena value if including output Dense head layer. Set false to exclude the head layer. + dropout: float value for drop out rate for Embedding layer and attention blocks. + activation: activation used in whole model, default `gelu/app`. + pretrained: None or one of ["webtext", "huggingface"]. + - if "webtext", will try to download and load ported weights if available. + - if "huggingface", will try converting and loading weights from huggingface `transformers` pacakge. + - if None, will initialize model with ranbdom weights. + +Returns: + A `keras.Model` instance. +""" + +Llama2.__doc__ = __head_doc__ + """ +Args: + num_blocks: . + embedding_size: . + num_heads: . + block_use_bias: . + model_name: string, model name. +""" + __tail_doc__ + """ +Model architectures: + | Model | Params | FLOPs | vocab_size | LAMBADA PPL | + | ------------| ------- | ------- | ---------- | ----------- | + | GPT2_Base | 163.04M | 146.42G | 50257 | 35.13 | + | GPT2_Medium | 406.29M | 415.07G | 50257 | 15.60 | + | GPT2_Large | 838.36M | 890.28G | 50257 | 10.87 | + | GPT2_XLarge | 1.638B | 1758.3G | 50257 | 8.63 | +""" + +Llama2_7B.__doc__ = __head_doc__ + """ +Args: +""" + __tail_doc__ diff --git a/keras_cv_attention_models/pytorch_backend/initializers.py b/keras_cv_attention_models/pytorch_backend/initializers.py index 7002db35..78bcd8b0 100644 --- a/keras_cv_attention_models/pytorch_backend/initializers.py +++ b/keras_cv_attention_models/pytorch_backend/initializers.py @@ -63,7 +63,10 @@ def __init__(self, value=0): super().__init__(seed=None) def __call__(self, shape, dtype=None, **kwargs): - return torch.nn.init.constant_(torch.empty(shape), val=self.value) + if hasattr(self.value, "shape") and tuple(self.value.shape) == tuple(shape): + return self.value + else: + return torch.nn.init.constant_(torch.empty(shape), val=self.value) def get_config(self): return {"value": self.value}