From 6ade92c8dc427296527d0a91759644544dbd1a77 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Wed, 27 Dec 2023 19:59:21 +0800 Subject: [PATCH] add TinySAM --- README.md | 1 + keras_cv_attention_models/coco/eval_func.py | 4 +-- .../segment_anything/README.md | 4 +++ .../segment_anything/__init__.py | 4 ++- .../segment_anything/image_encoders.py | 1 + .../{mask_decoder.py => mask_decoders.py} | 1 + .../{prompt_encoder.py => prompt_encoders.py} | 0 .../segment_anything/sam.py | 29 +++++++++++++++---- tests/test_models.py | 16 ++++++++++ 9 files changed, 51 insertions(+), 9 deletions(-) rename keras_cv_attention_models/segment_anything/{mask_decoder.py => mask_decoders.py} (99%) rename keras_cv_attention_models/segment_anything/{prompt_encoder.py => prompt_encoders.py} (100%) diff --git a/README.md b/README.md index 35d2190c..b5a33388 100644 --- a/README.md +++ b/README.md @@ -1665,6 +1665,7 @@ | Model | Params | FLOPs | Input | COCO val mIoU | T4 Inference | | ------------------- | ------ | ----- | ----- | ------------- | ------------ | | [MobileSAM](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/mobile_sam_5m_image_encoder_1024_sam.h5) | 5.74M | 39.4G | 1024 | 72.8 | | + | [TinySAM](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/tinysam_5m_image_encoder_1024_sam.h5) | 5.74M | 39.4G | 1024 | | | | [EfficientViT_SAM_L0](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/efficientvit_sam_l0_image_encoder_1024_sam.h5) | 30.73M | 35.4G | 512 | 74.45 | | *** diff --git a/keras_cv_attention_models/coco/eval_func.py b/keras_cv_attention_models/coco/eval_func.py index 400f506d..6f5873e3 100644 --- a/keras_cv_attention_models/coco/eval_func.py +++ b/keras_cv_attention_models/coco/eval_func.py @@ -256,10 +256,10 @@ def init_eval_dataset( # dataset = data.detection_dataset_from_custom_json(data_name) if data_name.endswith(".json") else tfds.load(data_name) if data_name.endswith(".json"): - import tensorflow_datasets as tfds - dataset, _, num_classes = data.detection_dataset_from_custom_json(data_name, with_info=True) else: + import tensorflow_datasets as tfds + dataset, info = tfds.load(data_name, with_info=True) num_classes = info.features["objects"]["label"].num_classes diff --git a/keras_cv_attention_models/segment_anything/README.md b/keras_cv_attention_models/segment_anything/README.md index 0f8d8c87..aed1de74 100644 --- a/keras_cv_attention_models/segment_anything/README.md +++ b/keras_cv_attention_models/segment_anything/README.md @@ -4,14 +4,17 @@ ## Summary - Paper [PDF 2304.02643 Segment Anything](https://arxiv.org/abs/2304.02643) - Paper [PDF 2306.14289 FASTER SEGMENT ANYTHING: TOWARDS LIGHTWEIGHT SAM FOR MOBILE APPLICATIONS](https://arxiv.org/pdf/2306.14289.pdf) + - Paper [PDF 2312.13789 TinySAM: Pushing the Envelope for Efficient Segment Anything Mode](https://arxiv.org/pdf/2312.13789.pdf) - [Github facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything) - [Github ChaoningZhang/MobileSAM](https://github.com/ChaoningZhang/MobileSAM) + - [Github xinghaochen/TinySAM](https://github.com/xinghaochen/TinySAM) - MobileSAM weights ported from [Github ChaoningZhang/MobileSAM](https://github.com/ChaoningZhang/MobileSAM) - EfficientViT_SAM weights ported from [Github mit-han-lab/efficientvit](https://github.com/mit-han-lab/efficientvit) ## Models | Model | Params | FLOPs | Input | COCO val mIoU | Download | | ------------------- | ------ | ----- | ----- | ------------- | -------- | | MobileSAM | 5.74M | 39.4G | 1024 | 72.8 | [mobile_sam_5m_image_encoder](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/mobile_sam_5m_image_encoder_1024_sam.h5) | + | TinySAM | 5.74M | 39.4G | 1024 | | [tinysam_5m_image_encoder](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/tinysam_5m_image_encoder_1024_sam.h5) | | EfficientViT_SAM_L0 | 30.73M | 35.4G | 512 | 74.45 | [efficientvit_sam_l0_image_encoder](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/efficientvit_sam_l0_image_encoder_1024_sam.h5) | Model differences only in `ImageEncoder`, the SAM `PromptEncoder` and `MaskDecoder` are sharing the same one @@ -19,6 +22,7 @@ | Model | Params | FLOPs | Download | | ----------------------- | ------ | ----- | -------- | | MaskDecoder | 4.06M | 1.78G | [sam_mask_decoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_mask_decoder_sam.h5) | + | - tiny_sam | 4.06M | 1.78G | [tiny_sam_mask_decoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/tiny_sam_mask_decoder_sam.h5) | | PointsEncoder | 768 | 0 | [sam_prompt_encoder_points_encoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_prompt_encoder_points_encoder_sam.h5) | | BboxesEncoder | 512 | 256 | [sam_prompt_encoder_bboxes_encoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_prompt_encoder_bboxes_encoder_sam.h5) | | MaskEncoder | 4684 | 0 | [sam_prompt_encoder_mask_encoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_prompt_encoder_mask_encoder_sam.h5) | diff --git a/keras_cv_attention_models/segment_anything/__init__.py b/keras_cv_attention_models/segment_anything/__init__.py index 3bf440bf..e46ce016 100644 --- a/keras_cv_attention_models/segment_anything/__init__.py +++ b/keras_cv_attention_models/segment_anything/__init__.py @@ -1,4 +1,4 @@ -from keras_cv_attention_models.segment_anything.sam import SAM, MobileSAM, EfficientViT_SAM_L0 +from keras_cv_attention_models.segment_anything.sam import SAM, MobileSAM, TinySAM, EfficientViT_SAM_L0 __head_doc__ = """ Keras implementation of [Github facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything). @@ -44,12 +44,14 @@ SAM.__doc__ = __head_doc__ + """ Init args: image_encoder: string or built image encoder model. Currently string can be one of ["TinyViT_5M", "EfficientViT_L0"]. + mask_decoder: string or built mask decoder model. Currently string can be one of ["sam_mask_decoder", "tiny_sam_mask_decoder"]. name: string, model name. """ + __tail_doc__ + """ Model architectures: | Model | Params | FLOPs | Input | COCO val mIoU | | ------------------- | ------ | ----- | ----- | ------------- | | MobileSAM | 5.74M | 39.4G | 1024 | 72.8 | + | TinySAM | 5.74M | 39.4G | 1024 | | | EfficientViT_SAM_L0 | 30.73M | 35.4G | 512 | 74.45 | """ diff --git a/keras_cv_attention_models/segment_anything/image_encoders.py b/keras_cv_attention_models/segment_anything/image_encoders.py index fde9f600..bcd77543 100644 --- a/keras_cv_attention_models/segment_anything/image_encoders.py +++ b/keras_cv_attention_models/segment_anything/image_encoders.py @@ -8,6 +8,7 @@ PRETRAINED_DICT = { "mobile_sam_5m_image_encoder": {"sam": {1024: "d9e48e1b5109b8f677625454a5f9c257"}}, + "tiny_sam_5m_image_encoder": {"sam": {1024: "ae58fa89388f5e1d414e86c33b21a71a"}}, "efficientvit_sam_l0_image_encoder": {"sam": {1024: "d91f40cf7f46b375a859bef4b2c87bdb"}}, } diff --git a/keras_cv_attention_models/segment_anything/mask_decoder.py b/keras_cv_attention_models/segment_anything/mask_decoders.py similarity index 99% rename from keras_cv_attention_models/segment_anything/mask_decoder.py rename to keras_cv_attention_models/segment_anything/mask_decoders.py index 7a6e5a7a..ca4e0fcd 100644 --- a/keras_cv_attention_models/segment_anything/mask_decoder.py +++ b/keras_cv_attention_models/segment_anything/mask_decoders.py @@ -12,6 +12,7 @@ LAYER_NORM_EPSILON = 1e-5 PRETRAINED_DICT = { "sam_mask_decoder": {"sam": "86ccca20e41dd15578fbbd067035fa70"}, + "tiny_sam_mask_decoder": {"sam": "34f68eb047de515721f4658106e4ccb5"}, } diff --git a/keras_cv_attention_models/segment_anything/prompt_encoder.py b/keras_cv_attention_models/segment_anything/prompt_encoders.py similarity index 100% rename from keras_cv_attention_models/segment_anything/prompt_encoder.py rename to keras_cv_attention_models/segment_anything/prompt_encoders.py diff --git a/keras_cv_attention_models/segment_anything/sam.py b/keras_cv_attention_models/segment_anything/sam.py index 6ed8ccfe..344e662e 100644 --- a/keras_cv_attention_models/segment_anything/sam.py +++ b/keras_cv_attention_models/segment_anything/sam.py @@ -5,7 +5,7 @@ from keras_cv_attention_models.attention_layers import BiasLayer, PureWeigths, batchnorm_with_activation, conv2d_no_bias, layer_norm from keras_cv_attention_models.models import register_model, FakeModelWrapper, torch_no_grad from keras_cv_attention_models.download_and_load import reload_model_weights -from keras_cv_attention_models.segment_anything import image_encoders, mask_decoder, prompt_encoder +from keras_cv_attention_models.segment_anything import image_encoders, mask_decoders, prompt_encoders LAYER_NORM_EPSILON = 1e-6 @@ -13,7 +13,15 @@ @register_model class SAM(FakeModelWrapper): # FakeModelWrapper providing save / load / cuda class methods def __init__( - self, image_encoder="TinyViT_5M", image_shape=(1024, 1024), embed_dims=256, mask_hidden_dims=16, pretrained="sam", name="mobile_sam_5m", kwargs=None + self, + image_encoder="TinyViT_5M", + mask_decoder="sam_mask_decoder", # string or built mask decoder model. Currently string can be one of ["sam_mask_decoder", "tiny_sam_mask_decoder"] + image_shape=(1024, 1024), + embed_dims=256, + mask_hidden_dims=16, + pretrained="sam", + name="mobile_sam_5m", + kwargs=None, # Not using, just recieving parameter ): self.image_shape = image_shape[:2] if isinstance(image_shape, (list, tuple)) else [image_shape, image_shape] self.embed_dims = embed_dims @@ -25,9 +33,13 @@ def __init__( self.prompt_mask_shape = [int(self.image_embedding_shape[0] * 16), int(self.image_embedding_shape[1] * 16)] # [64, 64] -> [1024, 1024] self.masks_input_shape = [int(self.image_embedding_shape[0] * 4), int(self.image_embedding_shape[1] * 4)] # [64, 64] -> [256, 256] + if isinstance(mask_decoder, str): + self.mask_decoder = mask_decoders.MaskDecoder(input_shape=[*self.image_embedding_shape, embed_dims], name=mask_decoder) + else: + self.mask_decoder = mask_decoder + # prompt_encoder is also a subclass of FakeModelWrapper, and here not passing the `name` - self.prompt_encoder = prompt_encoder.PromptEncoder(embed_dims, mask_hidden_dims, self.prompt_mask_shape, self.masks_input_shape, pretrained=pretrained) - self.mask_decoder = mask_decoder.MaskDecoder(input_shape=[*self.image_embedding_shape, embed_dims], pretrained=pretrained) + self.prompt_encoder = prompt_encoders.PromptEncoder(embed_dims, mask_hidden_dims, self.prompt_mask_shape, self.masks_input_shape, pretrained=pretrained) self.models = [self.image_encoder, self.mask_decoder] + self.prompt_encoder.models super().__init__(self.models, name=name) @@ -135,9 +147,14 @@ def MobileSAM(image_shape=(1024, 1024), pretrained="sam", name="mobile_sam_5m", return SAM(image_encoder="TinyViT_5M", **locals(), **kwargs) +@register_model +def TinySAM(image_shape=(1024, 1024), mask_decoder="tiny_sam_mask_decoder", pretrained="sam", name="tiny_sam_5m", **kwargs): + return SAM(image_encoder="TinyViT_5M", **locals(), **kwargs) + + @register_model def EfficientViT_SAM_L0(image_shape=(512, 512), pretrained="sam", name="efficientvit_sam_l0", **kwargs): - mask_decoder.LAYER_NORM_EPSILON = 1e-6 + mask_decoders.LAYER_NORM_EPSILON = 1e-6 model = SAM(image_encoder="EfficientViT_L0", **locals(), **kwargs) - mask_decoder.LAYER_NORM_EPSILON = 1e-5 + mask_decoders.LAYER_NORM_EPSILON = 1e-5 return model diff --git a/tests/test_models.py b/tests/test_models.py index 38eb6f1c..7cba70d1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -739,6 +739,9 @@ def test_YOLOV8_S_dynamic_predict(): assert COCO_80_LABEL_DICT[pred_label[0]] == "cat" +""" Stable Diffusion """ + + def test_stable_diffusion_no_weights_predict(): mm = keras_cv_attention_models.stable_diffusion.StableDiffusion(pretrained=None) image = keras_cv_attention_models.backend.numpy_image_resize(cat(), [256, 256]) @@ -750,3 +753,16 @@ def test_stable_diffusion_no_weights_predict(): out = out.numpy() assert out.shape == (1, 256, 256, 3) assert out.min() > -6 and out.max() < 6 # It should be within this range + + +""" Segment Anything """ + + +def test_MobileSAM_predict(): + mm = keras_cv_attention_models.segment_anything.MobileSAM() + points, labels = np.array([(0.5, 0.8)]), np.array([1]) + masks, iou_predictions, low_res_masks = mm(cat(), points, labels) + + assert masks.shape == (4, 512, 512) and iou_predictions.shape == (4,) and low_res_masks.shape == (4, 256, 256) + assert np.allclose(iou_predictions, np.array([0.98725945, 0.83492416, 0.9997821, 0.96904826]), atol=1e-3) + assert np.allclose([ii.sum() for ii in masks], [140151, 121550, 139295, 149360], atol=10)