Skip to content

Commit

Permalink
add TinySAM
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Dec 27, 2023
1 parent 41b4006 commit 6ade92c
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,7 @@
| Model | Params | FLOPs | Input | COCO val mIoU | T4 Inference |
| ------------------- | ------ | ----- | ----- | ------------- | ------------ |
| [MobileSAM](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/mobile_sam_5m_image_encoder_1024_sam.h5) | 5.74M | 39.4G | 1024 | 72.8 | |
| [TinySAM](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/tinysam_5m_image_encoder_1024_sam.h5) | 5.74M | 39.4G | 1024 | | |
| [EfficientViT_SAM_L0](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/efficientvit_sam_l0_image_encoder_1024_sam.h5) | 30.73M | 35.4G | 512 | 74.45 | |
***

Expand Down
4 changes: 2 additions & 2 deletions keras_cv_attention_models/coco/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@ def init_eval_dataset(

# dataset = data.detection_dataset_from_custom_json(data_name) if data_name.endswith(".json") else tfds.load(data_name)
if data_name.endswith(".json"):
import tensorflow_datasets as tfds

dataset, _, num_classes = data.detection_dataset_from_custom_json(data_name, with_info=True)
else:
import tensorflow_datasets as tfds

dataset, info = tfds.load(data_name, with_info=True)
num_classes = info.features["objects"]["label"].num_classes

Expand Down
4 changes: 4 additions & 0 deletions keras_cv_attention_models/segment_anything/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
## Summary
- Paper [PDF 2304.02643 Segment Anything](https://arxiv.org/abs/2304.02643)
- Paper [PDF 2306.14289 FASTER SEGMENT ANYTHING: TOWARDS LIGHTWEIGHT SAM FOR MOBILE APPLICATIONS](https://arxiv.org/pdf/2306.14289.pdf)
- Paper [PDF 2312.13789 TinySAM: Pushing the Envelope for Efficient Segment Anything Mode](https://arxiv.org/pdf/2312.13789.pdf)
- [Github facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
- [Github ChaoningZhang/MobileSAM](https://github.com/ChaoningZhang/MobileSAM)
- [Github xinghaochen/TinySAM](https://github.com/xinghaochen/TinySAM)
- MobileSAM weights ported from [Github ChaoningZhang/MobileSAM](https://github.com/ChaoningZhang/MobileSAM)
- EfficientViT_SAM weights ported from [Github mit-han-lab/efficientvit](https://github.com/mit-han-lab/efficientvit)
## Models
| Model | Params | FLOPs | Input | COCO val mIoU | Download |
| ------------------- | ------ | ----- | ----- | ------------- | -------- |
| MobileSAM | 5.74M | 39.4G | 1024 | 72.8 | [mobile_sam_5m_image_encoder](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/mobile_sam_5m_image_encoder_1024_sam.h5) |
| TinySAM | 5.74M | 39.4G | 1024 | | [tinysam_5m_image_encoder](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/tinysam_5m_image_encoder_1024_sam.h5) |
| EfficientViT_SAM_L0 | 30.73M | 35.4G | 512 | 74.45 | [efficientvit_sam_l0_image_encoder](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/efficientvit_sam_l0_image_encoder_1024_sam.h5) |

Model differences only in `ImageEncoder`, the SAM `PromptEncoder` and `MaskDecoder` are sharing the same one

| Model | Params | FLOPs | Download |
| ----------------------- | ------ | ----- | -------- |
| MaskDecoder | 4.06M | 1.78G | [sam_mask_decoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_mask_decoder_sam.h5) |
| - tiny_sam | 4.06M | 1.78G | [tiny_sam_mask_decoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/tiny_sam_mask_decoder_sam.h5) |
| PointsEncoder | 768 | 0 | [sam_prompt_encoder_points_encoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_prompt_encoder_points_encoder_sam.h5) |
| BboxesEncoder | 512 | 256 | [sam_prompt_encoder_bboxes_encoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_prompt_encoder_bboxes_encoder_sam.h5) |
| MaskEncoder | 4684 | 0 | [sam_prompt_encoder_mask_encoder_sam.h5](https://github.com/leondgarse/keras_cv_attention_models/releases/download/segment_anything/sam_prompt_encoder_mask_encoder_sam.h5) |
Expand Down
4 changes: 3 additions & 1 deletion keras_cv_attention_models/segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from keras_cv_attention_models.segment_anything.sam import SAM, MobileSAM, EfficientViT_SAM_L0
from keras_cv_attention_models.segment_anything.sam import SAM, MobileSAM, TinySAM, EfficientViT_SAM_L0

__head_doc__ = """
Keras implementation of [Github facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything).
Expand Down Expand Up @@ -44,12 +44,14 @@
SAM.__doc__ = __head_doc__ + """
Init args:
image_encoder: string or built image encoder model. Currently string can be one of ["TinyViT_5M", "EfficientViT_L0"].
mask_decoder: string or built mask decoder model. Currently string can be one of ["sam_mask_decoder", "tiny_sam_mask_decoder"].
name: string, model name.
""" + __tail_doc__ + """
Model architectures:
| Model | Params | FLOPs | Input | COCO val mIoU |
| ------------------- | ------ | ----- | ----- | ------------- |
| MobileSAM | 5.74M | 39.4G | 1024 | 72.8 |
| TinySAM | 5.74M | 39.4G | 1024 | |
| EfficientViT_SAM_L0 | 30.73M | 35.4G | 512 | 74.45 |
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

PRETRAINED_DICT = {
"mobile_sam_5m_image_encoder": {"sam": {1024: "d9e48e1b5109b8f677625454a5f9c257"}},
"tiny_sam_5m_image_encoder": {"sam": {1024: "ae58fa89388f5e1d414e86c33b21a71a"}},
"efficientvit_sam_l0_image_encoder": {"sam": {1024: "d91f40cf7f46b375a859bef4b2c87bdb"}},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
LAYER_NORM_EPSILON = 1e-5
PRETRAINED_DICT = {
"sam_mask_decoder": {"sam": "86ccca20e41dd15578fbbd067035fa70"},
"tiny_sam_mask_decoder": {"sam": "34f68eb047de515721f4658106e4ccb5"},
}


Expand Down
29 changes: 23 additions & 6 deletions keras_cv_attention_models/segment_anything/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@
from keras_cv_attention_models.attention_layers import BiasLayer, PureWeigths, batchnorm_with_activation, conv2d_no_bias, layer_norm
from keras_cv_attention_models.models import register_model, FakeModelWrapper, torch_no_grad
from keras_cv_attention_models.download_and_load import reload_model_weights
from keras_cv_attention_models.segment_anything import image_encoders, mask_decoder, prompt_encoder
from keras_cv_attention_models.segment_anything import image_encoders, mask_decoders, prompt_encoders

LAYER_NORM_EPSILON = 1e-6


@register_model
class SAM(FakeModelWrapper): # FakeModelWrapper providing save / load / cuda class methods
def __init__(
self, image_encoder="TinyViT_5M", image_shape=(1024, 1024), embed_dims=256, mask_hidden_dims=16, pretrained="sam", name="mobile_sam_5m", kwargs=None
self,
image_encoder="TinyViT_5M",
mask_decoder="sam_mask_decoder", # string or built mask decoder model. Currently string can be one of ["sam_mask_decoder", "tiny_sam_mask_decoder"]
image_shape=(1024, 1024),
embed_dims=256,
mask_hidden_dims=16,
pretrained="sam",
name="mobile_sam_5m",
kwargs=None, # Not using, just recieving parameter
):
self.image_shape = image_shape[:2] if isinstance(image_shape, (list, tuple)) else [image_shape, image_shape]
self.embed_dims = embed_dims
Expand All @@ -25,9 +33,13 @@ def __init__(
self.prompt_mask_shape = [int(self.image_embedding_shape[0] * 16), int(self.image_embedding_shape[1] * 16)] # [64, 64] -> [1024, 1024]
self.masks_input_shape = [int(self.image_embedding_shape[0] * 4), int(self.image_embedding_shape[1] * 4)] # [64, 64] -> [256, 256]

if isinstance(mask_decoder, str):
self.mask_decoder = mask_decoders.MaskDecoder(input_shape=[*self.image_embedding_shape, embed_dims], name=mask_decoder)
else:
self.mask_decoder = mask_decoder

# prompt_encoder is also a subclass of FakeModelWrapper, and here not passing the `name`
self.prompt_encoder = prompt_encoder.PromptEncoder(embed_dims, mask_hidden_dims, self.prompt_mask_shape, self.masks_input_shape, pretrained=pretrained)
self.mask_decoder = mask_decoder.MaskDecoder(input_shape=[*self.image_embedding_shape, embed_dims], pretrained=pretrained)
self.prompt_encoder = prompt_encoders.PromptEncoder(embed_dims, mask_hidden_dims, self.prompt_mask_shape, self.masks_input_shape, pretrained=pretrained)
self.models = [self.image_encoder, self.mask_decoder] + self.prompt_encoder.models
super().__init__(self.models, name=name)

Expand Down Expand Up @@ -135,9 +147,14 @@ def MobileSAM(image_shape=(1024, 1024), pretrained="sam", name="mobile_sam_5m",
return SAM(image_encoder="TinyViT_5M", **locals(), **kwargs)


@register_model
def TinySAM(image_shape=(1024, 1024), mask_decoder="tiny_sam_mask_decoder", pretrained="sam", name="tiny_sam_5m", **kwargs):
return SAM(image_encoder="TinyViT_5M", **locals(), **kwargs)


@register_model
def EfficientViT_SAM_L0(image_shape=(512, 512), pretrained="sam", name="efficientvit_sam_l0", **kwargs):
mask_decoder.LAYER_NORM_EPSILON = 1e-6
mask_decoders.LAYER_NORM_EPSILON = 1e-6
model = SAM(image_encoder="EfficientViT_L0", **locals(), **kwargs)
mask_decoder.LAYER_NORM_EPSILON = 1e-5
mask_decoders.LAYER_NORM_EPSILON = 1e-5
return model
16 changes: 16 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ def test_YOLOV8_S_dynamic_predict():
assert COCO_80_LABEL_DICT[pred_label[0]] == "cat"


""" Stable Diffusion """


def test_stable_diffusion_no_weights_predict():
mm = keras_cv_attention_models.stable_diffusion.StableDiffusion(pretrained=None)
image = keras_cv_attention_models.backend.numpy_image_resize(cat(), [256, 256])
Expand All @@ -750,3 +753,16 @@ def test_stable_diffusion_no_weights_predict():
out = out.numpy()
assert out.shape == (1, 256, 256, 3)
assert out.min() > -6 and out.max() < 6 # It should be within this range


""" Segment Anything """


def test_MobileSAM_predict():
mm = keras_cv_attention_models.segment_anything.MobileSAM()
points, labels = np.array([(0.5, 0.8)]), np.array([1])
masks, iou_predictions, low_res_masks = mm(cat(), points, labels)

assert masks.shape == (4, 512, 512) and iou_predictions.shape == (4,) and low_res_masks.shape == (4, 256, 256)
assert np.allclose(iou_predictions, np.array([0.98725945, 0.83492416, 0.9997821, 0.96904826]), atol=1e-3)
assert np.allclose([ii.sum() for ii in masks], [140151, 121550, 139295, 149360], atol=10)

0 comments on commit 6ade92c

Please sign in to comment.