From fa5acb125afdf57274ec064524b2228a33bc58a1 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Tue, 10 May 2022 11:11:40 +0800 Subject: [PATCH] beit support non-square input_shape --- keras_cv_attention_models/beit/beit.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/keras_cv_attention_models/beit/beit.py b/keras_cv_attention_models/beit/beit.py index 9516b0aa..befcd0f5 100644 --- a/keras_cv_attention_models/beit/beit.py +++ b/keras_cv_attention_models/beit/beit.py @@ -39,7 +39,7 @@ def __init__(self, with_cls_token=True, attn_height=-1, num_heads=-1, **kwargs): def build(self, attn_shape): # print(attn_shape) if self.attn_height == -1: - height = width = int(tf.math.sqrt(float(attn_shape[2] - self.cls_token_len))) # assume hh == ww, e.g. 14 + height = width = int(tf.math.sqrt(float(attn_shape[2] - self.cls_token_len))) # hh == ww, e.g. 14 else: height = self.attn_height width = int(float(attn_shape[2] - self.cls_token_len) / height) @@ -121,7 +121,7 @@ def show_pos_emb(self, rows=1, base_size=2): return fig -def attention_block(inputs, num_heads=4, key_dim=0, out_weight=True, out_bias=False, qv_bias=True, attn_dropout=0, name=None): +def attention_block(inputs, num_heads=4, key_dim=0, out_weight=True, out_bias=False, qv_bias=True, attn_height=-1, attn_dropout=0, name=None): _, bb, cc = inputs.shape key_dim = key_dim if key_dim > 0 else cc // num_heads qk_scale = float(1.0 / tf.math.sqrt(tf.cast(key_dim, "float32"))) @@ -146,7 +146,7 @@ def attention_block(inputs, num_heads=4, key_dim=0, out_weight=True, out_bias=Fa query *= qk_scale # [batch, num_heads, cls_token + hh * ww, cls_token + hh * ww] attention_scores = keras.layers.Lambda(lambda xx: tf.matmul(xx[0], xx[1]))([query, key]) - attention_scores = MultiHeadRelativePositionalEmbedding(name=name and name + "pos_emb")(attention_scores) + attention_scores = MultiHeadRelativePositionalEmbedding(attn_height=attn_height, name=name and name + "pos_emb")(attention_scores) # attention_scores = tf.nn.softmax(attention_scores, axis=-1, name=name and name + "_attention_scores") attention_scores = keras.layers.Softmax(axis=-1, name=name and name + "attention_scores")(attention_scores) @@ -225,6 +225,7 @@ def Beit( """ forward_embeddings """ nn = conv2d_no_bias(inputs, embed_dim, patch_size, strides=patch_size, padding="valid", use_bias=True, name="stem_") + patch_height = nn.shape[1] nn = keras.layers.Reshape([-1, nn.shape[-1]])(nn) nn = ClassToken(name="cls_token")(nn) @@ -234,6 +235,7 @@ def Beit( "qv_bias": attn_qv_bias, "out_weight": attn_out_weight, "out_bias": attn_out_bias, + "attn_height": patch_height, "attn_dropout": attn_dropout, }