From 3442b1f7aef67b82d3bfa66ae1d7fed7ccc8fc5a Mon Sep 17 00:00:00 2001 From: leondgarse Date: Tue, 10 May 2022 10:59:59 +0800 Subject: [PATCH] fix coat for non-square input --- keras_cv_attention_models/coat/README.md | 2 +- keras_cv_attention_models/coat/coat.py | 79 +++++++++++---------- keras_cv_attention_models/coco/eval_func.py | 2 +- tests/test_models.py | 8 +++ 4 files changed, 52 insertions(+), 39 deletions(-) diff --git a/keras_cv_attention_models/coat/README.md b/keras_cv_attention_models/coat/README.md index 5c92e582..ccd94c78 100644 --- a/keras_cv_attention_models/coat/README.md +++ b/keras_cv_attention_models/coat/README.md @@ -50,7 +50,7 @@ print(mm.output_shape) # [(None, 784, 216), (None, 196, 216), (None, 49, 216)] ``` - Set `use_shared_cpe=False, use_shared_crpe=False` to disable using shared `ConvPositionalEncoding` and `ConvRelativePositionalEncoding` blocks. will have a better structure view using `netron` or other visualization tools. + Set `use_shared_cpe=False, use_shared_crpe=False` to disable using shared `ConvPositionalEncoding` and `ConvRelativePositionalEncoding` blocks. will have a better structure view using `netron` or other visualization tools. Note it's for checking model architecture only, keep input_shape `height == width` if set False. ```py mm = coat.CoaTMini(pretrained="imagenet", classifier_activation=None, input_shape=(224, 224, 3)) mm.summary() diff --git a/keras_cv_attention_models/coat/coat.py b/keras_cv_attention_models/coat/coat.py index 63ec5402..d63715ae 100644 --- a/keras_cv_attention_models/coat/coat.py +++ b/keras_cv_attention_models/coat/coat.py @@ -23,14 +23,16 @@ def mlp_block(inputs, hidden_dim, activation="gelu", name=None): @keras.utils.register_keras_serializable(package="coat") class ConvPositionalEncoding(keras.layers.Layer): - def __init__(self, kernel_size=3, **kwargs): + def __init__(self, kernel_size=3, input_height=-1, **kwargs): super(ConvPositionalEncoding, self).__init__(**kwargs) - self.kernel_size = kernel_size + self.kernel_size, self.input_height = kernel_size, input_height self.pad = [[0, 0], [kernel_size // 2, kernel_size // 2], [kernel_size // 2, kernel_size // 2], [0, 0]] self.supports_masking = False def build(self, input_shape): - self.height = self.width = int(tf.math.sqrt(float(input_shape[1] - 1))) # assume hh == ww + self.height = self.input_height if self.input_height > 0 else int(tf.math.sqrt(float(input_shape[1] - 1))) + self.width = (input_shape[1] - 1) // self.height + self.channel = input_shape[-1] # Conv2D with goups=self.channel self.dconv = keras.layers.DepthwiseConv2D( @@ -54,20 +56,21 @@ def compute_output_shape(self, input_shape): def get_config(self): base_config = super(ConvPositionalEncoding, self).get_config() - base_config.update({"kernel_size": self.kernel_size}) + base_config.update({"kernel_size": self.kernel_size, "input_height": self.input_height}) return base_config @keras.utils.register_keras_serializable(package="coat") class ConvRelativePositionalEncoding(keras.layers.Layer): - def __init__(self, head_splits=[2, 3, 3], head_kernel_size=[3, 5, 7], **kwargs): + def __init__(self, head_splits=[2, 3, 3], head_kernel_size=[3, 5, 7], input_height=-1, **kwargs): super(ConvRelativePositionalEncoding, self).__init__(**kwargs) - self.head_splits, self.head_kernel_size = head_splits, head_kernel_size + self.head_splits, self.head_kernel_size, self.input_height = head_splits, head_kernel_size, input_height self.supports_masking = False def build(self, query_shape): # print(query_shape) - self.height = self.width = int(tf.math.sqrt(float(query_shape[2] - 1))) # assume hh == ww + self.height = self.input_height if self.input_height > 0 else int(tf.math.sqrt(float(query_shape[2] - 1))) + self.width = (query_shape[2] - 1) // self.height self.num_heads, self.query_dim = query_shape[1], query_shape[-1] self.channel_splits = [ii * self.query_dim for ii in self.head_splits] @@ -104,7 +107,7 @@ def call(self, query, value, **kwargs): def get_config(self): base_config = super(ConvRelativePositionalEncoding, self).get_config() - base_config.update({"head_splits": self.head_splits, "head_kernel_size": self.head_kernel_size}) + base_config.update({"head_splits": self.head_splits, "head_kernel_size": self.head_kernel_size, "input_height": self.input_height}) return base_config @@ -185,39 +188,35 @@ def serial_block(inputs, embed_dim, shared_cpe=None, shared_crpe=None, num_heads return out -def resample(image, class_token=None, factor=1): - out_hh, out_ww = int(image.shape[1] * factor), int(image.shape[2] * factor) - out_image = tf.cast(tf.image.resize(image, [out_hh, out_ww], method="bilinear"), image.dtype) - # if factor > 1: - # out_image = keras.layers.UpSampling2D(factor, interpolation='bilinear')(image) - # elif factor == 1: - # out_image = image - # else: - # size = int(1 / factor) - # out_image = keras.layers.AvgPool2D(size, strides=size)(image) +def resample(image, target_shape, class_token=None): + out_image = tf.cast(tf.image.resize(image, target_shape, method="bilinear"), image.dtype) if class_token is not None: - out_image = tf.reshape(out_image, [-1, out_hh * out_ww, out_image.shape[-1]]) + out_image = tf.reshape(out_image, [-1, out_image.shape[1] * out_image.shape[2], out_image.shape[-1]]) return tf.concat([class_token, out_image], axis=1) else: return out_image -def parallel_block(inputs, shared_cpes=None, shared_crpes=None, num_heads=8, mlp_ratios=[], drop_rate=0, activation="gelu", name=""): +def parallel_block(inputs, shared_cpes=None, shared_crpes=None, block_heights=[], num_heads=8, mlp_ratios=[], drop_rate=0, activation="gelu", name=""): # Conv-Attention. - cpe_outs, crpe_outs, crpe_images = [], [], [] + # print(f">>>> {block_heights = }") + cpe_outs, crpe_outs, crpe_images, resample_shapes = [], [], [], [] + block_heights = block_heights[1:] for id, (xx, shared_cpe, shared_crpe) in enumerate(zip(inputs[1:], shared_cpes[1:], shared_crpes[1:])): cur_name = name + "{}_".format(id + 2) cpe_out, crpe_out = __cpe_norm_crpe__(xx, shared_cpe, shared_crpe, num_heads, name=cur_name) cpe_outs.append(cpe_out) crpe_outs.append(crpe_out) - hh = ww = int(tf.math.sqrt(float(crpe_out.shape[1] - 1))) # assume hh == ww - crpe_images.append(tf.reshape(crpe_out[:, 1:, :], [-1, hh, ww, crpe_out.shape[-1]])) + height = block_heights[id] if len(block_heights) > id else int(tf.math.sqrt(float(crpe_out.shape[1] - 1))) + width = (crpe_out.shape[1] - 1) // height + crpe_images.append(tf.reshape(crpe_out[:, 1:, :], [-1, height, width, crpe_out.shape[-1]])) + resample_shapes.append([height, width]) # print(f">>>> {crpe_out.shape = }, {crpe_images[-1].shape = }") crpe_stack = [ # [[None, 28, 28, 152], [None, 14, 14, 152], [None, 7, 7, 152]] - crpe_outs[0] + resample(crpe_images[1], crpe_outs[1][:, :1], factor=2) + resample(crpe_images[2], crpe_outs[2][:, :1], factor=4), - crpe_outs[1] + resample(crpe_images[2], crpe_outs[2][:, :1], factor=2) + resample(crpe_images[0], crpe_outs[0][:, :1], factor=1 / 2), - crpe_outs[2] + resample(crpe_images[1], crpe_outs[1][:, :1], factor=1 / 2) + resample(crpe_images[0], crpe_outs[0][:, :1], factor=1 / 4), + crpe_outs[0] + resample(crpe_images[1], resample_shapes[0], crpe_outs[1][:, :1]) + resample(crpe_images[2], resample_shapes[0], crpe_outs[2][:, :1]), + crpe_outs[1] + resample(crpe_images[2], resample_shapes[1], crpe_outs[2][:, :1]) + resample(crpe_images[0], resample_shapes[1], crpe_outs[0][:, :1]), + crpe_outs[2] + resample(crpe_images[1], resample_shapes[2], crpe_outs[1][:, :1]) + resample(crpe_images[0], resample_shapes[2], crpe_outs[0][:, :1]), ] # MLP @@ -229,14 +228,16 @@ def parallel_block(inputs, shared_cpes=None, shared_crpes=None, num_heads=8, mlp return inputs[:1] + outs # inputs[0] directly out -def patch_embed(inputs, embed_dim, patch_size=2, name=""): +def patch_embed(inputs, embed_dim, patch_size=2, input_height=-1, name=""): if len(inputs.shape) == 3: - height = width = int(tf.math.sqrt(float(inputs.shape[1]))) # assume hh == ww - inputs = keras.layers.Reshape([height, width, inputs.shape[-1]])(inputs) - nn = conv2d_no_bias(inputs, embed_dim, kernel_size=patch_size, strides=patch_size, use_bias=True, name=name) # Try with Conv1D + input_height = input_height if input_height > 0 else int(tf.math.sqrt(float(inputs.shape[1]))) + input_width = inputs.shape[1] // input_height + inputs = keras.layers.Reshape([input_height, input_width, inputs.shape[-1]])(inputs) + nn = conv2d_no_bias(inputs, embed_dim, kernel_size=patch_size, strides=patch_size, use_bias=True, name=name) + block_height = nn.shape[1] nn = keras.layers.Reshape([nn.shape[1] * nn.shape[2], nn.shape[-1]])(nn) # flatten(2) nn = layer_norm(nn, name=name) - return nn + return nn, block_height def CoaT( @@ -248,8 +249,8 @@ def CoaT( num_heads=8, head_splits=[2, 3, 3], head_kernel_size=[3, 5, 7], - use_shared_cpe=True, - use_shared_crpe=True, + use_shared_cpe=True, # For checking model architecture only, keep input_shape height == width if set False + use_shared_crpe=True, # For checking model architecture only, keep input_shape height == width if set False out_features=None, input_shape=(224, 224, 3), num_classes=1000, @@ -267,14 +268,18 @@ def CoaT( classfier_outs = [] shared_cpes = [] shared_crpes = [] + block_heights = [] for sid, (depth, embed_dim, mlp_ratio) in enumerate(zip(serial_depths, embed_dims, mlp_ratios)): name = "serial{}_".format(sid + 1) patch_size = patch_size if sid == 0 else 2 + patch_input_height = -1 if sid == 0 else block_heights[-1] # print(f">>>> {nn.shape = }") - nn = patch_embed(nn, embed_dim, patch_size=patch_size, name=name + "patch_") + nn, block_height = patch_embed(nn, embed_dim, patch_size=patch_size, input_height=patch_input_height, name=name + "patch_") + block_heights.append(block_height) + # print(f">>>> {nn.shape = }, {block_height = }") nn = ClassToken(name=name + "class_token")(nn) - shared_cpe = ConvPositionalEncoding(kernel_size=3, name="cpe_" + str(sid + 1)) if use_shared_cpe else None - shared_crpe = ConvRelativePositionalEncoding(head_splits, head_kernel_size, name="crpe_" + str(sid + 1)) if use_shared_crpe else None + shared_cpe = ConvPositionalEncoding(kernel_size=3, input_height=block_height, name="cpe_" + str(sid + 1)) if use_shared_cpe else None + shared_crpe = ConvRelativePositionalEncoding(head_splits, head_kernel_size, block_height, name="crpe_" + str(sid + 1)) if use_shared_crpe else None for bid in range(depth): block_name = name + "block{}_".format(bid + 1) nn = serial_block(nn, embed_dim, shared_cpe, shared_crpe, num_heads, mlp_ratio, activation=activation, name=block_name) @@ -286,7 +291,7 @@ def CoaT( # Parallel blocks. for pid in range(parallel_depth): name = "parallel{}_".format(pid + 1) - classfier_outs = parallel_block(classfier_outs, shared_cpes, shared_crpes, num_heads, mlp_ratios, activation=activation, name=name) + classfier_outs = parallel_block(classfier_outs, shared_cpes, shared_crpes, block_heights, num_heads, mlp_ratios, activation=activation, name=name) if out_features is not None: # Return intermediate features (for down-stream tasks). nn = [classfier_outs[id][:, 1:, :] for id in out_features] diff --git a/keras_cv_attention_models/coco/eval_func.py b/keras_cv_attention_models/coco/eval_func.py index 4aa25f2e..496d4618 100644 --- a/keras_cv_attention_models/coco/eval_func.py +++ b/keras_cv_attention_models/coco/eval_func.py @@ -296,7 +296,7 @@ def build(self, input_shape, output_shape): else: num_anchors = anchors_func.NUM_ANCHORS.get(self.anchors_mode, 9) pyramid_levels = anchors_func.get_pyramid_levels_by_anchors(input_shape, total_anchors=output_shape[1], num_anchors=num_anchors) - print(">>>> [COCOEvalCallback] input_shape: {}, pyramid_levels: {}, anchors_mode: {}".format(input_shape, pyramid_levels, self.anchors_mode)) + print("\n>>>> [COCOEvalCallback] input_shape: {}, pyramid_levels: {}, anchors_mode: {}".format(input_shape, pyramid_levels, self.anchors_mode)) # print(">>>>", self.dataset_kwargs) # print(">>>>", self.nms_kwargs) self.pred_decoder = DecodePredictions(input_shape, pyramid_levels, self.anchors_mode, anchor_scale=self.anchor_scale) diff --git a/tests/test_models.py b/tests/test_models.py index eed22ae4..2258e75b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -221,6 +221,14 @@ def test_CMTTiny_new_shape_predict(): assert out[1] == "Egyptian_cat" +def test_CoaT_new_shape_predict(): + mm = keras_cv_attention_models.coat.CoaTLiteMini(input_shape=(193, 117, 3), pretrained="imagenet") + pred = mm(mm.preprocess_input(chelsea())) # Chelsea the cat + out = mm.decode_predictions(pred)[0][0] + + assert out[1] == "Egyptian_cat" + + def test_CoAtNet_new_shape_predict(): mm = keras_cv_attention_models.coatnet.CoAtNet0(input_shape=(320, 320, 3), pretrained="imagenet") pred = mm(mm.preprocess_input(chelsea())) # Chelsea the cat