diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6980c45 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.idea/* +*__pycache__* +weights +experiments +pretrained_models +*.so +*.ipynb_checkpoints* +*.yml +*.json diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..500a598 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,37 @@ +FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + git \ + curl \ + libglib2.0-0 \ + software-properties-common \ + python3.6-dev \ + python3-pip \ + python3-tk + +WORKDIR /tmp + +RUN pip3 install --upgrade pip +RUN pip3 install setuptools +RUN pip3 install matplotlib numpy pandas scipy tqdm pyyaml easydict scikit-image bridson Pillow ninja +RUN pip3 install imgaug mxboard graphviz +RUN pip3 install albumentations --no-deps +RUN pip3 install opencv-python-headless +RUN pip3 install Cython +RUN pip3 install torch +RUN pip3 install torchvision +RUN pip3 install scikit-learn +RUN pip3 install tensorboard + +RUN mkdir /work +WORKDIR /work +RUN chmod -R 777 /work && chmod -R 777 /root + +ENV TINI_VERSION v0.18.0 +ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/bin/tini +RUN chmod +x /usr/bin/tini +ENTRYPOINT [ "/usr/bin/tini", "--" ] +CMD [ "/bin/bash" ] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..17c5b89 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2021 Samsung Electronics Co., Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..27a42cc --- /dev/null +++ b/README.md @@ -0,0 +1,346 @@ +## Reviving Iterative Training with Mask Guidance for Interactive Segmentation + +

+ + + + + Open In Colab + + + The MIT License + +

+ +

+ drawing + drawing +

+ +This repository provides the source code for training and testing state-of-the-art click-based interactive segmentation models with the official PyTorch implementation of the following paper: + +> **Reviving Iterative Training with Mask Guidance for Interactive Segmentation**
+> [Konstantin Sofiiuk](https://github.com/ksofiyuk), [Ilia Petrov](https://github.com/ptrvilya), [Anton Konushin](https://scholar.google.com/citations?user=ZT_k-wMAAAAJ)
+> Samsung AI Center Moscow
+> https://arxiv.org/abs/ +> +> **Abstract:** *Recent works on click-based interactive segmentation have demonstrated state-of-the-art results by +> using various inference-time optimization schemes. These methods are considerably more computationally expensive +> compared to feedforward approaches, as they require performing backward passes through a network during inference and +> are hard to deploy on mobile frameworks that usually support only forward passes. In this paper, we extensively +> evaluate various design choices for interactive segmentation and discover that new state-of-the-art results can be +> obtained without any additional optimization schemes. Thus, we propose a simple feedforward model for click-based +> interactive segmentation that employs the segmentation masks from previous steps. It allows not only to segment an +> entirely new object, but also to start with an external mask and correct it. When analyzing the performance of models +> trained on different datasets, we observe that the choice of a training dataset greatly impacts the quality of +> interactive segmentation. We find that the models trained on a combination of COCO and LVIS with diverse and +> high-quality annotations show performance superior to all existing models.* + + +## Setting up an environment + +This framework is built using Python 3.6 and relies on the PyTorch 1.4.0+. The following command installs all +necessary packages: + +```.bash +pip3 install -r requirements.txt +``` + +You can also use our [Dockerfile](./Dockerfile) to build a container with the configured environment. + +If you want to run training or testing, you must configure the paths to the datasets in [config.yml](config.yml). + +## Interactive Segmentation Demo + +

+ drawing +

+ +The GUI is based on TkInter library and its Python bindings. You can try our interactive demo with any of the +[provided models](#pretrained-models). Our scripts automatically detect the architecture of the loaded model, just +specify the path to the corresponding checkpoint. + +Examples of the script usage: + +```.bash +# This command runs interactive demo with HRNet18 ITER-M model from cfg.INTERACTIVE_MODELS_PATH on GPU with id=0 +# --checkpoint can be relative to cfg.INTERACTIVE_MODELS_PATH or absolute path to the checkpoint +python3 demo.py --checkpoint=hrnet18_cocolvis_itermask_3p --gpu=0 + +# This command runs interactive demo with HRNet18 ITER-M model from /home/demo/isegm/weights/ +# If you also do not have a lot of GPU memory, you can reduce --limit-longest-size (default=800) +python3 demo.py --checkpoint=/home/demo/fBRS/weights/hrnet18_cocolvis_itermask_3p --limit-longest-size=400 + +# You can try the demo in CPU only mode +python3 demo.py --checkpoint=hrnet18_cocolvis_itermask_3p --cpu +``` + +
+Running demo in docker +
# activate xhost
+xhost +
+docker run -v "$PWD":/tmp/ \
+           -v /tmp/.X11-unix:/tmp/.X11-unix \
+           -e DISPLAY=$DISPLAY <id-or-tag-docker-built-image> \
+           python3 demo.py --checkpoint resnet34_dh128_sbd --cpu
+
+
+ +**Controls**: + +| Key | Description | +| ------------------------------------------------------------- | ---------------------------------- | +| Left Mouse Button | Place a positive click | +| Right Mouse Button | Place a negative click | +| Scroll Wheel | Zoom an image in and out | +| Right Mouse Button +
Move Mouse | Move an image | +| Space | Finish the current object mask | + +
+Initializing the ITER-M models with an external segmentation mask +

+ drawing +

+ +According to our paper, ITER-M models take an image, encoded user input, and a previous step mask as their input. Moreover, a user can initialize the model with an external mask before placing any clicks and correct this mask using the same interface. As it turns out, our models successfully handle this situation and make it possible to change the mask. + + +To initialize any ITER-M model with an external mask use the "Load mask" button in the menu bar. +
+ +
+Interactive segmentation options + +
+ +## Datasets + +We train all our models on SBD and COCO+LVIS and evaluate them on GrabCut, Berkeley, DAVIS, SBD and PascalVOC. We also provide links to additional datasets: ADE20k and OpenImages, that are used in ablation study. + +| Dataset | Description | Download Link | +|-----------|----------------------------------------------|:------------------------------------:| +|ADE20k | 22k images with 434k instances (total) | [official site][ADE20k] | +|OpenImages | 944k images with 2.6M instances (total) | [official site][OpenImages] | +|MS COCO | 118k images with 1.2M instances (train) | [official site][GrabCut] | +|LVIS v1.0 | 100k images with 1.2M instances (total) | [official site][LVIS] | +|COCO+LVIS* | 99k images with 1.5M instances (train) | [original LVIS images][LVIS] +
[our combined annotations][COCOLVIS_annotation] | +|SBD | 8498 images with 20172 instances for (train)
2857 images with 6671 instances for (test) |[official site][SBD]| +|Grab Cut | 50 images with one object each (test) | [GrabCut.zip (11 MB)][GrabCut] | +|Berkeley | 96 images with 100 instances (test) | [Berkeley.zip (7 MB)][Berkeley] | +|DAVIS | 345 images with one object each (test) | [DAVIS.zip (43 MB)][DAVIS] | +|Pascal VOC | 1449 images with 3417 instances (validation)| [official site][PascalVOC] | +|COCO_MVal | 800 images with 800 instances (test) | [COCO_MVal.zip (127 MB)][COCO_MVal] | + +[ADE20k]: http://sceneparsing.csail.mit.edu/ +[OpenImages]: https://storage.googleapis.com/openimages/web/download.html +[MSCOCO]: https://cocodataset.org/#download +[LVIS]: https://www.lvisdataset.org/dataset +[SBD]: http://home.bharathh.info/pubs/codes/SBD/download.html +[GrabCut]: https://github.com/saic-vul/fbrs_interactive_segmentation/releases/download/v1.0/GrabCut.zip +[Berkeley]: https://github.com/saic-vul/fbrs_interactive_segmentation/releases/download/v1.0/Berkeley.zip +[DAVIS]: https://github.com/saic-vul/fbrs_interactive_segmentation/releases/download/v1.0/DAVIS.zip +[PascalVOC]: http://host.robots.ox.ac.uk/pascal/VOC/ +[COCOLVIS_annotation]: https://github.com/saic-vul/ritm_interactive_segmentation/releases/download/v1.0/cocolvis_annotation.tar.gz +[COCO_MVal]: https://github.com/saic-vul/fbrs_interactive_segmentation/releases/download/v1.0/COCO_MVal.zip + +Don't forget to change the paths to the datasets in [config.yml](config.yml) after downloading and unpacking. + +(*) To prepare COCO+LVIS, you need to download original LVIS v1.0, then download and unpack our +pre-processed annotations that are obtained by combining COCO and LVIS dataset into the folder with LVIS v1.0. + +## Testing + +### Pretrained models +We provide pretrained models with different backbones for interactive segmentation. + +You can find model weights and evaluation results in the tables below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Train
Dataset
ModelGrabCutBerkeleySBDDAVISPascal
VOC
COCO
MVal
NoC
85%
NoC
90%
NoC
90%
NoC
85%
NoC
90%
NoC
85%
NoC
90%
NoC
85%
NoC
90%
SBDHRNet18 IT-M
(38.8 MB)
1.762.043.223.395.434.946.712.514.39
COCO+
LVIS
HRNet18
(38.8 MB)
1.541.702.484.266.864.796.002.593.58
HRNet18s IT-M
(16.5 MB)
1.541.682.604.046.484.705.982.573.33
HRNet18 IT-M
(38.8 MB)
1.421.542.263.806.064.365.742.282.98
HRNet32 IT-M
(119 MB)
1.461.562.103.595.714.115.342.572.97
+ + +### Evaluation + +We provide the script to test all the presented models in all possible configurations on GrabCut, Berkeley, DAVIS, +Pascal VOC, and SBD. To test a model, you should download its weights and put them in `./weights` folder (you can +change this path in the [config.yml](config.yml), see `INTERACTIVE_MODELS_PATH` variable). To test any of our models, +just specify the path to the corresponding checkpoint. Our scripts automatically detect the architecture of the loaded model. + +The following command runs the NoC evaluation on all test datasets (other options are displayed using '-h'): + +```.bash +python3 scripts/evaluate_model.py --checkpoint= +``` + +Examples of the script usage: +```.bash +# This command evaluates HRNetV2-W18-C+OCR ITER-M model in NoBRS mode on all Datasets. +python3 scripts/evaluate_model.py NoBRS --checkpoint=hrnet18_cocolvis_itermask_3p + +# This command evaluates HRNet-W18-C-Small-v2+OCR ITER-M model in f-BRS-B mode on all Datasets. +python3 scripts/evaluate_model.py f-BRS-B --checkpoint=hrnet18s_cocolvis_itermask_3p + +# This command evaluates HRNetV2-W18-C+OCR ITER-M model in NoBRS mode on GrabCut and Berkeley datasets. +python3 scripts/evaluate_model.py NoBRS --checkpoint=hrnet18_cocolvis_itermask_3p --datasets=GrabCut,Berkeley +``` + +### Jupyter notebook + +You can also interactively experiment with our models using [test_any_model.ipynb](./notebooks/test_any_model.ipynb) Jupyter notebook. + +## Training + +We provide the scripts for training our models on the SBD dataset. You can start training with the following commands: +```.bash +# ResNet-34 non-iterative baseline model +python3 train.py models/noniterative_baselines/r34_dh128_cocolvis.py --gpus=0 --workers=4 --exp-name=first-try + +# HRNet-W18-C-Small-v2+OCR ITER-M model +python3 train.py models/iter_mask/hrnet18s_cocolvis_itermask_3p.py --gpus=0 --workers=4 --exp-name=first-try + +# HRNetV2-W18-C+OCR ITER-M model +python3 train.py models/iter_mask/hrnet18_cocolvis_itermask_3p.py --gpus=0,1 --workers=6 --exp-name=first-try + +# HRNetV2-W32-C+OCR ITER-M model +python3 train.py models/iter_mask/hrnet32_cocolvis_itermask_3p.py --gpus=0,1,2,3 --workers=12 --exp-name=first-try +``` + +For each experiment, a separate folder is created in the `./experiments` with Tensorboard logs, text logs, +visualization and checkpoints. You can specify another path in the [config.yml](config.yml) (see `EXPS_PATH` +variable). + +Please note that we trained ResNet-34 and HRNet-18s on 1 GPU, HRNet-18 on 2 GPUs, HRNet-32 on 4 GPUs +(we used Nvidia Tesla P40 for training). To train on a different GPU you should adjust the batch size using the command +line argument `--batch-size` or change the default value in a model script. + +We used the pre-trained HRNetV2 models from [the official repository](https://github.com/HRNet/HRNet-Image-Classification). +If you want to train interactive segmentation with these models, you need to download the weights and specify the paths to +them in [config.yml](config.yml). + +## License + +The code is released under the MIT License. It is a short, permissive software license. Basically, you can do whatever you want as long as you include the original copyright and license notice in any copy of the software/source. +## Citation + +If you find this work is useful for your research, please cite our papers: +``` +@article{reviving2021, + title={Reviving Iterative Training with Mask Guidance for Interactive Segmentation}, + author={Konstantin Sofiiuk, Ilia Petrov, Anton Konushin}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +@article{fbrs2020, + title={f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation}, + author={Konstantin Sofiiuk, Ilia Petrov, Olga Barinova, Anton Konushin}, + journal={arXiv preprint arXiv:2001.10331}, + year={2020} +} +``` diff --git a/assets/img/demo_gui.jpg b/assets/img/demo_gui.jpg new file mode 100644 index 0000000..9ea4ca3 Binary files /dev/null and b/assets/img/demo_gui.jpg differ diff --git a/assets/img/miou_berkeley.png b/assets/img/miou_berkeley.png new file mode 100644 index 0000000..ed8d1e1 Binary files /dev/null and b/assets/img/miou_berkeley.png differ diff --git a/assets/img/modifying_external_mask.jpg b/assets/img/modifying_external_mask.jpg new file mode 100644 index 0000000..0243b2b Binary files /dev/null and b/assets/img/modifying_external_mask.jpg differ diff --git a/assets/img/teaser.gif b/assets/img/teaser.gif new file mode 100644 index 0000000..9a56a3b Binary files /dev/null and b/assets/img/teaser.gif differ diff --git a/assets/sbd_samples_weights.pkl b/assets/sbd_samples_weights.pkl new file mode 100644 index 0000000..ff0ed60 Binary files /dev/null and b/assets/sbd_samples_weights.pkl differ diff --git a/assets/test_imgs/apples_bowl.jpg b/assets/test_imgs/apples_bowl.jpg new file mode 100644 index 0000000..6fe14c0 Binary files /dev/null and b/assets/test_imgs/apples_bowl.jpg differ diff --git a/assets/test_imgs/parrots.jpg b/assets/test_imgs/parrots.jpg new file mode 100644 index 0000000..bfaf02d Binary files /dev/null and b/assets/test_imgs/parrots.jpg differ diff --git a/assets/test_imgs/sheep.jpg b/assets/test_imgs/sheep.jpg new file mode 100644 index 0000000..c2b81d4 Binary files /dev/null and b/assets/test_imgs/sheep.jpg differ diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..753af13 --- /dev/null +++ b/demo.py @@ -0,0 +1,57 @@ +import argparse +import tkinter as tk + +import torch + +from isegm.utils import exp +from isegm.inference import utils +from interactive_demo.app import InteractiveDemoApp + + +def main(): + args, cfg = parse_args() + + torch.backends.cudnn.deterministic = True + checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint) + model = utils.load_is_model(checkpoint_path, args.device, cpu_dist_maps=True) + + root = tk.Tk() + root.minsize(960, 480) + app = InteractiveDemoApp(root, args, model) + root.deiconify() + app.mainloop() + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--checkpoint', type=str, required=True, + help='The path to the checkpoint. ' + 'This can be a relative path (relative to cfg.INTERACTIVE_MODELS_PATH) ' + 'or an absolute path. The file extension can be omitted.') + + parser.add_argument('--gpu', type=int, default=0, + help='Id of GPU to use.') + + parser.add_argument('--cpu', action='store_true', default=False, + help='Use only CPU for inference.') + + parser.add_argument('--limit-longest-size', type=int, default=800, + help='If the largest side of an image exceeds this value, ' + 'it is resized so that its largest side is equal to this value.') + + parser.add_argument('--cfg', type=str, default="config.yml", + help='The path to the config file.') + + args = parser.parse_args() + if args.cpu: + args.device =torch.device('cpu') + else: + args.device = torch.device(f'cuda:{args.gpu}') + cfg = exp.load_config_file(args.cfg, return_edict=True) + + return args, cfg + + +if __name__ == '__main__': + main() diff --git a/interactive_demo/app.py b/interactive_demo/app.py new file mode 100644 index 0000000..47cf782 --- /dev/null +++ b/interactive_demo/app.py @@ -0,0 +1,348 @@ +import tkinter as tk +from tkinter import messagebox, filedialog, ttk + +import cv2 +import numpy as np +from PIL import Image + +from interactive_demo.canvas import CanvasImage +from interactive_demo.controller import InteractiveController +from interactive_demo.wrappers import BoundedNumericalEntry, FocusHorizontalScale, FocusCheckButton, \ + FocusButton, FocusLabelFrame + + +class InteractiveDemoApp(ttk.Frame): + def __init__(self, master, args, model): + super().__init__(master) + self.master = master + master.title("Reviving Iterative Training with Mask Guidance for Interactive Segmentation") + master.withdraw() + master.update_idletasks() + x = (master.winfo_screenwidth() - master.winfo_reqwidth()) / 2 + y = (master.winfo_screenheight() - master.winfo_reqheight()) / 2 + master.geometry("+%d+%d" % (x, y)) + self.pack(fill="both", expand=True) + + self.brs_modes = ['NoBRS', 'RGB-BRS', 'DistMap-BRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C'] + self.limit_longest_size = args.limit_longest_size + + self.controller = InteractiveController(model, args.device, + predictor_params={'brs_mode': 'NoBRS'}, + update_image_callback=self._update_image) + + self._init_state() + self._add_menu() + self._add_canvas() + self._add_buttons() + + master.bind('', lambda event: self.controller.finish_object()) + master.bind('a', lambda event: self.controller.partially_finish_object()) + + self.state['zoomin_params']['skip_clicks'].trace(mode='w', callback=self._reset_predictor) + self.state['zoomin_params']['target_size'].trace(mode='w', callback=self._reset_predictor) + self.state['zoomin_params']['expansion_ratio'].trace(mode='w', callback=self._reset_predictor) + self.state['predictor_params']['net_clicks_limit'].trace(mode='w', callback=self._change_brs_mode) + self.state['lbfgs_max_iters'].trace(mode='w', callback=self._change_brs_mode) + self._change_brs_mode() + + def _init_state(self): + self.state = { + 'zoomin_params': { + 'use_zoom_in': tk.BooleanVar(value=True), + 'fixed_crop': tk.BooleanVar(value=True), + 'skip_clicks': tk.IntVar(value=-1), + 'target_size': tk.IntVar(value=min(400, self.limit_longest_size)), + 'expansion_ratio': tk.DoubleVar(value=1.4) + }, + + 'predictor_params': { + 'net_clicks_limit': tk.IntVar(value=8) + }, + 'brs_mode': tk.StringVar(value='NoBRS'), + 'prob_thresh': tk.DoubleVar(value=0.5), + 'lbfgs_max_iters': tk.IntVar(value=20), + + 'alpha_blend': tk.DoubleVar(value=0.5), + 'click_radius': tk.IntVar(value=3), + } + + def _add_menu(self): + self.menubar = FocusLabelFrame(self, bd=1) + self.menubar.pack(side=tk.TOP, fill='x') + + button = FocusButton(self.menubar, text='Load image', command=self._load_image_callback) + button.pack(side=tk.LEFT) + self.save_mask_btn = FocusButton(self.menubar, text='Save mask', command=self._save_mask_callback) + self.save_mask_btn.pack(side=tk.LEFT) + self.save_mask_btn.configure(state=tk.DISABLED) + + self.load_mask_btn = FocusButton(self.menubar, text='Load mask', command=self._load_mask_callback) + self.load_mask_btn.pack(side=tk.LEFT) + self.load_mask_btn.configure(state=tk.DISABLED) + + button = FocusButton(self.menubar, text='About', command=self._about_callback) + button.pack(side=tk.LEFT) + button = FocusButton(self.menubar, text='Exit', command=self.master.quit) + button.pack(side=tk.LEFT) + + def _add_canvas(self): + self.canvas_frame = FocusLabelFrame(self, text="Image") + self.canvas_frame.rowconfigure(0, weight=1) + self.canvas_frame.columnconfigure(0, weight=1) + + self.canvas = tk.Canvas(self.canvas_frame, highlightthickness=0, cursor="hand1", width=400, height=400) + self.canvas.grid(row=0, column=0, sticky='nswe', padx=5, pady=5) + + self.image_on_canvas = None + self.canvas_frame.pack(side=tk.LEFT, fill="both", expand=True, padx=5, pady=5) + + def _add_buttons(self): + self.control_frame = FocusLabelFrame(self, text="Controls") + self.control_frame.pack(side=tk.TOP, fill='x', padx=5, pady=5) + master = self.control_frame + + self.clicks_options_frame = FocusLabelFrame(master, text="Clicks management") + self.clicks_options_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=3) + self.finish_object_button = \ + FocusButton(self.clicks_options_frame, text='Finish\nobject', bg='#b6d7a8', fg='black', width=10, height=2, + state=tk.DISABLED, command=self.controller.finish_object) + self.finish_object_button.pack(side=tk.LEFT, fill=tk.X, padx=10, pady=3) + self.undo_click_button = \ + FocusButton(self.clicks_options_frame, text='Undo click', bg='#ffe599', fg='black', width=10, height=2, + state=tk.DISABLED, command=self.controller.undo_click) + self.undo_click_button.pack(side=tk.LEFT, fill=tk.X, padx=10, pady=3) + self.reset_clicks_button = \ + FocusButton(self.clicks_options_frame, text='Reset clicks', bg='#ea9999', fg='black', width=10, height=2, + state=tk.DISABLED, command=self._reset_last_object) + self.reset_clicks_button.pack(side=tk.LEFT, fill=tk.X, padx=10, pady=3) + + self.zoomin_options_frame = FocusLabelFrame(master, text="ZoomIn options") + self.zoomin_options_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=3) + FocusCheckButton(self.zoomin_options_frame, text='Use ZoomIn', command=self._reset_predictor, + variable=self.state['zoomin_params']['use_zoom_in']).grid(row=0, column=0, padx=10) + FocusCheckButton(self.zoomin_options_frame, text='Fixed crop', command=self._reset_predictor, + variable=self.state['zoomin_params']['fixed_crop']).grid(row=1, column=0, padx=10) + tk.Label(self.zoomin_options_frame, text="Skip clicks").grid(row=0, column=1, pady=1, sticky='e') + tk.Label(self.zoomin_options_frame, text="Target size").grid(row=1, column=1, pady=1, sticky='e') + tk.Label(self.zoomin_options_frame, text="Expand ratio").grid(row=2, column=1, pady=1, sticky='e') + BoundedNumericalEntry(self.zoomin_options_frame, variable=self.state['zoomin_params']['skip_clicks'], + min_value=-1, max_value=None, vartype=int, + name='zoom_in_skip_clicks').grid(row=0, column=2, padx=10, pady=1, sticky='w') + BoundedNumericalEntry(self.zoomin_options_frame, variable=self.state['zoomin_params']['target_size'], + min_value=100, max_value=self.limit_longest_size, vartype=int, + name='zoom_in_target_size').grid(row=1, column=2, padx=10, pady=1, sticky='w') + BoundedNumericalEntry(self.zoomin_options_frame, variable=self.state['zoomin_params']['expansion_ratio'], + min_value=1.0, max_value=2.0, vartype=float, + name='zoom_in_expansion_ratio').grid(row=2, column=2, padx=10, pady=1, sticky='w') + self.zoomin_options_frame.columnconfigure((0, 1, 2), weight=1) + + self.brs_options_frame = FocusLabelFrame(master, text="BRS options") + self.brs_options_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=3) + menu = tk.OptionMenu(self.brs_options_frame, self.state['brs_mode'], + *self.brs_modes, command=self._change_brs_mode) + menu.config(width=11) + menu.grid(rowspan=2, column=0, padx=10) + self.net_clicks_label = tk.Label(self.brs_options_frame, text="Network clicks") + self.net_clicks_label.grid(row=0, column=1, pady=2, sticky='e') + self.net_clicks_entry = BoundedNumericalEntry(self.brs_options_frame, + variable=self.state['predictor_params']['net_clicks_limit'], + min_value=0, max_value=None, vartype=int, allow_inf=True, + name='net_clicks_limit') + self.net_clicks_entry.grid(row=0, column=2, padx=10, pady=2, sticky='w') + self.lbfgs_iters_label = tk.Label(self.brs_options_frame, text="L-BFGS\nmax iterations") + self.lbfgs_iters_label.grid(row=1, column=1, pady=2, sticky='e') + self.lbfgs_iters_entry = BoundedNumericalEntry(self.brs_options_frame, variable=self.state['lbfgs_max_iters'], + min_value=1, max_value=1000, vartype=int, + name='lbfgs_max_iters') + self.lbfgs_iters_entry.grid(row=1, column=2, padx=10, pady=2, sticky='w') + self.brs_options_frame.columnconfigure((0, 1), weight=1) + + self.prob_thresh_frame = FocusLabelFrame(master, text="Predictions threshold") + self.prob_thresh_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=3) + FocusHorizontalScale(self.prob_thresh_frame, from_=0.0, to=1.0, command=self._update_prob_thresh, + variable=self.state['prob_thresh']).pack(padx=10) + + self.alpha_blend_frame = FocusLabelFrame(master, text="Alpha blending coefficient") + self.alpha_blend_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=3) + FocusHorizontalScale(self.alpha_blend_frame, from_=0.0, to=1.0, command=self._update_blend_alpha, + variable=self.state['alpha_blend']).pack(padx=10, anchor=tk.CENTER) + + self.click_radius_frame = FocusLabelFrame(master, text="Visualisation click radius") + self.click_radius_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=3) + FocusHorizontalScale(self.click_radius_frame, from_=0, to=7, resolution=1, command=self._update_click_radius, + variable=self.state['click_radius']).pack(padx=10, anchor=tk.CENTER) + + def _load_image_callback(self): + self.menubar.focus_set() + if self._check_entry(self): + filename = filedialog.askopenfilename(parent=self.master, filetypes=[ + ("Images", "*.jpg *.jpeg *.png *.bmp *.tiff"), + ("All files", "*.*"), + ], title="Chose an image") + + if len(filename) > 0: + image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB) + self.controller.set_image(image) + self.save_mask_btn.configure(state=tk.NORMAL) + self.load_mask_btn.configure(state=tk.NORMAL) + + def _save_mask_callback(self): + self.menubar.focus_set() + if self._check_entry(self): + mask = self.controller.result_mask + if mask is None: + return + + filename = filedialog.asksaveasfilename(parent=self.master, initialfile='mask.png', filetypes=[ + ("PNG image", "*.png"), + ("BMP image", "*.bmp"), + ("All files", "*.*"), + ], title="Save the current mask as...") + + if len(filename) > 0: + if mask.max() < 256: + mask = mask.astype(np.uint8) + mask *= 255 // mask.max() + cv2.imwrite(filename, mask) + + def _load_mask_callback(self): + if not self.controller.net.with_prev_mask: + messagebox.showwarning("Warning", "The current model doesn't support loading external masks. " + "Please use ITER-M models for that purpose.") + return + + self.menubar.focus_set() + if self._check_entry(self): + filename = filedialog.askopenfilename(parent=self.master, filetypes=[ + ("Binary mask (png, bmp)", "*.png *.bmp"), + ("All files", "*.*"), + ], title="Chose an image") + + if len(filename) > 0: + mask = cv2.imread(filename)[:, :, 0] > 127 + self.controller.set_mask(mask) + self._update_image() + + def _about_callback(self): + self.menubar.focus_set() + + text = [ + "Developed by:", + "K.Sofiiuk and I. Petrov", + "The MIT License, 2021" + ] + + messagebox.showinfo("About Demo", '\n'.join(text)) + + def _reset_last_object(self): + self.state['alpha_blend'].set(0.5) + self.state['prob_thresh'].set(0.5) + self.controller.reset_last_object() + + def _update_prob_thresh(self, value): + if self.controller.is_incomplete_mask: + self.controller.prob_thresh = self.state['prob_thresh'].get() + self._update_image() + + def _update_blend_alpha(self, value): + self._update_image() + + def _update_click_radius(self, *args): + if self.image_on_canvas is None: + return + + self._update_image() + + def _change_brs_mode(self, *args): + if self.state['brs_mode'].get() == 'NoBRS': + self.net_clicks_entry.set('INF') + self.net_clicks_entry.configure(state=tk.DISABLED) + self.net_clicks_label.configure(state=tk.DISABLED) + self.lbfgs_iters_entry.configure(state=tk.DISABLED) + self.lbfgs_iters_label.configure(state=tk.DISABLED) + else: + if self.net_clicks_entry.get() == 'INF': + self.net_clicks_entry.set(8) + self.net_clicks_entry.configure(state=tk.NORMAL) + self.net_clicks_label.configure(state=tk.NORMAL) + self.lbfgs_iters_entry.configure(state=tk.NORMAL) + self.lbfgs_iters_label.configure(state=tk.NORMAL) + + self._reset_predictor() + + def _reset_predictor(self, *args, **kwargs): + brs_mode = self.state['brs_mode'].get() + prob_thresh = self.state['prob_thresh'].get() + net_clicks_limit = None if brs_mode == 'NoBRS' else self.state['predictor_params']['net_clicks_limit'].get() + + if self.state['zoomin_params']['use_zoom_in'].get(): + zoomin_params = { + 'skip_clicks': self.state['zoomin_params']['skip_clicks'].get(), + 'target_size': self.state['zoomin_params']['target_size'].get(), + 'expansion_ratio': self.state['zoomin_params']['expansion_ratio'].get() + } + if self.state['zoomin_params']['fixed_crop'].get(): + zoomin_params['target_size'] = (zoomin_params['target_size'], zoomin_params['target_size']) + else: + zoomin_params = None + + predictor_params = { + 'brs_mode': brs_mode, + 'prob_thresh': prob_thresh, + 'zoom_in_params': zoomin_params, + 'predictor_params': { + 'net_clicks_limit': net_clicks_limit, + 'max_size': self.limit_longest_size + }, + 'brs_opt_func_params': {'min_iou_diff': 1e-3}, + 'lbfgs_params': {'maxfun': self.state['lbfgs_max_iters'].get()} + } + self.controller.reset_predictor(predictor_params) + + def _click_callback(self, is_positive, x, y): + self.canvas.focus_set() + + if self.image_on_canvas is None: + messagebox.showwarning("Warning", "Please load an image first") + return + + if self._check_entry(self): + self.controller.add_click(x, y, is_positive) + + def _update_image(self, reset_canvas=False): + image = self.controller.get_visualization(alpha_blend=self.state['alpha_blend'].get(), + click_radius=self.state['click_radius'].get()) + if self.image_on_canvas is None: + self.image_on_canvas = CanvasImage(self.canvas_frame, self.canvas) + self.image_on_canvas.register_click_callback(self._click_callback) + + self._set_click_dependent_widgets_state() + if image is not None: + self.image_on_canvas.reload_image(Image.fromarray(image), reset_canvas) + + def _set_click_dependent_widgets_state(self): + after_1st_click_state = tk.NORMAL if self.controller.is_incomplete_mask else tk.DISABLED + before_1st_click_state = tk.DISABLED if self.controller.is_incomplete_mask else tk.NORMAL + + self.finish_object_button.configure(state=after_1st_click_state) + self.undo_click_button.configure(state=after_1st_click_state) + self.reset_clicks_button.configure(state=after_1st_click_state) + self.zoomin_options_frame.set_frame_state(before_1st_click_state) + self.brs_options_frame.set_frame_state(before_1st_click_state) + + if self.state['brs_mode'].get() == 'NoBRS': + self.net_clicks_entry.configure(state=tk.DISABLED) + self.net_clicks_label.configure(state=tk.DISABLED) + self.lbfgs_iters_entry.configure(state=tk.DISABLED) + self.lbfgs_iters_label.configure(state=tk.DISABLED) + + def _check_entry(self, widget): + all_checked = True + if widget.winfo_children is not None: + for w in widget.winfo_children(): + all_checked = all_checked and self._check_entry(w) + + if getattr(widget, "_check_bounds", None) is not None: + all_checked = all_checked and widget._check_bounds(widget.get(), '-1') + + return all_checked diff --git a/interactive_demo/canvas.py b/interactive_demo/canvas.py new file mode 100644 index 0000000..50e91f8 --- /dev/null +++ b/interactive_demo/canvas.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +""" Adopted from https://github.com/foobar167/junkyard/blob/master/manual_image_annotation1/polygon/gui_canvas.py """ +import os +import sys +import time +import math +import tkinter as tk + +from tkinter import ttk +from PIL import Image, ImageTk + + +def handle_exception(exit_code=0): + """ Use: @land.logger.handle_exception(0) + before every function which could cast an exception """ + + def wrapper(func): + def inner(*args, **kwargs): + try: + return func(*args, **kwargs) + except: + if exit_code != 0: # if zero, don't exit from the program + sys.exit(exit_code) # exit from the program + + return inner + + return wrapper + + +class AutoScrollbar(ttk.Scrollbar): + """ A scrollbar that hides itself if it's not needed. Works only for grid geometry manager """ + + def set(self, lo, hi): + if float(lo) <= 0.0 and float(hi) >= 1.0: + self.grid_remove() + else: + self.grid() + ttk.Scrollbar.set(self, lo, hi) + + @handle_exception(1) + def pack(self, **kw): + raise tk.TclError('Cannot use pack with the widget ' + self.__class__.__name__) + + @handle_exception(1) + def place(self, **kw): + raise tk.TclError('Cannot use place with the widget ' + self.__class__.__name__) + + +class CanvasImage: + """ Display and zoom image """ + + def __init__(self, canvas_frame, canvas): + """ Initialize the ImageFrame """ + self.current_scale = 1.0 # scale for the canvas image zoom, public for outer classes + self.__delta = 1.2 # zoom magnitude + self.__previous_state = 0 # previous state of the keyboard + # Create ImageFrame in placeholder widget + self.__imframe = canvas_frame + # Vertical and horizontal scrollbars for canvas + self.hbar = AutoScrollbar(canvas_frame, orient='horizontal') + self.vbar = AutoScrollbar(canvas_frame, orient='vertical') + self.hbar.grid(row=1, column=0, sticky='we') + self.vbar.grid(row=0, column=1, sticky='ns') + # Add scroll bars to canvas + self.canvas = canvas + self.canvas.configure(xscrollcommand=self.hbar.set, yscrollcommand=self.vbar.set) + self.hbar.configure(command=self.__scroll_x) # bind scrollbars to the canvas + self.vbar.configure(command=self.__scroll_y) + # Bind events to the Canvas + self.canvas.bind('', lambda event: self.__size_changed()) # canvas is resized + self.canvas.bind('', self.__left_mouse_button) # remember canvas position + self.canvas.bind('', self.__right_mouse_button_pressed) # remember canvas position + self.canvas.bind('', self.__right_mouse_button_released) # remember canvas position + self.canvas.bind('', self.__right_mouse_button_motion) # move canvas to the new position + self.canvas.bind('', self.__wheel) # zoom for Windows and MacOS, but not Linux + self.canvas.bind('', self.__wheel) # zoom for Linux, wheel scroll down + self.canvas.bind('', self.__wheel) # zoom for Linux, wheel scroll up + # Handle keystrokes in idle mode, because program slows down on a weak computers, + # when too many key stroke events in the same time + self.canvas.bind('', lambda event: self.canvas.after_idle(self.__keystroke, event)) + self.container = None + + self._click_callback = None + + def register_click_callback(self, click_callback): + self._click_callback = click_callback + + def reload_image(self, image, reset_canvas=True): + self.__original_image = image.copy() + self.__current_image = image.copy() + + if reset_canvas: + self.imwidth, self.imheight = self.__original_image.size + self.__min_side = min(self.imwidth, self.imheight) # get the smaller image side + + scale = min(self.canvas.winfo_width() / self.imwidth, self.canvas.winfo_height() / self.imheight) + if self.container: + self.canvas.delete(self.container) + + self.container = self.canvas.create_rectangle((0, 0, scale * self.imwidth, scale * self.imheight), width=0) + self.current_scale = scale + self._reset_canvas_offset() + + self.__show_image() # show image on the canvas + self.canvas.focus_set() # set focus on the canvas + + def grid(self, **kw): + """ Put CanvasImage widget on the parent widget """ + self.__imframe.grid(**kw) # place CanvasImage widget on the grid + self.__imframe.grid(sticky='nswe') # make frame container sticky + self.__imframe.rowconfigure(0, weight=1) # make canvas expandable + self.__imframe.columnconfigure(0, weight=1) + + def __show_image(self): + box_image = self.canvas.coords(self.container) # get image area + box_canvas = (self.canvas.canvasx(0), # get visible area of the canvas + self.canvas.canvasy(0), + self.canvas.canvasx(self.canvas.winfo_width()), + self.canvas.canvasy(self.canvas.winfo_height())) + box_img_int = tuple(map(int, box_image)) # convert to integer or it will not work properly + # Get scroll region box + box_scroll = [min(box_img_int[0], box_canvas[0]), min(box_img_int[1], box_canvas[1]), + max(box_img_int[2], box_canvas[2]), max(box_img_int[3], box_canvas[3])] + # Horizontal part of the image is in the visible area + if box_scroll[0] == box_canvas[0] and box_scroll[2] == box_canvas[2]: + box_scroll[0] = box_img_int[0] + box_scroll[2] = box_img_int[2] + # Vertical part of the image is in the visible area + if box_scroll[1] == box_canvas[1] and box_scroll[3] == box_canvas[3]: + box_scroll[1] = box_img_int[1] + box_scroll[3] = box_img_int[3] + # Convert scroll region to tuple and to integer + self.canvas.configure(scrollregion=tuple(map(int, box_scroll))) # set scroll region + x1 = max(box_canvas[0] - box_image[0], 0) # get coordinates (x1,y1,x2,y2) of the image tile + y1 = max(box_canvas[1] - box_image[1], 0) + x2 = min(box_canvas[2], box_image[2]) - box_image[0] + y2 = min(box_canvas[3], box_image[3]) - box_image[1] + + if int(x2 - x1) > 0 and int(y2 - y1) > 0: # show image if it in the visible area + border_width = 2 + sx1, sx2 = x1 / self.current_scale, x2 / self.current_scale + sy1, sy2 = y1 / self.current_scale, y2 / self.current_scale + crop_x, crop_y = max(0, math.floor(sx1 - border_width)), max(0, math.floor(sy1 - border_width)) + crop_w, crop_h = math.ceil(sx2 - sx1 + 2 * border_width), math.ceil(sy2 - sy1 + 2 * border_width) + crop_w = min(crop_w, self.__original_image.width - crop_x) + crop_h = min(crop_h, self.__original_image.height - crop_y) + + __current_image = self.__original_image.crop((crop_x, crop_y, + crop_x + crop_w, crop_y + crop_h)) + crop_zw = int(round(crop_w * self.current_scale)) + crop_zh = int(round(crop_h * self.current_scale)) + zoom_sx, zoom_sy = crop_zw / crop_w, crop_zh / crop_h + crop_zx, crop_zy = crop_x * zoom_sx, crop_y * zoom_sy + self.real_scale = (zoom_sx, zoom_sy) + + interpolation = Image.NEAREST if self.current_scale > 2.0 else Image.ANTIALIAS + __current_image = __current_image.resize((crop_zw, crop_zh), interpolation) + zx1, zy1 = x1 - crop_zx, y1 - crop_zy + zx2 = min(zx1 + self.canvas.winfo_width(), __current_image.width) + zy2 = min(zy1 + self.canvas.winfo_height(), __current_image.height) + + self.__current_image = __current_image.crop((zx1, zy1, zx2, zy2)) + + imagetk = ImageTk.PhotoImage(self.__current_image) + imageid = self.canvas.create_image(max(box_canvas[0], box_img_int[0]), + max(box_canvas[1], box_img_int[1]), + anchor='nw', image=imagetk) + self.canvas.lower(imageid) # set image into background + self.canvas.imagetk = imagetk # keep an extra reference to prevent garbage-collection + + def _get_click_coordinates(self, event): + x = self.canvas.canvasx(event.x) # get coordinates of the event on the canvas + y = self.canvas.canvasy(event.y) + + if self.outside(x, y): + return None + + box_image = self.canvas.coords(self.container) + x = max(x - box_image[0], 0) + y = max(y - box_image[1], 0) + + x = int(x / self.real_scale[0]) + y = int(y / self.real_scale[1]) + + return x, y + + # ================================================ Canvas Routines ================================================= + def _reset_canvas_offset(self): + self.canvas.configure(scrollregion=(0, 0, 5000, 5000)) + self.canvas.scan_mark(0, 0) + self.canvas.scan_dragto(int(self.canvas.canvasx(0)), int(self.canvas.canvasy(0)), gain=1) + + def _change_canvas_scale(self, relative_scale, x=0, y=0): + new_scale = self.current_scale * relative_scale + + if new_scale > 20: + return + + if new_scale * self.__original_image.width < self.canvas.winfo_width() and \ + new_scale * self.__original_image.height < self.canvas.winfo_height(): + return + + self.current_scale = new_scale + self.canvas.scale('all', x, y, relative_scale, relative_scale) # rescale all objects + + # noinspection PyUnusedLocal + def __scroll_x(self, *args, **kwargs): + """ Scroll canvas horizontally and redraw the image """ + self.canvas.xview(*args) # scroll horizontally + self.__show_image() # redraw the image + + # noinspection PyUnusedLocal + def __scroll_y(self, *args, **kwargs): + """ Scroll canvas vertically and redraw the image """ + self.canvas.yview(*args) # scroll vertically + self.__show_image() # redraw the image + + def __size_changed(self): + new_scale_w = self.canvas.winfo_width() / (self.current_scale * self.__original_image.width) + new_scale_h = self.canvas.winfo_height() / (self.current_scale * self.__original_image.height) + new_scale = min(new_scale_w, new_scale_h) + if new_scale > 1.0: + self._change_canvas_scale(new_scale) + self.__show_image() + + # ================================================ Mouse callbacks ================================================= + def __wheel(self, event): + """ Zoom with mouse wheel """ + x = self.canvas.canvasx(event.x) # get coordinates of the event on the canvas + y = self.canvas.canvasy(event.y) + if self.outside(x, y): return # zoom only inside image area + + scale = 1.0 + # Respond to Linux (event.num) or Windows (event.delta) wheel event + if event.num == 5 or event.delta == -120: # scroll down, zoom out, smaller + scale /= self.__delta + if event.num == 4 or event.delta == 120: # scroll up, zoom in, bigger + scale *= self.__delta + + self._change_canvas_scale(scale, x, y) + self.__show_image() + + def __left_mouse_button(self, event): + if self._click_callback is None: + return + + coords = self._get_click_coordinates(event) + + if coords is not None: + self._click_callback(is_positive=True, x=coords[0], y=coords[1]) + + def __right_mouse_button_pressed(self, event): + """ Remember previous coordinates for scrolling with the mouse """ + self._last_rb_click_time = time.time() + self._last_rb_click_event = event + self.canvas.scan_mark(event.x, event.y) + + def __right_mouse_button_released(self, event): + time_delta = time.time() - self._last_rb_click_time + move_delta = math.sqrt((event.x - self._last_rb_click_event.x) ** 2 + + (event.y - self._last_rb_click_event.y) ** 2) + if time_delta > 0.5 or move_delta > 3: + return + + if self._click_callback is None: + return + + coords = self._get_click_coordinates(self._last_rb_click_event) + + if coords is not None: + self._click_callback(is_positive=False, x=coords[0], y=coords[1]) + + def __right_mouse_button_motion(self, event): + """ Drag (move) canvas to the new position """ + move_delta = math.sqrt((event.x - self._last_rb_click_event.x) ** 2 + + (event.y - self._last_rb_click_event.y) ** 2) + if move_delta > 3: + self.canvas.scan_dragto(event.x, event.y, gain=1) + self.__show_image() # zoom tile and show it on the canvas + + def outside(self, x, y): + """ Checks if the point (x,y) is outside the image area """ + bbox = self.canvas.coords(self.container) # get image area + if bbox[0] < x < bbox[2] and bbox[1] < y < bbox[3]: + return False # point (x,y) is inside the image area + else: + return True # point (x,y) is outside the image area + + # ================================================= Keys Callback ================================================== + def __keystroke(self, event): + """ Scrolling with the keyboard. + Independent from the language of the keyboard, CapsLock, +, etc. """ + if event.state - self.__previous_state == 4: # means that the Control key is pressed + pass # do nothing if Control key is pressed + else: + self.__previous_state = event.state # remember the last keystroke state + # Up, Down, Left, Right keystrokes + self.keycodes = {} # init key codes + if os.name == 'nt': # Windows OS + self.keycodes = { + 'd': [68, 39, 102], + 'a': [65, 37, 100], + 'w': [87, 38, 104], + 's': [83, 40, 98], + } + else: # Linux OS + self.keycodes = { + 'd': [40, 114, 85], + 'a': [38, 113, 83], + 'w': [25, 111, 80], + 's': [39, 116, 88], + } + if event.keycode in self.keycodes['d']: # scroll right, keys 'd' or 'Right' + self.__scroll_x('scroll', 1, 'unit', event=event) + elif event.keycode in self.keycodes['a']: # scroll left, keys 'a' or 'Left' + self.__scroll_x('scroll', -1, 'unit', event=event) + elif event.keycode in self.keycodes['w']: # scroll up, keys 'w' or 'Up' + self.__scroll_y('scroll', -1, 'unit', event=event) + elif event.keycode in self.keycodes['s']: # scroll down, keys 's' or 'Down' + self.__scroll_y('scroll', 1, 'unit', event=event) diff --git a/interactive_demo/controller.py b/interactive_demo/controller.py new file mode 100644 index 0000000..6601781 --- /dev/null +++ b/interactive_demo/controller.py @@ -0,0 +1,154 @@ +import torch +import numpy as np +from tkinter import messagebox + +from isegm.inference import clicker +from isegm.inference.predictors import get_predictor +from isegm.utils.vis import draw_with_blend_and_clicks + + +class InteractiveController: + def __init__(self, net, device, predictor_params, update_image_callback, prob_thresh=0.5): + self.net = net + self.prob_thresh = prob_thresh + self.clicker = clicker.Clicker() + self.states = [] + self.probs_history = [] + self.object_count = 0 + self._result_mask = None + self._init_mask = None + + self.image = None + self.predictor = None + self.device = device + self.update_image_callback = update_image_callback + self.predictor_params = predictor_params + self.reset_predictor() + + def set_image(self, image): + self.image = image + self._result_mask = np.zeros(image.shape[:2], dtype=np.uint16) + self.object_count = 0 + self.reset_last_object(update_image=False) + self.update_image_callback(reset_canvas=True) + + def set_mask(self, mask): + if self.image.shape[:2] != mask.shape[:2]: + messagebox.showwarning("Warning", "A segmentation mask must have the same sizes as the current image!") + return + + if len(self.probs_history) > 0: + self.reset_last_object() + + self._init_mask = mask.astype(np.float32) + self.probs_history.append((np.zeros_like(self._init_mask), self._init_mask)) + self._init_mask = torch.tensor(self._init_mask, device=self.device).unsqueeze(0).unsqueeze(0) + self.clicker.click_indx_offset = 1 + + def add_click(self, x, y, is_positive): + self.states.append({ + 'clicker': self.clicker.get_state(), + 'predictor': self.predictor.get_states() + }) + + click = clicker.Click(is_positive=is_positive, coords=(y, x)) + self.clicker.add_click(click) + pred = self.predictor.get_prediction(self.clicker, prev_mask=self._init_mask) + if self._init_mask is not None and len(self.clicker) == 1: + pred = self.predictor.get_prediction(self.clicker, prev_mask=self._init_mask) + + torch.cuda.empty_cache() + + if self.probs_history: + self.probs_history.append((self.probs_history[-1][0], pred)) + else: + self.probs_history.append((np.zeros_like(pred), pred)) + + self.update_image_callback() + + def undo_click(self): + if not self.states: + return + + prev_state = self.states.pop() + self.clicker.set_state(prev_state['clicker']) + self.predictor.set_states(prev_state['predictor']) + self.probs_history.pop() + if not self.probs_history: + self.reset_init_mask() + self.update_image_callback() + + def partially_finish_object(self): + object_prob = self.current_object_prob + if object_prob is None: + return + + self.probs_history.append((object_prob, np.zeros_like(object_prob))) + self.states.append(self.states[-1]) + + self.clicker.reset_clicks() + self.reset_predictor() + self.reset_init_mask() + self.update_image_callback() + + def finish_object(self): + if self.current_object_prob is None: + return + + self._result_mask = self.result_mask + self.object_count += 1 + self.reset_last_object() + + def reset_last_object(self, update_image=True): + self.states = [] + self.probs_history = [] + self.clicker.reset_clicks() + self.reset_predictor() + self.reset_init_mask() + if update_image: + self.update_image_callback() + + def reset_predictor(self, predictor_params=None): + if predictor_params is not None: + self.predictor_params = predictor_params + self.predictor = get_predictor(self.net, device=self.device, + **self.predictor_params) + if self.image is not None: + self.predictor.set_input_image(self.image) + + def reset_init_mask(self): + self._init_mask = None + self.clicker.click_indx_offset = 0 + + @property + def current_object_prob(self): + if self.probs_history: + current_prob_total, current_prob_additive = self.probs_history[-1] + return np.maximum(current_prob_total, current_prob_additive) + else: + return None + + @property + def is_incomplete_mask(self): + return len(self.probs_history) > 0 + + @property + def result_mask(self): + result_mask = self._result_mask.copy() + if self.probs_history: + result_mask[self.current_object_prob > self.prob_thresh] = self.object_count + 1 + return result_mask + + def get_visualization(self, alpha_blend, click_radius): + if self.image is None: + return None + + results_mask_for_vis = self.result_mask + vis = draw_with_blend_and_clicks(self.image, mask=results_mask_for_vis, alpha=alpha_blend, + clicks_list=self.clicker.clicks_list, radius=click_radius) + if self.probs_history: + total_mask = self.probs_history[-1][0] > self.prob_thresh + results_mask_for_vis[np.logical_not(total_mask)] = 0 + vis = draw_with_blend_and_clicks(vis, mask=results_mask_for_vis, alpha=alpha_blend) + + return vis diff --git a/interactive_demo/wrappers.py b/interactive_demo/wrappers.py new file mode 100644 index 0000000..a219392 --- /dev/null +++ b/interactive_demo/wrappers.py @@ -0,0 +1,92 @@ +import tkinter as tk +from tkinter import messagebox, ttk + + +class BoundedNumericalEntry(tk.Entry): + def __init__(self, master=None, min_value=None, max_value=None, variable=None, + vartype=float, width=7, allow_inf=False, **kwargs): + if variable is None: + if vartype == float: + self.var = tk.DoubleVar() + elif vartype == int: + self.var = tk.IntVar() + else: + self.var = tk.StringVar() + else: + self.var = variable + + self.fake_var = tk.StringVar(value=self.var.get()) + self.vartype = vartype + self.old_value = self.var.get() + self.allow_inf = allow_inf + + self.min_value, self.max_value = min_value, max_value + self.get, self.set = self.fake_var.get, self.fake_var.set + + self.validate_command = master.register(self._check_bounds) + tk.Entry.__init__(self, master, textvariable=self.fake_var, validate="focus", width=width, + vcmd=(self.validate_command, '%P', '%d'), **kwargs) + + def _check_bounds(self, instr, action_type): + if self.allow_inf and instr == 'INF': + self.fake_var.set('INF') + return True + + if action_type == '-1': + try: + new_value = self.vartype(instr) + except ValueError: + pass + else: + if (self.min_value is None or new_value >= self.min_value) and \ + (self.max_value is None or new_value <= self.max_value): + if new_value != self.old_value: + self.old_value = self.vartype(self.fake_var.get()) + self.delete(0, tk.END) + self.insert(0, str(self.old_value)) + self.var.set(self.old_value) + return True + self.delete(0, tk.END) + self.insert(0, str(self.old_value)) + mn = '-inf' if self.min_value is None else str(self.min_value) + mx = '+inf' if self.max_value is None else str(self.max_value) + messagebox.showwarning("Incorrect value in input field", f"Value for {self._name} should be in " + f"[{mn}; {mx}] and of type {self.vartype.__name__}") + + return False + + +class FocusHorizontalScale(tk.Scale): + def __init__(self, *args, highlightthickness=0, sliderrelief=tk.GROOVE, resolution=0.01, + sliderlength=20, length=200, **kwargs): + tk.Scale.__init__(self, *args, orient=tk.HORIZONTAL, highlightthickness=highlightthickness, + sliderrelief=sliderrelief, resolution=resolution, + sliderlength=sliderlength, length=length, **kwargs) + self.bind("<1>", lambda event: self.focus_set()) + + +class FocusCheckButton(tk.Checkbutton): + def __init__(self, *args, highlightthickness=0, **kwargs): + tk.Checkbutton.__init__(self, *args, highlightthickness=highlightthickness, **kwargs) + self.bind("<1>", lambda event: self.focus_set()) + + +class FocusButton(tk.Button): + def __init__(self, *args, highlightthickness=0, **kwargs): + tk.Button.__init__(self, *args, highlightthickness=highlightthickness, **kwargs) + self.bind("<1>", lambda event: self.focus_set()) + + +class FocusLabelFrame(ttk.LabelFrame): + def __init__(self, *args, highlightthickness=0, relief=tk.RIDGE, borderwidth=2, **kwargs): + tk.LabelFrame.__init__(self, *args, highlightthickness=highlightthickness, relief=relief, + borderwidth=borderwidth, **kwargs) + self.bind("<1>", lambda event: self.focus_set()) + + def set_frame_state(self, state): + def set_widget_state(widget, state): + if widget.winfo_children is not None: + for w in widget.winfo_children(): + w.configure(state=state) + + set_widget_state(self, state) diff --git a/isegm/data/base.py b/isegm/data/base.py new file mode 100644 index 0000000..ee2a532 --- /dev/null +++ b/isegm/data/base.py @@ -0,0 +1,99 @@ +import random +import pickle +import numpy as np +import torch +from torchvision import transforms +from .points_sampler import MultiPointSampler +from .sample import DSample + + +class ISDataset(torch.utils.data.dataset.Dataset): + def __init__(self, + augmentator=None, + points_sampler=MultiPointSampler(max_num_points=12), + min_object_area=0, + keep_background_prob=0.0, + with_image_info=False, + samples_scores_path=None, + samples_scores_gamma=1.0, + epoch_len=-1): + super(ISDataset, self).__init__() + self.epoch_len = epoch_len + self.augmentator = augmentator + self.min_object_area = min_object_area + self.keep_background_prob = keep_background_prob + self.points_sampler = points_sampler + self.with_image_info = with_image_info + self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma) + self.to_tensor = transforms.ToTensor() + + self.dataset_samples = None + + def __getitem__(self, index): + if self.samples_precomputed_scores is not None: + index = np.random.choice(self.samples_precomputed_scores['indices'], + p=self.samples_precomputed_scores['probs']) + else: + if self.epoch_len > 0: + index = random.randrange(0, len(self.dataset_samples)) + + sample = self.get_sample(index) + sample = self.augment_sample(sample) + sample.remove_small_objects(self.min_object_area) + + self.points_sampler.sample_object(sample) + points = np.array(self.points_sampler.sample_points()) + mask = self.points_sampler.selected_mask + + output = { + 'images': self.to_tensor(sample.image), + 'points': points.astype(np.float32), + 'instances': mask + } + + if self.with_image_info: + output['image_info'] = sample.sample_id + + return output + + def augment_sample(self, sample) -> DSample: + if self.augmentator is None: + return sample + + valid_augmentation = False + while not valid_augmentation: + sample.augment(self.augmentator) + keep_sample = (self.keep_background_prob < 0.0 or + random.random() < self.keep_background_prob) + valid_augmentation = len(sample) > 0 or keep_sample + + return sample + + def get_sample(self, index) -> DSample: + raise NotImplementedError + + def __len__(self): + if self.epoch_len > 0: + return self.epoch_len + else: + return self.get_samples_number() + + def get_samples_number(self): + return len(self.dataset_samples) + + @staticmethod + def _load_samples_scores(samples_scores_path, samples_scores_gamma): + if samples_scores_path is None: + return None + + with open(samples_scores_path, 'rb') as f: + images_scores = pickle.load(f) + + probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) + probs /= probs.sum() + samples_scores = { + 'indices': [x[0] for x in images_scores], + 'probs': probs + } + print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') + return samples_scores diff --git a/isegm/data/compose.py b/isegm/data/compose.py new file mode 100644 index 0000000..e6e458c --- /dev/null +++ b/isegm/data/compose.py @@ -0,0 +1,39 @@ +import numpy as np +from math import isclose +from .base import ISDataset + + +class ComposeDataset(ISDataset): + def __init__(self, datasets, **kwargs): + super(ComposeDataset, self).__init__(**kwargs) + + self._datasets = datasets + self.dataset_samples = [] + for dataset_indx, dataset in enumerate(self._datasets): + self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) + + def get_sample(self, index): + dataset_indx, sample_indx = self.dataset_samples[index] + return self._datasets[dataset_indx].get_sample(sample_indx) + + +class ProportionalComposeDataset(ISDataset): + def __init__(self, datasets, ratios, **kwargs): + super().__init__(**kwargs) + + assert len(ratios) == len(datasets),\ + "The number of datasets must match the number of ratios" + assert isclose(sum(ratios), 1.0),\ + "The sum of ratios must be equal to 1" + + self._ratios = ratios + self._datasets = datasets + self.dataset_samples = [] + for dataset_indx, dataset in enumerate(self._datasets): + self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) + + def get_sample(self, index): + dataset_indx = np.random.choice(len(self._datasets), p=self._ratios) + sample_indx = np.random.choice(len(self._datasets[dataset_indx])) + + return self._datasets[dataset_indx].get_sample(sample_indx) diff --git a/isegm/data/datasets/__init__.py b/isegm/data/datasets/__init__.py new file mode 100644 index 0000000..966ffff --- /dev/null +++ b/isegm/data/datasets/__init__.py @@ -0,0 +1,12 @@ +from isegm.data.compose import ComposeDataset, ProportionalComposeDataset +from .berkeley import BerkeleyDataset +from .coco import CocoDataset +from .davis import DavisDataset +from .grabcut import GrabCutDataset +from .coco_lvis import CocoLvisDataset +from .lvis import LvisDataset +from .openimages import OpenImagesDataset +from .sbd import SBDDataset, SBDEvaluationDataset +from .images_dir import ImagesDirDataset +from .ade20k import ADE20kDataset +from .pascalvoc import PascalVocDataset diff --git a/isegm/data/datasets/ade20k.py b/isegm/data/datasets/ade20k.py new file mode 100644 index 0000000..6791b83 --- /dev/null +++ b/isegm/data/datasets/ade20k.py @@ -0,0 +1,55 @@ +import os +import random +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_labels_with_sizes + + +class ADE20kDataset(ISDataset): + def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val'} + + self.dataset_path = Path(dataset_path) + self.dataset_split = split + self.dataset_split_folder = 'training' if split == 'train' else 'validation' + self.stuff_prob = stuff_prob + + anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl' + if os.path.exists(anno_path): + with anno_path.open('rb') as f: + annotations = pkl.load(f) + else: + raise RuntimeError(f"Can't find annotations at {anno_path}") + self.annotations = annotations + self.dataset_samples = list(annotations.keys()) + + def get_sample(self, index) -> DSample: + image_id = self.dataset_samples[index] + sample_annos = self.annotations[image_id] + + image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg') + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # select random mask for an image + layer = random.choice(sample_annos['layers']) + mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name']) + instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances + instances_mask = instances_mask.astype(np.int32) + object_ids, _ = get_labels_with_sizes(instances_mask) + + if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob): + # remove stuff objects + for i, object_id in enumerate(object_ids): + if i in layer['stuff_instances']: + instances_mask[instances_mask == object_id] = 0 + object_ids, _ = get_labels_with_sizes(instances_mask) + + return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) diff --git a/isegm/data/datasets/berkeley.py b/isegm/data/datasets/berkeley.py new file mode 100644 index 0000000..5c269d8 --- /dev/null +++ b/isegm/data/datasets/berkeley.py @@ -0,0 +1,6 @@ +from .grabcut import GrabCutDataset + + +class BerkeleyDataset(GrabCutDataset): + def __init__(self, dataset_path, **kwargs): + super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs) diff --git a/isegm/data/datasets/coco.py b/isegm/data/datasets/coco.py new file mode 100644 index 0000000..985eb76 --- /dev/null +++ b/isegm/data/datasets/coco.py @@ -0,0 +1,74 @@ +import cv2 +import json +import random +import numpy as np +from pathlib import Path +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class CocoDataset(ISDataset): + def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): + super(CocoDataset, self).__init__(**kwargs) + self.split = split + self.dataset_path = Path(dataset_path) + self.stuff_prob = stuff_prob + + self.load_samples() + + def load_samples(self): + annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json' + self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}' + self.images_path = self.dataset_path / self.split + + with open(annotation_path, 'r') as f: + annotation = json.load(f) + + self.dataset_samples = annotation['annotations'] + + self._categories = annotation['categories'] + self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0] + self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1] + self._things_labels_set = set(self._things_labels) + self._stuff_labels_set = set(self._stuff_labels) + + def get_sample(self, index) -> DSample: + dataset_sample = self.dataset_samples[index] + + image_path = self.images_path / self.get_image_name(dataset_sample['file_name']) + label_path = self.labels_path / dataset_sample['file_name'] + + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32) + label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2] + + instance_map = np.full_like(label, 0) + things_ids = [] + stuff_ids = [] + + for segment in dataset_sample['segments_info']: + class_id = segment['category_id'] + obj_id = segment['id'] + if class_id in self._things_labels_set: + if segment['iscrowd'] == 1: + continue + things_ids.append(obj_id) + else: + stuff_ids.append(obj_id) + + instance_map[label == obj_id] = obj_id + + if self.stuff_prob > 0 and random.random() < self.stuff_prob: + instances_ids = things_ids + stuff_ids + else: + instances_ids = things_ids + + for stuff_id in stuff_ids: + instance_map[instance_map == stuff_id] = 0 + + return DSample(image, instance_map, objects_ids=instances_ids) + + @classmethod + def get_image_name(cls, panoptic_name): + return panoptic_name.replace('.png', '.jpg') diff --git a/isegm/data/datasets/coco_lvis.py b/isegm/data/datasets/coco_lvis.py new file mode 100644 index 0000000..0369103 --- /dev/null +++ b/isegm/data/datasets/coco_lvis.py @@ -0,0 +1,67 @@ +from pathlib import Path +import pickle +import random +import numpy as np +import json +import cv2 +from copy import deepcopy +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class CocoLvisDataset(ISDataset): + def __init__(self, dataset_path, split='train', stuff_prob=0.0, + allow_list_name=None, anno_file='hannotation.pickle', **kwargs): + super(CocoLvisDataset, self).__init__(**kwargs) + dataset_path = Path(dataset_path) + self._split_path = dataset_path / split + self.split = split + self._images_path = self._split_path / 'images' + self._masks_path = self._split_path / 'masks' + self.stuff_prob = stuff_prob + + with open(self._split_path / anno_file, 'rb') as f: + self.dataset_samples = sorted(pickle.load(f).items()) + + if allow_list_name is not None: + allow_list_path = self._split_path / allow_list_name + with open(allow_list_path, 'r') as f: + allow_images_ids = json.load(f) + allow_images_ids = set(allow_images_ids) + + self.dataset_samples = [sample for sample in self.dataset_samples + if sample[0] in allow_images_ids] + + def get_sample(self, index) -> DSample: + image_id, sample = self.dataset_samples[index] + image_path = self._images_path / f'{image_id}.jpg' + + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + packed_masks_path = self._masks_path / f'{image_id}.pickle' + with open(packed_masks_path, 'rb') as f: + encoded_layers, objs_mapping = pickle.load(f) + layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers] + layers = np.stack(layers, axis=2) + + instances_info = deepcopy(sample['hierarchy']) + for inst_id, inst_info in list(instances_info.items()): + if inst_info is None: + inst_info = {'children': [], 'parent': None, 'node_level': 0} + instances_info[inst_id] = inst_info + inst_info['mapping'] = objs_mapping[inst_id] + + if self.stuff_prob > 0 and random.random() < self.stuff_prob: + for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): + instances_info[inst_id] = { + 'mapping': objs_mapping[inst_id], + 'parent': None, + 'children': [] + } + else: + for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): + layer_indx, mask_id = objs_mapping[inst_id] + layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0 + + return DSample(image, layers, objects=instances_info) diff --git a/isegm/data/datasets/davis.py b/isegm/data/datasets/davis.py new file mode 100644 index 0000000..de36b96 --- /dev/null +++ b/isegm/data/datasets/davis.py @@ -0,0 +1,33 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class DavisDataset(ISDataset): + def __init__(self, dataset_path, + images_dir_name='img', masks_dir_name='gt', + **kwargs): + super(DavisDataset, self).__init__(**kwargs) + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / images_dir_name + self._insts_path = self.dataset_path / masks_dir_name + + self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] + self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} + + def get_sample(self, index) -> DSample: + image_name = self.dataset_samples[index] + image_path = str(self._images_path / image_name) + mask_path = str(self._masks_paths[image_name.split('.')[0]]) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2) + instances_mask[instances_mask > 0] = 1 + + return DSample(image, instances_mask, objects_ids=[1], sample_id=index) diff --git a/isegm/data/datasets/grabcut.py b/isegm/data/datasets/grabcut.py new file mode 100644 index 0000000..ff00446 --- /dev/null +++ b/isegm/data/datasets/grabcut.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class GrabCutDataset(ISDataset): + def __init__(self, dataset_path, + images_dir_name='data_GT', masks_dir_name='boundary_GT', + **kwargs): + super(GrabCutDataset, self).__init__(**kwargs) + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / images_dir_name + self._insts_path = self.dataset_path / masks_dir_name + + self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] + self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} + + def get_sample(self, index) -> DSample: + image_name = self.dataset_samples[index] + image_path = str(self._images_path / image_name) + mask_path = str(self._masks_paths[image_name.split('.')[0]]) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32) + instances_mask[instances_mask == 128] = -1 + instances_mask[instances_mask > 128] = 1 + + return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index) diff --git a/isegm/data/datasets/images_dir.py b/isegm/data/datasets/images_dir.py new file mode 100644 index 0000000..db7d0fa --- /dev/null +++ b/isegm/data/datasets/images_dir.py @@ -0,0 +1,59 @@ +import cv2 +import numpy as np +from pathlib import Path + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class ImagesDirDataset(ISDataset): + def __init__(self, dataset_path, + images_dir_name='images', masks_dir_name='masks', + **kwargs): + super(ImagesDirDataset, self).__init__(**kwargs) + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / images_dir_name + self._insts_path = self.dataset_path / masks_dir_name + + images_list = [x for x in sorted(self._images_path.glob('*.*'))] + + samples = {x.stem: {'image': x, 'masks': []} for x in images_list} + for mask_path in self._insts_path.glob('*.*'): + mask_name = mask_path.stem + if mask_name in samples: + samples[mask_name]['masks'].append(mask_path) + continue + + mask_name_split = mask_name.split('_') + if mask_name_split[-1].isdigit(): + mask_name = '_'.join(mask_name_split[:-1]) + assert mask_name in samples + samples[mask_name]['masks'].append(mask_path) + + for x in samples.values(): + assert len(x['masks']) > 0, x['image'] + + self.dataset_samples = [v for k, v in sorted(samples.items())] + + def get_sample(self, index) -> DSample: + sample = self.dataset_samples[index] + image_path = str(sample['image']) + + objects = [] + ignored_regions = [] + masks = [] + for indx, mask_path in enumerate(sample['masks']): + gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32) + instances_mask = np.zeros_like(gt_mask) + instances_mask[gt_mask == 128] = 2 + instances_mask[gt_mask > 128] = 1 + masks.append(instances_mask) + objects.append((indx, 1)) + ignored_regions.append((indx, 2)) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + return DSample(image, np.stack(masks, axis=2), + objects_ids=objects, ignore_ids=ignored_regions, sample_id=index) diff --git a/isegm/data/datasets/lvis.py b/isegm/data/datasets/lvis.py new file mode 100644 index 0000000..fd94b43 --- /dev/null +++ b/isegm/data/datasets/lvis.py @@ -0,0 +1,97 @@ +import json +import random +from collections import defaultdict +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class LvisDataset(ISDataset): + def __init__(self, dataset_path, split='train', + max_overlap_ratio=0.5, + **kwargs): + super(LvisDataset, self).__init__(**kwargs) + dataset_path = Path(dataset_path) + train_categories_path = dataset_path / 'train_categories.json' + self._train_path = dataset_path / 'train' + self._val_path = dataset_path / 'val' + + self.split = split + self.max_overlap_ratio = max_overlap_ratio + + with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f: + json_annotation = json.loads(f.read()) + + self.annotations = defaultdict(list) + for x in json_annotation['annotations']: + self.annotations[x['image_id']].append(x) + + if not train_categories_path.exists(): + self.generate_train_categories(dataset_path, train_categories_path) + self.dataset_samples = [x for x in json_annotation['images'] + if len(self.annotations[x['id']]) > 0] + + def get_sample(self, index) -> DSample: + image_info = self.dataset_samples[index] + image_id, image_url = image_info['id'], image_info['coco_url'] + image_filename = image_url.split('/')[-1] + image_annotations = self.annotations[image_id] + random.shuffle(image_annotations) + + # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017) + if 'train2017' in image_url: + image_path = self._train_path / 'images' / image_filename + else: + image_path = self._val_path / 'images' / image_filename + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + instances_mask = None + instances_area = defaultdict(int) + objects_ids = [] + for indx, obj_annotation in enumerate(image_annotations): + mask = self.get_mask_from_polygon(obj_annotation, image) + object_mask = mask > 0 + object_area = object_mask.sum() + + if instances_mask is None: + instances_mask = np.zeros_like(object_mask, dtype=np.int32) + + overlap_ids = np.bincount(instances_mask[object_mask].flatten()) + overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids) + if overlap_area > 0 and inst_id > 0] + overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area + if overlap_areas: + overlap_ratio = max(overlap_ratio, max(overlap_areas)) + if overlap_ratio > self.max_overlap_ratio: + continue + + instance_id = indx + 1 + instances_mask[object_mask] = instance_id + instances_area[instance_id] = object_area + objects_ids.append(instance_id) + + return DSample(image, instances_mask, objects_ids=objects_ids) + + + @staticmethod + def get_mask_from_polygon(annotation, image): + mask = np.zeros(image.shape[:2], dtype=np.int32) + for contour_points in annotation['segmentation']: + contour_points = np.array(contour_points).reshape((-1, 2)) + contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :] + cv2.fillPoly(mask, contour_points, 1) + + return mask + + @staticmethod + def generate_train_categories(dataset_path, train_categories_path): + with open(dataset_path / 'train/lvis_train.json', 'r') as f: + annotation = json.load(f) + + with open(train_categories_path, 'w') as f: + json.dump(annotation['categories'], f, indent=1) diff --git a/isegm/data/datasets/openimages.py b/isegm/data/datasets/openimages.py new file mode 100644 index 0000000..d0a81cf --- /dev/null +++ b/isegm/data/datasets/openimages.py @@ -0,0 +1,58 @@ +import os +import random +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class OpenImagesDataset(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'test'} + + self.dataset_path = Path(dataset_path) + self._split_path = self.dataset_path / split + self._images_path = self._split_path / 'images' + self._masks_path = self._split_path / 'masks' + self.dataset_split = split + + clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl' + if os.path.exists(clean_anno_path): + with clean_anno_path.open('rb') as f: + annotations = pkl.load(f) + else: + raise RuntimeError(f"Can't find annotations at {clean_anno_path}") + self.image_id_to_masks = annotations['image_id_to_masks'] + self.dataset_samples = annotations['dataset_samples'] + + def get_sample(self, index) -> DSample: + image_id = self.dataset_samples[index] + + image_path = str(self._images_path / f'{image_id}.jpg') + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + mask_paths = self.image_id_to_masks[image_id] + # select random mask for an image + mask_path = str(self._masks_path / random.choice(mask_paths)) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY) + instances_mask[instances_mask > 0] = 1 + instances_mask = instances_mask.astype(np.int32) + + min_width = min(image.shape[1], instances_mask.shape[1]) + min_height = min(image.shape[0], instances_mask.shape[0]) + + if image.shape[0] != min_height or image.shape[1] != min_width: + image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR) + if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width: + instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST) + + object_ids = [1] if instances_mask.sum() > 0 else [] + + return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) diff --git a/isegm/data/datasets/pascalvoc.py b/isegm/data/datasets/pascalvoc.py new file mode 100644 index 0000000..4e1ad48 --- /dev/null +++ b/isegm/data/datasets/pascalvoc.py @@ -0,0 +1,48 @@ +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class PascalVocDataset(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "JPEGImages" + self._insts_path = self.dataset_path / "SegmentationObject" + self.dataset_split = split + + if split == 'test': + with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f: + self.dataset_samples, self.instance_ids = pkl.load(f) + else: + with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f: + self.dataset_samples = [name.strip() for name in f.readlines()] + + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) diff --git a/isegm/data/datasets/sbd.py b/isegm/data/datasets/sbd.py new file mode 100644 index 0000000..b6a05e4 --- /dev/null +++ b/isegm/data/datasets/sbd.py @@ -0,0 +1,111 @@ +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np +from scipy.io import loadmat + +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class SBDDataset(ISDataset): + def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs): + super(SBDDataset, self).__init__(**kwargs) + assert split in {'train', 'val'} + + self.dataset_path = Path(dataset_path) + self.dataset_split = split + self._images_path = self.dataset_path / 'img' + self._insts_path = self.dataset_path / 'inst' + self._buggy_objects = dict() + self._buggy_mask_thresh = buggy_mask_thresh + + with open(self.dataset_path / f'{split}.txt', 'r') as f: + self.dataset_samples = [x.strip() for x in f.readlines()] + + def get_sample(self, index): + image_name = self.dataset_samples[index] + image_path = str(self._images_path / f'{image_name}.jpg') + inst_info_path = str(self._insts_path / f'{image_name}.mat') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_mask = self.remove_buggy_masks(index, instances_mask) + instances_ids, _ = get_labels_with_sizes(instances_mask) + + return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index) + + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask + + +class SBDEvaluationDataset(ISDataset): + def __init__(self, dataset_path, split='val', **kwargs): + super(SBDEvaluationDataset, self).__init__(**kwargs) + assert split in {'train', 'val'} + + self.dataset_path = Path(dataset_path) + self.dataset_split = split + self._images_path = self.dataset_path / 'img' + self._insts_path = self.dataset_path / 'inst' + + with open(self.dataset_path / f'{split}.txt', 'r') as f: + self.dataset_samples = [x.strip() for x in f.readlines()] + + self.dataset_samples = self.get_sbd_images_and_ids_list() + + def get_sample(self, index) -> DSample: + image_name, instance_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{image_name}.jpg') + inst_info_path = str(self._insts_path / f'{image_name}.mat') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_mask[instances_mask != instance_id] = 0 + instances_mask[instances_mask > 0] = 1 + + return DSample(image, instances_mask, objects_ids=[1], sample_id=index) + + def get_sbd_images_and_ids_list(self): + pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl' + + if pkl_path.exists(): + with open(str(pkl_path), 'rb') as fp: + images_and_ids_list = pkl.load(fp) + else: + images_and_ids_list = [] + + for sample in self.dataset_samples: + inst_info_path = str(self._insts_path / f'{sample}.mat') + instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_ids, _ = get_labels_with_sizes(instances_mask) + + for instances_id in instances_ids: + images_and_ids_list.append((sample, instances_id)) + + with open(str(pkl_path), 'wb') as fp: + pkl.dump(images_and_ids_list, fp) + + return images_and_ids_list diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py new file mode 100644 index 0000000..43cc638 --- /dev/null +++ b/isegm/data/points_sampler.py @@ -0,0 +1,305 @@ +import cv2 +import math +import random +import numpy as np +from functools import lru_cache +from .sample import DSample + + +class BasePointSampler: + def __init__(self): + self._selected_mask = None + self._selected_masks = None + + def sample_object(self, sample: DSample): + raise NotImplementedError + + def sample_points(self): + raise NotImplementedError + + @property + def selected_mask(self): + assert self._selected_mask is not None + return self._selected_mask + + @selected_mask.setter + def selected_mask(self, mask): + self._selected_mask = mask[np.newaxis, :].astype(np.float32) + + +class MultiPointSampler(BasePointSampler): + def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1, + positive_erode_prob=0.9, positive_erode_iters=3, + negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5, + merge_objects_prob=0.0, max_num_merged_objects=2, + use_hierarchy=False, soft_targets=False, + first_click_center=False, only_one_first_click=False, + sfc_inner_k=1.7, sfc_full_inner_prob=0.0): + super().__init__() + self.max_num_points = max_num_points + self.expand_ratio = expand_ratio + self.positive_erode_prob = positive_erode_prob + self.positive_erode_iters = positive_erode_iters + self.merge_objects_prob = merge_objects_prob + self.use_hierarchy = use_hierarchy + self.soft_targets = soft_targets + self.first_click_center = first_click_center + self.only_one_first_click = only_one_first_click + self.sfc_inner_k = sfc_inner_k + self.sfc_full_inner_prob = sfc_full_inner_prob + + if max_num_merged_objects == -1: + max_num_merged_objects = max_num_points + self.max_num_merged_objects = max_num_merged_objects + + self.neg_strategies = ['bg', 'other', 'border'] + self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob] + assert math.isclose(sum(self.neg_strategies_prob), 1.0) + + self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma) + self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma) + self._neg_masks = None + + def sample_object(self, sample: DSample): + if len(sample) == 0: + bg_mask = sample.get_background_mask() + self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32) + self._selected_masks = [[]] + self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies} + self._neg_masks['required'] = [] + return + + gt_mask, pos_masks, neg_masks = self._sample_mask(sample) + binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0 + + self.selected_mask = gt_mask + self._selected_masks = pos_masks + + neg_mask_bg = np.logical_not(binary_gt_mask) + neg_mask_border = self._get_border_mask(binary_gt_mask) + if len(sample) <= len(self._selected_masks): + neg_mask_other = neg_mask_bg + else: + neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()), + np.logical_not(binary_gt_mask)) + + self._neg_masks = { + 'bg': neg_mask_bg, + 'other': neg_mask_other, + 'border': neg_mask_border, + 'required': neg_masks + } + + def _sample_mask(self, sample: DSample): + root_obj_ids = sample.root_objects + + if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob: + max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects) + num_selected_objects = np.random.randint(2, max_selected_objects + 1) + random_ids = random.sample(root_obj_ids, num_selected_objects) + else: + random_ids = [random.choice(root_obj_ids)] + + gt_mask = None + pos_segments = [] + neg_segments = [] + for obj_id in random_ids: + obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample) + if gt_mask is None: + gt_mask = obj_gt_mask + else: + gt_mask = np.maximum(gt_mask, obj_gt_mask) + + pos_segments.extend(obj_pos_segments) + neg_segments.extend(obj_neg_segments) + + pos_masks = [self._positive_erode(x) for x in pos_segments] + neg_masks = [self._positive_erode(x) for x in neg_segments] + + return gt_mask, pos_masks, neg_masks + + def _sample_from_masks_layer(self, obj_id, sample: DSample): + objs_tree = sample._objects + + if not self.use_hierarchy: + node_mask = sample.get_object_mask(obj_id) + gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask + return gt_mask, [node_mask], [] + + def _select_node(node_id): + node_info = objs_tree[node_id] + if not node_info['children'] or random.random() < 0.5: + return node_id + return _select_node(random.choice(node_info['children'])) + + selected_node = _select_node(obj_id) + node_info = objs_tree[selected_node] + node_mask = sample.get_object_mask(selected_node) + gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask + pos_mask = node_mask.copy() + + negative_segments = [] + if node_info['parent'] is not None and node_info['parent'] in objs_tree: + parent_mask = sample.get_object_mask(node_info['parent']) + negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask))) + + for child_id in node_info['children']: + if objs_tree[child_id]['area'] / node_info['area'] < 0.10: + child_mask = sample.get_object_mask(child_id) + pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) + + if node_info['children']: + max_disabled_children = min(len(node_info['children']), 3) + num_disabled_children = np.random.randint(0, max_disabled_children + 1) + disabled_children = random.sample(node_info['children'], num_disabled_children) + + for child_id in disabled_children: + child_mask = sample.get_object_mask(child_id) + pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) + if self.soft_targets: + soft_child_mask = sample.get_soft_object_mask(child_id) + gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask) + else: + gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask)) + negative_segments.append(child_mask) + + return gt_mask, [pos_mask], negative_segments + + def sample_points(self): + assert self._selected_mask is not None + pos_points = self._multi_mask_sample_points(self._selected_masks, + is_negative=[False] * len(self._selected_masks), + with_first_click=self.first_click_center) + + neg_strategy = [(self._neg_masks[k], prob) + for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)] + neg_masks = self._neg_masks['required'] + [neg_strategy] + neg_points = self._multi_mask_sample_points(neg_masks, + is_negative=[False] * len(self._neg_masks['required']) + [True]) + + return pos_points + neg_points + + def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False): + selected_masks = selected_masks[:self.max_num_points] + + each_obj_points = [ + self._sample_points(mask, is_negative=is_negative[i], + with_first_click=with_first_click) + for i, mask in enumerate(selected_masks) + ] + each_obj_points = [x for x in each_obj_points if len(x) > 0] + + points = [] + if len(each_obj_points) == 1: + points = each_obj_points[0] + elif len(each_obj_points) > 1: + if self.only_one_first_click: + each_obj_points = each_obj_points[:1] + + points = [obj_points[0] for obj_points in each_obj_points] + + aggregated_masks_with_prob = [] + for indx, x in enumerate(selected_masks): + if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)): + for t, prob in x: + aggregated_masks_with_prob.append((t, prob / len(selected_masks))) + else: + aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) + + other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True) + if len(other_points_union) + len(points) <= self.max_num_points: + points.extend(other_points_union) + else: + points.extend(random.sample(other_points_union, self.max_num_points - len(points))) + + if len(points) < self.max_num_points: + points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) + + return points + + def _sample_points(self, mask, is_negative=False, with_first_click=False): + if is_negative: + num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs) + else: + num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) + + indices_probs = None + if isinstance(mask, (list, tuple)): + indices_probs = [x[1] for x in mask] + indices = [(np.argwhere(x), prob) for x, prob in mask] + if indices_probs: + assert math.isclose(sum(indices_probs), 1.0) + else: + indices = np.argwhere(mask) + + points = [] + for j in range(num_points): + first_click = with_first_click and j == 0 and indices_probs is None + + if first_click: + point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob) + elif indices_probs: + point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs) + point_indices = indices[point_indices_indx][0] + else: + point_indices = indices + + num_indices = len(point_indices) + if num_indices > 0: + point_indx = 0 if first_click else 100 + click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx] + points.append(click) + + return points + + def _positive_erode(self, mask): + if random.random() > self.positive_erode_prob: + return mask + + kernel = np.ones((3, 3), np.uint8) + eroded_mask = cv2.erode(mask.astype(np.uint8), + kernel, iterations=self.positive_erode_iters).astype(np.bool) + + if eroded_mask.sum() > 10: + return eroded_mask + else: + return mask + + def _get_border_mask(self, mask): + expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum()))) + kernel = np.ones((3, 3), np.uint8) + expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r) + expanded_mask[mask.astype(np.bool)] = 0 + return expanded_mask + + +@lru_cache(maxsize=None) +def generate_probs(max_num_points, gamma): + probs = [] + last_value = 1 + for i in range(max_num_points): + probs.append(last_value) + last_value *= gamma + + probs = np.array(probs) + probs /= probs.sum() + + return probs + + +def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): + if full_prob > 0 and random.random() < full_prob: + return obj_mask + + padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant') + + dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1] + if k > 0: + inner_mask = dt > dt.max() / k + return np.argwhere(inner_mask) + else: + prob_map = dt.flatten() + prob_map /= max(prob_map.sum(), 1e-6) + click_indx = np.random.choice(len(prob_map), p=prob_map) + click_coords = np.unravel_index(click_indx, dt.shape) + return np.array([click_coords]) diff --git a/isegm/data/sample.py b/isegm/data/sample.py new file mode 100644 index 0000000..d57794c --- /dev/null +++ b/isegm/data/sample.py @@ -0,0 +1,148 @@ +import numpy as np +from copy import deepcopy +from isegm.utils.misc import get_labels_with_sizes +from isegm.data.transforms import remove_image_only_transforms +from albumentations import ReplayCompose + + +class DSample: + def __init__(self, image, encoded_masks, objects=None, + objects_ids=None, ignore_ids=None, sample_id=None): + self.image = image + self.sample_id = sample_id + + if len(encoded_masks.shape) == 2: + encoded_masks = encoded_masks[:, :, np.newaxis] + self._encoded_masks = encoded_masks + self._ignored_regions = [] + + if objects_ids is not None: + if not objects_ids or not isinstance(objects_ids[0], tuple): + assert encoded_masks.shape[2] == 1 + objects_ids = [(0, obj_id) for obj_id in objects_ids] + + self._objects = dict() + for indx, obj_mapping in enumerate(objects_ids): + self._objects[indx] = { + 'parent': None, + 'mapping': obj_mapping, + 'children': [] + } + + if ignore_ids: + if isinstance(ignore_ids[0], tuple): + self._ignored_regions = ignore_ids + else: + self._ignored_regions = [(0, region_id) for region_id in ignore_ids] + else: + self._objects = deepcopy(objects) + + self._augmented = False + self._soft_mask_aug = None + self._original_data = self.image, self._encoded_masks, deepcopy(self._objects) + + def augment(self, augmentator): + self.reset_augmentation() + aug_output = augmentator(image=self.image, mask=self._encoded_masks) + self.image = aug_output['image'] + self._encoded_masks = aug_output['mask'] + + aug_replay = aug_output.get('replay', None) + if aug_replay: + assert len(self._ignored_regions) == 0 + mask_replay = remove_image_only_transforms(aug_replay) + self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay) + + self._compute_objects_areas() + self.remove_small_objects(min_area=1) + + self._augmented = True + + def reset_augmentation(self): + if not self._augmented: + return + orig_image, orig_masks, orig_objects = self._original_data + self.image = orig_image + self._encoded_masks = orig_masks + self._objects = deepcopy(orig_objects) + self._augmented = False + self._soft_mask_aug = None + + def remove_small_objects(self, min_area): + if self._objects and not 'area' in list(self._objects.values())[0]: + self._compute_objects_areas() + + for obj_id, obj_info in list(self._objects.items()): + if obj_info['area'] < min_area: + self._remove_object(obj_id) + + def get_object_mask(self, obj_id): + layer_indx, mask_id = self._objects[obj_id]['mapping'] + obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32) + if self._ignored_regions: + for layer_indx, mask_id in self._ignored_regions: + ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id + obj_mask[ignore_mask] = -1 + + return obj_mask + + def get_soft_object_mask(self, obj_id): + assert self._soft_mask_aug is not None + original_encoded_masks = self._original_data[1] + layer_indx, mask_id = self._objects[obj_id]['mapping'] + obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32) + obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image'] + return np.clip(obj_mask, 0, 1) + + def get_background_mask(self): + return np.max(self._encoded_masks, axis=2) == 0 + + @property + def objects_ids(self): + return list(self._objects.keys()) + + @property + def gt_mask(self): + assert len(self._objects) == 1 + return self.get_object_mask(self.objects_ids[0]) + + @property + def root_objects(self): + return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None] + + def _compute_objects_areas(self): + inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()} + ignored_regions_keys = set(self._ignored_regions) + + for layer_indx in range(self._encoded_masks.shape[2]): + objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx]) + for obj_id, obj_area in zip(objects_ids, objects_areas): + inv_key = (layer_indx, obj_id) + if inv_key in ignored_regions_keys: + continue + try: + self._objects[inverse_index[inv_key]]['area'] = obj_area + del inverse_index[inv_key] + except KeyError: + layer = self._encoded_masks[:, :, layer_indx] + layer[layer == obj_id] = 0 + self._encoded_masks[:, :, layer_indx] = layer + + for obj_id in inverse_index.values(): + self._objects[obj_id]['area'] = 0 + + def _remove_object(self, obj_id): + obj_info = self._objects[obj_id] + obj_parent = obj_info['parent'] + for child_id in obj_info['children']: + self._objects[child_id]['parent'] = obj_parent + + if obj_parent is not None: + parent_children = self._objects[obj_parent]['children'] + parent_children = [x for x in parent_children if x != obj_id] + self._objects[obj_parent]['children'] = parent_children + obj_info['children'] + + del self._objects[obj_id] + + def __len__(self): + return len(self._objects) diff --git a/isegm/data/transforms.py b/isegm/data/transforms.py new file mode 100644 index 0000000..0a3fd67 --- /dev/null +++ b/isegm/data/transforms.py @@ -0,0 +1,178 @@ +import cv2 +import random +import numpy as np + +from albumentations.core.serialization import SERIALIZABLE_REGISTRY +from albumentations import ImageOnlyTransform, DualTransform +from albumentations.core.transforms_interface import to_tuple +from albumentations.augmentations import functional as F +from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes + + +class UniformRandomResize(DualTransform): + def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): + super().__init__(always_apply, p) + self.scale_range = scale_range + self.interpolation = interpolation + + def get_params_dependent_on_targets(self, params): + scale = random.uniform(*self.scale_range) + height = int(round(params['image'].shape[0] * scale)) + width = int(round(params['image'].shape[1] * scale)) + return {'new_height': height, 'new_width': width} + + def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params): + return F.resize(img, height=new_height, width=new_width, interpolation=interpolation) + + def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params): + scale_x = new_width / params["cols"] + scale_y = new_height / params["rows"] + return F.keypoint_scale(keypoint, scale_x, scale_y) + + def get_transform_init_args_names(self): + return "scale_range", "interpolation" + + @property + def targets_as_params(self): + return ["image"] + + +class ZoomIn(DualTransform): + def __init__( + self, + height, + width, + bbox_jitter=0.1, + expansion_ratio=1.4, + min_crop_size=200, + min_area=100, + always_resize=False, + always_apply=False, + p=0.5, + ): + super(ZoomIn, self).__init__(always_apply, p) + self.height = height + self.width = width + self.bbox_jitter = to_tuple(bbox_jitter) + self.expansion_ratio = expansion_ratio + self.min_crop_size = min_crop_size + self.min_area = min_area + self.always_resize = always_resize + + def apply(self, img, selected_object, bbox, **params): + if selected_object is None: + if self.always_resize: + img = F.resize(img, height=self.height, width=self.width) + return img + + rmin, rmax, cmin, cmax = bbox + img = img[rmin:rmax + 1, cmin:cmax + 1] + img = F.resize(img, height=self.height, width=self.width) + + return img + + def apply_to_mask(self, mask, selected_object, bbox, **params): + if selected_object is None: + if self.always_resize: + mask = F.resize(mask, height=self.height, width=self.width, + interpolation=cv2.INTER_NEAREST) + return mask + + rmin, rmax, cmin, cmax = bbox + mask = mask[rmin:rmax + 1, cmin:cmax + 1] + if isinstance(selected_object, tuple): + layer_indx, mask_id = selected_object + obj_mask = mask[:, :, layer_indx] == mask_id + new_mask = np.zeros_like(mask) + new_mask[:, :, layer_indx][obj_mask] = mask_id + else: + obj_mask = mask == selected_object + new_mask = mask.copy() + new_mask[np.logical_not(obj_mask)] = 0 + + new_mask = F.resize(new_mask, height=self.height, width=self.width, + interpolation=cv2.INTER_NEAREST) + return new_mask + + def get_params_dependent_on_targets(self, params): + instances = params['mask'] + + is_mask_layer = len(instances.shape) > 2 + candidates = [] + if is_mask_layer: + for layer_indx in range(instances.shape[2]): + labels, areas = get_labels_with_sizes(instances[:, :, layer_indx]) + candidates.extend([(layer_indx, obj_id) + for obj_id, area in zip(labels, areas) + if area > self.min_area]) + else: + labels, areas = get_labels_with_sizes(instances) + candidates = [obj_id for obj_id, area in zip(labels, areas) + if area > self.min_area] + + selected_object = None + bbox = None + if candidates: + selected_object = random.choice(candidates) + if is_mask_layer: + layer_indx, mask_id = selected_object + obj_mask = instances[:, :, layer_indx] == mask_id + else: + obj_mask = instances == selected_object + + bbox = get_bbox_from_mask(obj_mask) + + if isinstance(self.expansion_ratio, tuple): + expansion_ratio = random.uniform(*self.expansion_ratio) + else: + expansion_ratio = self.expansion_ratio + + bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size) + bbox = self._jitter_bbox(bbox) + bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1) + + return { + 'selected_object': selected_object, + 'bbox': bbox + } + + def _jitter_bbox(self, bbox): + rmin, rmax, cmin, cmax = bbox + height = rmax - rmin + 1 + width = cmax - cmin + 1 + rmin = int(rmin + random.uniform(*self.bbox_jitter) * height) + rmax = int(rmax + random.uniform(*self.bbox_jitter) * height) + cmin = int(cmin + random.uniform(*self.bbox_jitter) * width) + cmax = int(cmax + random.uniform(*self.bbox_jitter) * width) + + return rmin, rmax, cmin, cmax + + def apply_to_bbox(self, bbox, **params): + raise NotImplementedError + + def apply_to_keypoint(self, keypoint, **params): + raise NotImplementedError + + @property + def targets_as_params(self): + return ["mask"] + + def get_transform_init_args_names(self): + return ("height", "width", "bbox_jitter", + "expansion_ratio", "min_crop_size", "min_area", "always_resize") + + +def remove_image_only_transforms(sdict): + if not 'transforms' in sdict: + return sdict + + keep_transforms = [] + for tdict in sdict['transforms']: + cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']] + if 'transforms' in tdict: + keep_transforms.append(remove_image_only_transforms(tdict)) + elif not issubclass(cls, ImageOnlyTransform): + keep_transforms.append(tdict) + sdict['transforms'] = keep_transforms + + return sdict diff --git a/isegm/engine/optimizer.py b/isegm/engine/optimizer.py new file mode 100644 index 0000000..fd03d8c --- /dev/null +++ b/isegm/engine/optimizer.py @@ -0,0 +1,27 @@ +import torch +import math +from isegm.utils.log import logger + + +def get_optimizer(model, opt_name, opt_kwargs): + params = [] + base_lr = opt_kwargs['lr'] + for name, param in model.named_parameters(): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + + if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): + logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') + param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult + + params.append(param_group) + + optimizer = { + 'sgd': torch.optim.SGD, + 'adam': torch.optim.Adam, + 'adamw': torch.optim.AdamW + }[opt_name.lower()](params, **opt_kwargs) + + return optimizer diff --git a/isegm/engine/trainer.py b/isegm/engine/trainer.py new file mode 100644 index 0000000..ba56323 --- /dev/null +++ b/isegm/engine/trainer.py @@ -0,0 +1,413 @@ +import os +import random +import logging +from copy import deepcopy +from collections import defaultdict + +import cv2 +import torch +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader + +from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg +from isegm.utils.vis import draw_probmap, draw_points +from isegm.utils.misc import save_checkpoint +from isegm.utils.serialization import get_config_repr +from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict +from .optimizer import get_optimizer + + +class ISTrainer(object): + def __init__(self, model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=None, + image_dump_interval=200, + checkpoint_interval=10, + tb_dump_period=25, + max_interactive_points=0, + lr_scheduler=None, + metrics=None, + additional_val_metrics=None, + net_inputs=('images', 'points'), + max_num_next_clicks=0, + click_models=None, + prev_mask_drop_prob=0.0, + ): + self.cfg = cfg + self.model_cfg = model_cfg + self.max_interactive_points = max_interactive_points + self.loss_cfg = loss_cfg + self.val_loss_cfg = deepcopy(loss_cfg) + self.tb_dump_period = tb_dump_period + self.net_inputs = net_inputs + self.max_num_next_clicks = max_num_next_clicks + + self.click_models = click_models + self.prev_mask_drop_prob = prev_mask_drop_prob + + if cfg.distributed: + cfg.batch_size //= cfg.ngpus + cfg.val_batch_size //= cfg.ngpus + + if metrics is None: + metrics = [] + self.train_metrics = metrics + self.val_metrics = deepcopy(metrics) + if additional_val_metrics is not None: + self.val_metrics.extend(additional_val_metrics) + + self.checkpoint_interval = checkpoint_interval + self.image_dump_interval = image_dump_interval + self.task_prefix = '' + self.sw = None + + self.trainset = trainset + self.valset = valset + + logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.') + logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.') + + self.train_data = DataLoader( + trainset, cfg.batch_size, + sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers + ) + + self.val_data = DataLoader( + valset, cfg.val_batch_size, + sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers + ) + + self.optim = get_optimizer(model, optimizer, optimizer_params) + model = self._load_weights(model) + + if cfg.multi_gpu: + model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids, + output_device=cfg.gpu_ids[0]) + + if self.is_master: + logger.info(model) + logger.info(get_config_repr(model._config)) + + self.device = cfg.device + self.net = model.to(self.device) + self.lr = optimizer_params['lr'] + + if lr_scheduler is not None: + self.lr_scheduler = lr_scheduler(optimizer=self.optim) + if cfg.start_epoch > 0: + for _ in range(cfg.start_epoch): + self.lr_scheduler.step() + + self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) + + if self.click_models is not None: + for click_model in self.click_models: + for param in click_model.parameters(): + param.requires_grad = False + click_model.to(self.device) + click_model.eval() + + def run(self, num_epochs, start_epoch=None, validation=True): + if start_epoch is None: + start_epoch = self.cfg.start_epoch + + logger.info(f'Starting Epoch: {start_epoch}') + logger.info(f'Total Epochs: {num_epochs}') + for epoch in range(start_epoch, num_epochs): + self.training(epoch) + if validation: + self.validation(epoch) + + def training(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + if self.cfg.distributed: + self.train_data.sampler.set_epoch(epoch) + + log_prefix = 'Train' + self.task_prefix.capitalize() + tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\ + if self.is_master else self.train_data + + for metric in self.train_metrics: + metric.reset_epoch_stats() + + self.net.train() + train_loss = 0.0 + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.train_data) + i + + loss, losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data) + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + losses_logging['overall'] = loss + reduce_loss_dict(losses_logging) + + train_loss += losses_logging['overall'].item() + + if self.is_master: + for loss_name, loss_value in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', + value=loss_value.item(), + global_step=global_step) + + for k, v in self.loss_cfg.items(): + if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: + v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) + + if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: + self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') + + self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', + value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], + global_step=global_step) + + tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}') + for metric in self.train_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for metric in self.train_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', + value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=None, multi_gpu=self.cfg.multi_gpu) + + if isinstance(self.checkpoint_interval, (list, tuple)): + checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1] + else: + checkpoint_interval = self.checkpoint_interval + + if epoch % checkpoint_interval == 0: + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=epoch, multi_gpu=self.cfg.multi_gpu) + + if hasattr(self, 'lr_scheduler'): + self.lr_scheduler.step() + + def validation(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + log_prefix = 'Val' + self.task_prefix.capitalize() + tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data + + for metric in self.val_metrics: + metric.reset_epoch_stats() + + val_loss = 0 + losses_logging = defaultdict(list) + + self.net.eval() + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.val_data) + i + loss, batch_losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data, validation=True) + + batch_losses_logging['overall'] = loss + reduce_loss_dict(batch_losses_logging) + for loss_name, loss_value in batch_losses_logging.items(): + losses_logging[loss_name].append(loss_value.item()) + + val_loss += batch_losses_logging['overall'].item() + + if self.is_master: + tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}') + for metric in self.val_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for loss_name, loss_values in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(), + global_step=epoch, disable_avg=True) + + for metric in self.val_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + def batch_forward(self, batch_data, validation=False): + metrics = self.val_metrics if validation else self.train_metrics + losses_logging = dict() + + with torch.set_grad_enabled(not validation): + batch_data = {k: v.to(self.device) for k, v in batch_data.items()} + image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] + orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone() + + prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] + + last_click_indx = None + + with torch.no_grad(): + num_iters = random.randint(0, self.max_num_next_clicks) + + for click_indx in range(num_iters): + last_click_indx = click_indx + + if not validation: + self.net.eval() + + if self.click_models is None or click_indx >= len(self.click_models): + eval_model = self.net + else: + eval_model = self.click_models[click_indx] + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) + + points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) + + if not validation: + self.net.train() + + if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None: + zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob + prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) + + batch_data['points'] = points + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + output = self.net(net_input, points) + + loss = 0.0 + loss = self.add_loss('instance_loss', loss, losses_logging, validation, + lambda: (output['instances'], batch_data['instances'])) + loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation, + lambda: (output['instances_aux'], batch_data['instances'])) + + if self.is_master: + with torch.no_grad(): + for m in metrics: + m.update(*(output.get(x) for x in m.pred_outputs), + *(batch_data[x] for x in m.gt_outputs)) + return loss, losses_logging, batch_data, output + + def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs): + loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg + loss_weight = loss_cfg.get(loss_name + '_weight', 0.0) + if loss_weight > 0.0: + loss_criterion = loss_cfg.get(loss_name) + loss = loss_criterion(*lambda_loss_inputs()) + loss = torch.mean(loss) + losses_logging[loss_name] = loss + loss = loss_weight * loss + total_loss = total_loss + loss + + return total_loss + + def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): + output_images_path = self.cfg.VIS_PATH / prefix + if self.task_prefix: + output_images_path /= self.task_prefix + + if not output_images_path.exists(): + output_images_path.mkdir(parents=True) + image_name_prefix = f'{global_step:06d}' + + def _save_image(suffix, image): + cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'), + image, [cv2.IMWRITE_JPEG_QUALITY, 85]) + + images = splitted_batch_data['images'] + points = splitted_batch_data['points'] + instance_masks = splitted_batch_data['instances'] + + gt_instance_masks = instance_masks.cpu().numpy() + predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy() + points = points.detach().cpu().numpy() + + image_blob, points = images[0], points[0] + gt_mask = np.squeeze(gt_instance_masks[0], axis=0) + predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) + + image = image_blob.cpu().numpy() * 255 + image = image.transpose((1, 2, 0)) + + image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0)) + image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255)) + + gt_mask[gt_mask < 0] = 0.25 + gt_mask = draw_probmap(gt_mask) + predicted_mask = draw_probmap(predicted_mask) + viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8) + + _save_image('instance_segmentation', viz_image[:, :, ::-1]) + + def _load_weights(self, net): + if self.cfg.weights is not None: + if os.path.isfile(self.cfg.weights): + load_weights(net, self.cfg.weights) + self.cfg.weights = None + else: + raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") + elif self.cfg.resume_exp is not None: + checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth')) + assert len(checkpoints) == 1 + + checkpoint_path = checkpoints[0] + logger.info(f'Load checkpoint from path: {checkpoint_path}') + load_weights(net, str(checkpoint_path)) + return net + + @property + def is_master(self): + return self.cfg.local_rank == 0 + + +def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): + assert click_indx > 0 + pred = pred.cpu().numpy()[:, 0, :, :] + gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 + + fn_mask = np.logical_and(gt, pred < pred_thresh) + fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) + + fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + num_points = points.size(1) // 2 + points = points.clone() + + for bindx in range(fn_mask.shape[0]): + fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + dt = fn_mask_dt if is_positive else fp_mask_dt + inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 + indices = np.argwhere(inner_mask) + if len(indices) > 0: + coords = indices[np.random.randint(0, len(indices))] + if is_positive: + points[bindx, num_points - click_indx, 0] = float(coords[0]) + points[bindx, num_points - click_indx, 1] = float(coords[1]) + points[bindx, num_points - click_indx, 2] = float(click_indx) + else: + points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) + points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) + points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) + + return points + + +def load_weights(model, path_to_weights): + current_state_dict = model.state_dict() + new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] + current_state_dict.update(new_state_dict) + model.load_state_dict(current_state_dict) diff --git a/isegm/inference/clicker.py b/isegm/inference/clicker.py new file mode 100644 index 0000000..8789e11 --- /dev/null +++ b/isegm/inference/clicker.py @@ -0,0 +1,118 @@ +import numpy as np +from copy import deepcopy +import cv2 + + +class Clicker(object): + def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): + self.click_indx_offset = click_indx_offset + if gt_mask is not None: + self.gt_mask = gt_mask == 1 + self.not_ignore_mask = gt_mask != ignore_label + else: + self.gt_mask = None + + self.reset_clicks() + + if init_clicks is not None: + for click in init_clicks: + self.add_click(click) + + def make_next_click(self, pred_mask): + assert self.gt_mask is not None + click = self._get_next_click(pred_mask) + self.add_click(click) + + def get_clicks(self, clicks_limit=None): + return self.clicks_list[:clicks_limit] + + def _get_next_click(self, pred_mask, padding=True): + fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) + fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) + + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') + + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + fn_mask_dt = fn_mask_dt * self.not_clicked_map + fp_mask_dt = fp_mask_dt * self.not_clicked_map + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + if is_positive: + coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] + else: + coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] + + return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) + + def add_click(self, click): + coords = click.coords + + click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks + if click.is_positive: + self.num_pos_clicks += 1 + else: + self.num_neg_clicks += 1 + + self.clicks_list.append(click) + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = False + + def _remove_last_click(self): + click = self.clicks_list.pop() + coords = click.coords + + if click.is_positive: + self.num_pos_clicks -= 1 + else: + self.num_neg_clicks -= 1 + + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = True + + def reset_clicks(self): + if self.gt_mask is not None: + self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) + + self.num_pos_clicks = 0 + self.num_neg_clicks = 0 + + self.clicks_list = [] + + def get_state(self): + return deepcopy(self.clicks_list) + + def set_state(self, state): + self.reset_clicks() + for click in state: + self.add_click(click) + + def __len__(self): + return len(self.clicks_list) + + +class Click: + def __init__(self, is_positive, coords, indx=None): + self.is_positive = is_positive + self.coords = coords + self.indx = indx + + @property + def coords_and_indx(self): + return (*self.coords, self.indx) + + def copy(self, **kwargs): + self_copy = deepcopy(self) + for k, v in kwargs.items(): + setattr(self_copy, k, v) + return self_copy diff --git a/isegm/inference/evaluation.py b/isegm/inference/evaluation.py new file mode 100644 index 0000000..ef46e40 --- /dev/null +++ b/isegm/inference/evaluation.py @@ -0,0 +1,56 @@ +from time import time + +import numpy as np +import torch + +from isegm.inference import utils +from isegm.inference.clicker import Clicker + +try: + get_ipython() + from tqdm import tqdm_notebook as tqdm +except NameError: + from tqdm import tqdm + + +def evaluate_dataset(dataset, predictor, **kwargs): + all_ious = [] + + start_time = time() + for index in tqdm(range(len(dataset)), leave=False): + sample = dataset.get_sample(index) + + _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, predictor, + sample_id=index, **kwargs) + all_ious.append(sample_ious) + end_time = time() + elapsed_time = end_time - start_time + + return all_ious, elapsed_time + + +def evaluate_sample(image, gt_mask, predictor, max_iou_thr, + pred_thr=0.49, min_clicks=1, max_clicks=20, + sample_id=None, callback=None): + clicker = Clicker(gt_mask=gt_mask) + pred_mask = np.zeros_like(gt_mask) + ious_list = [] + + with torch.no_grad(): + predictor.set_input_image(image) + + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + pred_probs = predictor.get_prediction(clicker) + pred_mask = pred_probs > pred_thr + + if callback is not None: + callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + + iou = utils.get_iou(gt_mask, pred_mask) + ious_list.append(iou) + + if iou >= max_iou_thr and click_indx + 1 >= min_clicks: + break + + return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs diff --git a/isegm/inference/predictors/__init__.py b/isegm/inference/predictors/__init__.py new file mode 100644 index 0000000..1e5a4f7 --- /dev/null +++ b/isegm/inference/predictors/__init__.py @@ -0,0 +1,98 @@ +from .base import BasePredictor +from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor +from .brs_functors import InputOptimizer, ScaleBiasOptimizer +from isegm.inference.transforms import ZoomIn +from isegm.model.is_hrnet_model import HRNetModel + + +def get_predictor(net, brs_mode, device, + prob_thresh=0.49, + with_flip=True, + zoom_in_params=dict(), + predictor_params=None, + brs_opt_func_params=None, + lbfgs_params=None): + lbfgs_params_ = { + 'm': 20, + 'factr': 0, + 'pgtol': 1e-8, + 'maxfun': 20, + } + + predictor_params_ = { + 'optimize_after_n_clicks': 1 + } + + if zoom_in_params is not None: + zoom_in = ZoomIn(**zoom_in_params) + else: + zoom_in = None + + if lbfgs_params is not None: + lbfgs_params_.update(lbfgs_params) + lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] + + if brs_opt_func_params is None: + brs_opt_func_params = dict() + + if isinstance(net, (list, tuple)): + assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." + + if brs_mode == 'NoBRS': + if predictor_params is not None: + predictor_params_.update(predictor_params) + predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) + elif brs_mode.startswith('f-BRS'): + predictor_params_.update({ + 'net_clicks_limit': 8, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + insertion_mode = { + 'f-BRS-A': 'after_c4', + 'f-BRS-B': 'after_aspp', + 'f-BRS-C': 'after_deeplab' + }[brs_mode] + + opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + if isinstance(net, HRNetModel): + FeaturePredictor = HRNetFeatureBRSPredictor + insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] + else: + FeaturePredictor = FeatureBRSPredictor + + predictor = FeaturePredictor(net, device, + opt_functor=opt_functor, + with_flip=with_flip, + insertion_mode=insertion_mode, + zoom_in=zoom_in, + **predictor_params_) + elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': + use_dmaps = brs_mode == 'DistMap-BRS' + + predictor_params_.update({ + 'net_clicks_limit': 5, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + opt_functor = InputOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + predictor = InputBRSPredictor(net, device, + optimize_target='dmaps' if use_dmaps else 'rgb', + opt_functor=opt_functor, + with_flip=with_flip, + zoom_in=zoom_in, + **predictor_params_) + else: + raise NotImplementedError + + return predictor diff --git a/isegm/inference/predictors/base.py b/isegm/inference/predictors/base.py new file mode 100644 index 0000000..8703117 --- /dev/null +++ b/isegm/inference/predictors/base.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F +from torchvision import transforms +from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide + + +class BasePredictor(object): + def __init__(self, model, device, + net_clicks_limit=None, + with_flip=False, + zoom_in=None, + max_size=None, + **kwargs): + self.with_flip = with_flip + self.net_clicks_limit = net_clicks_limit + self.original_image = None + self.device = device + self.zoom_in = zoom_in + self.prev_prediction = None + self.model_indx = 0 + self.click_models = None + self.net_state_dict = None + + if isinstance(model, tuple): + self.net, self.click_models = model + else: + self.net = model + + self.to_tensor = transforms.ToTensor() + + self.transforms = [zoom_in] if zoom_in is not None else [] + if max_size is not None: + self.transforms.append(LimitLongestSide(max_size=max_size)) + self.transforms.append(SigmoidForPred()) + if with_flip: + self.transforms.append(AddHorizontalFlip()) + + def set_input_image(self, image): + image_nd = self.to_tensor(image) + for transform in self.transforms: + transform.reset() + self.original_image = image_nd.to(self.device) + if len(self.original_image.shape) == 3: + self.original_image = self.original_image.unsqueeze(0) + self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) + + def get_prediction(self, clicker, prev_mask=None): + clicks_list = clicker.get_clicks() + + if self.click_models is not None: + model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1 + if model_indx != self.model_indx: + self.model_indx = model_indx + self.net = self.click_models[model_indx] + + input_image = self.original_image + if prev_mask is None: + prev_mask = self.prev_prediction + if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask: + input_image = torch.cat((input_image, prev_mask), dim=1) + image_nd, clicks_lists, is_image_changed = self.apply_transforms( + input_image, [clicks_list] + ) + + pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) + prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, + size=image_nd.size()[2:]) + + for t in reversed(self.transforms): + prediction = t.inv_transform(prediction) + + if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): + return self.get_prediction(clicker) + + self.prev_prediction = prediction + return prediction.cpu().numpy()[0, 0] + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + return self.net(image_nd, points_nd)['instances'] + + def _get_transform_states(self): + return [x.get_state() for x in self.transforms] + + def _set_transform_states(self, states): + assert len(states) == len(self.transforms) + for state, transform in zip(states, self.transforms): + transform.set_state(state) + + def apply_transforms(self, image_nd, clicks_lists): + is_image_changed = False + for t in self.transforms: + image_nd, clicks_lists = t.transform(image_nd, clicks_lists) + is_image_changed |= t.image_changed + + return image_nd, clicks_lists, is_image_changed + + def get_points_nd(self, clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + if self.net_clicks_limit is not None: + num_max_points = min(self.net_clicks_limit, num_max_points) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:self.net_clicks_limit] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device=self.device) + + def get_states(self): + return { + 'transform_states': self._get_transform_states(), + 'prev_prediction': self.prev_prediction.clone() + } + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.prev_prediction = states['prev_prediction'] diff --git a/isegm/inference/predictors/brs.py b/isegm/inference/predictors/brs.py new file mode 100644 index 0000000..910e3fd --- /dev/null +++ b/isegm/inference/predictors/brs.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.optimize import fmin_l_bfgs_b + +from .base import BasePredictor + + +class BRSBasePredictor(BasePredictor): + def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs): + super().__init__(model, device, **kwargs) + self.optimize_after_n_clicks = optimize_after_n_clicks + self.opt_functor = opt_functor + + self.opt_data = None + self.input_data = None + + def set_input_image(self, image): + super().set_input_image(image) + self.opt_data = None + self.input_data = None + + def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): + pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + + for list_indx, clicks_list in enumerate(clicks_lists): + for click in clicks_list: + y, x = click.coords + y, x = int(round(y)), int(round(x)) + y1, x1 = y - radius, x - radius + y2, x2 = y + radius + 1, x + radius + 1 + + if click.is_positive: + pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + else: + neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + + with torch.no_grad(): + pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device) + neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device) + + return pos_clicks_map, neg_clicks_map + + def get_states(self): + return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data} + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.opt_data = states['opt_data'] + + +class FeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'after_deeplab': + self.num_channels = model.feature_extractor.ch + elif self.insertion_mode == 'after_c4': + self.num_channels = model.feature_extractor.aspp_in_channels + elif self.insertion_mode == 'after_aspp': + self.num_channels = model.feature_extractor.ch + 32 + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'after_c4': + x = self.net.feature_extractor.aspp(scaled_backbone_features) + x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:], + align_corners=True) + x = torch.cat((x, self._c1_features), dim=1) + scaled_backbone_features = self.net.feature_extractor.head(x) + elif self.insertion_mode == 'after_aspp': + scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features) + + pred_logits = self.net.head(scaled_backbone_features) + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + image_nd, prev_mask = self.net.prepare_input(image_nd) + coord_features = self.net.get_coord_features(image_nd, prev_mask, points) + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + additional_features = None + elif hasattr(self.net, 'maps_transform'): + x = image_nd + additional_features = self.net.maps_transform(coord_features) + + if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp': + c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features) + c1 = self.net.feature_extractor.skip_project(c1) + + if self.insertion_mode == 'after_aspp': + x = self.net.feature_extractor.aspp(c4) + x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + backbone_features = x + else: + backbone_features = c4 + self._c1_features = c1 + else: + backbone_features = self.net.feature_extractor(x, additional_features)[0] + + return backbone_features + + +class HRNetFeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'A': + self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8]) + elif self.insertion_mode == 'C': + self.num_channels = 2 * model.feature_extractor.ocr_width + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'A': + if self.net.feature_extractor.ocr_width > 0: + out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features) + feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + feats = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + feats = scaled_backbone_features + pred_logits = self.net.feature_extractor.cls_head(feats) + elif self.insertion_mode == 'C': + pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features) + else: + raise NotImplementedError + + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + image_nd, prev_mask = self.net.prepare_input(image_nd) + coord_features = self.net.get_coord_features(image_nd, prev_mask, points) + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + additional_features = None + elif hasattr(self.net, 'maps_transform'): + x = image_nd + additional_features = self.net.maps_transform(coord_features) + + feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features) + + if self.insertion_mode == 'A': + backbone_features = feats + elif self.insertion_mode == 'C': + out_aux = self.net.feature_extractor.aux_head(feats) + feats = self.net.feature_extractor.conv3x3_ocr(feats) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + raise NotImplementedError + + return backbone_features + + +class InputBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.optimize_target = optimize_target + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + + if self.opt_data is None or is_image_changed: + if self.optimize_target == 'dmaps': + opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch + else: + opt_channels = 3 + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), + device=self.device, dtype=torch.float32) + + def get_prediction_logits(opt_bias): + input_image, prev_mask = self.net.prepare_input(image_nd) + dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd) + + if self.optimize_target == 'rgb': + input_image = input_image + opt_bias + elif self.optimize_target == 'dmaps': + if self.net.with_prev_mask: + dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias + else: + dmaps = dmaps + opt_bias + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) + if self.optimize_target == 'all': + x = x + opt_bias + coord_features = None + elif hasattr(self.net, 'maps_transform'): + x = input_image + coord_features = self.net.maps_transform(dmaps) + + pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances'] + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) + + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device, + shape=self.opt_data.shape) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), + **self.opt_functor.optimizer_params) + + self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device) + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits diff --git a/isegm/inference/predictors/brs_functors.py b/isegm/inference/predictors/brs_functors.py new file mode 100644 index 0000000..f919e13 --- /dev/null +++ b/isegm/inference/predictors/brs_functors.py @@ -0,0 +1,109 @@ +import torch +import numpy as np + +from isegm.model.metrics import _compute_iou +from .brs_losses import BRSMaskLoss + + +class BaseOptimizer: + def __init__(self, optimizer_params, + prob_thresh=0.49, + reg_weight=1e-3, + min_iou_diff=0.01, + brs_loss=BRSMaskLoss(), + with_flip=False, + flip_average=False, + **kwargs): + self.brs_loss = brs_loss + self.optimizer_params = optimizer_params + self.prob_thresh = prob_thresh + self.reg_weight = reg_weight + self.min_iou_diff = min_iou_diff + self.with_flip = with_flip + self.flip_average = flip_average + + self.best_prediction = None + self._get_prediction_logits = None + self._opt_shape = None + self._best_loss = None + self._click_masks = None + self._last_mask = None + self.device = None + + def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): + self.best_prediction = None + self._get_prediction_logits = get_prediction_logits + self._click_masks = (pos_mask, neg_mask) + self._opt_shape = shape + self._last_mask = None + self.device = device + + def __call__(self, x): + opt_params = torch.from_numpy(x).float().to(self.device) + opt_params.requires_grad_(True) + + with torch.enable_grad(): + opt_vars, reg_loss = self.unpack_opt_params(opt_params) + result_before_sigmoid = self._get_prediction_logits(*opt_vars) + result = torch.sigmoid(result_before_sigmoid) + + pos_mask, neg_mask = self._click_masks + if self.with_flip and self.flip_average: + result, result_flipped = torch.chunk(result, 2, dim=0) + result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) + pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] + + loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) + loss = loss + reg_loss + + f_val = loss.detach().cpu().numpy() + if self.best_prediction is None or f_val < self._best_loss: + self.best_prediction = result_before_sigmoid.detach() + self._best_loss = f_val + + if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: + return [f_val, np.zeros_like(x)] + + current_mask = result > self.prob_thresh + if self._last_mask is not None and self.min_iou_diff > 0: + diff_iou = _compute_iou(current_mask, self._last_mask) + if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: + return [f_val, np.zeros_like(x)] + self._last_mask = current_mask + + loss.backward() + f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) + + return [f_val, f_grad] + + def unpack_opt_params(self, opt_params): + raise NotImplementedError + + +class InputOptimizer(BaseOptimizer): + def unpack_opt_params(self, opt_params): + opt_params = opt_params.view(self._opt_shape) + if self.with_flip: + opt_params_flipped = torch.flip(opt_params, dims=[3]) + opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) + reg_loss = self.reg_weight * torch.sum(opt_params**2) + + return (opt_params,), reg_loss + + +class ScaleBiasOptimizer(BaseOptimizer): + def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): + super().__init__(*args, **kwargs) + self.scale_act = scale_act + self.reg_bias_weight = reg_bias_weight + + def unpack_opt_params(self, opt_params): + scale, bias = torch.chunk(opt_params, 2, dim=0) + reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) + + if self.scale_act == 'tanh': + scale = torch.tanh(scale) + elif self.scale_act == 'sin': + scale = torch.sin(scale) + + return (1 + scale, bias), reg_loss diff --git a/isegm/inference/predictors/brs_losses.py b/isegm/inference/predictors/brs_losses.py new file mode 100644 index 0000000..ea98824 --- /dev/null +++ b/isegm/inference/predictors/brs_losses.py @@ -0,0 +1,58 @@ +import torch + +from isegm.model.losses import SigmoidBinaryCrossEntropyLoss + + +class BRSMaskLoss(torch.nn.Module): + def __init__(self, eps=1e-5): + super().__init__() + self._eps = eps + + def forward(self, result, pos_mask, neg_mask): + pos_diff = (1 - result) * pos_mask + pos_target = torch.sum(pos_diff ** 2) + pos_target = pos_target / (torch.sum(pos_mask) + self._eps) + + neg_diff = result * neg_mask + neg_target = torch.sum(neg_diff ** 2) + neg_target = neg_target / (torch.sum(neg_mask) + self._eps) + + loss = pos_target + neg_target + + with torch.no_grad(): + f_max_pos = torch.max(torch.abs(pos_diff)).item() + f_max_neg = torch.max(torch.abs(neg_diff)).item() + + return loss, f_max_pos, f_max_neg + + +class OracleMaskLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.gt_mask = None + self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) + self.predictor = None + self.history = [] + + def set_gt_mask(self, gt_mask): + self.gt_mask = gt_mask + self.history = [] + + def forward(self, result, pos_mask, neg_mask): + gt_mask = self.gt_mask.to(result.device) + if self.predictor.object_roi is not None: + r1, r2, c1, c2 = self.predictor.object_roi[:4] + gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] + gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) + + if result.shape[0] == 2: + gt_mask_flipped = torch.flip(gt_mask, dims=[3]) + gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) + + loss = self.loss(result, gt_mask) + self.history.append(loss.detach().cpu().numpy()[0]) + + if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: + return 0, 0, 0 + + return loss, 1.0, 1.0 diff --git a/isegm/inference/transforms/__init__.py b/isegm/inference/transforms/__init__.py new file mode 100644 index 0000000..cbd54e3 --- /dev/null +++ b/isegm/inference/transforms/__init__.py @@ -0,0 +1,5 @@ +from .base import SigmoidForPred +from .flip import AddHorizontalFlip +from .zoom_in import ZoomIn +from .limit_longest_side import LimitLongestSide +from .crops import Crops diff --git a/isegm/inference/transforms/base.py b/isegm/inference/transforms/base.py new file mode 100644 index 0000000..eb5a2de --- /dev/null +++ b/isegm/inference/transforms/base.py @@ -0,0 +1,38 @@ +import torch + + +class BaseTransform(object): + def __init__(self): + self.image_changed = False + + def transform(self, image_nd, clicks_lists): + raise NotImplementedError + + def inv_transform(self, prob_map): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def set_state(self, state): + raise NotImplementedError + + +class SigmoidForPred(BaseTransform): + def transform(self, image_nd, clicks_lists): + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + return torch.sigmoid(prob_map) + + def reset(self): + pass + + def get_state(self): + return None + + def set_state(self, state): + pass diff --git a/isegm/inference/transforms/crops.py b/isegm/inference/transforms/crops.py new file mode 100644 index 0000000..428d977 --- /dev/null +++ b/isegm/inference/transforms/crops.py @@ -0,0 +1,97 @@ +import math + +import torch +import numpy as np +from typing import List + +from isegm.inference.clicker import Click +from .base import BaseTransform + + +class Crops(BaseTransform): + def __init__(self, crop_size=(320, 480), min_overlap=0.2): + super().__init__() + self.crop_height, self.crop_width = crop_size + self.min_overlap = min_overlap + + self.x_offsets = None + self.y_offsets = None + self._counts = None + + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_height, image_width = image_nd.shape[2:4] + self._counts = None + + if image_height < self.crop_height or image_width < self.crop_width: + return image_nd, clicks_lists + + self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) + self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) + self._counts = np.zeros((image_height, image_width)) + + image_crops = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 + image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] + image_crops.append(image_crop) + image_crops = torch.cat(image_crops, dim=0) + self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) + + clicks_list = clicks_lists[0] + clicks_lists = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] + clicks_lists.append(crop_clicks) + + return image_crops, clicks_lists + + def inv_transform(self, prob_map): + if self._counts is None: + return prob_map + + new_prob_map = torch.zeros((1, 1, *self._counts.shape), + dtype=prob_map.dtype, device=prob_map.device) + + crop_indx = 0 + for dy in self.y_offsets: + for dx in self.x_offsets: + new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] + crop_indx += 1 + new_prob_map = torch.div(new_prob_map, self._counts) + + return new_prob_map + + def get_state(self): + return self.x_offsets, self.y_offsets, self._counts + + def set_state(self, state): + self.x_offsets, self.y_offsets, self._counts = state + + def reset(self): + self.x_offsets = None + self.y_offsets = None + self._counts = None + + +def get_offsets(length, crop_size, min_overlap_ratio=0.2): + if length == crop_size: + return [0] + + N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) + N = math.ceil(N) + + overlap_ratio = (N - length / crop_size) / (N - 1) + overlap_width = int(crop_size * overlap_ratio) + + offsets = [0] + for i in range(1, N): + new_offset = offsets[-1] + crop_size - overlap_width + if new_offset + crop_size > length: + new_offset = length - crop_size + + offsets.append(new_offset) + + return offsets diff --git a/isegm/inference/transforms/flip.py b/isegm/inference/transforms/flip.py new file mode 100644 index 0000000..373640e --- /dev/null +++ b/isegm/inference/transforms/flip.py @@ -0,0 +1,37 @@ +import torch + +from typing import List +from isegm.inference.clicker import Click +from .base import BaseTransform + + +class AddHorizontalFlip(BaseTransform): + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert len(image_nd.shape) == 4 + image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) + + image_width = image_nd.shape[3] + clicks_lists_flipped = [] + for clicks_list in clicks_lists: + clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) + for click in clicks_list] + clicks_lists_flipped.append(clicks_list_flipped) + clicks_lists = clicks_lists + clicks_lists_flipped + + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 + num_maps = prob_map.shape[0] // 2 + prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] + + return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) + + def get_state(self): + return None + + def set_state(self, state): + pass + + def reset(self): + pass diff --git a/isegm/inference/transforms/limit_longest_side.py b/isegm/inference/transforms/limit_longest_side.py new file mode 100644 index 0000000..50c5a53 --- /dev/null +++ b/isegm/inference/transforms/limit_longest_side.py @@ -0,0 +1,22 @@ +from .zoom_in import ZoomIn, get_roi_image_nd + + +class LimitLongestSide(ZoomIn): + def __init__(self, max_size=800): + super().__init__(target_size=max_size, skip_clicks=0) + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_max_size = max(image_nd.shape[2:4]) + self.image_changed = False + + if image_max_size <= self.target_size: + return image_nd, clicks_lists + self._input_image = image_nd + + self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + self.image_changed = True + + tclicks_lists = [self._transform_clicks(clicks_lists[0])] + return self._roi_image, tclicks_lists diff --git a/isegm/inference/transforms/zoom_in.py b/isegm/inference/transforms/zoom_in.py new file mode 100644 index 0000000..04b576a --- /dev/null +++ b/isegm/inference/transforms/zoom_in.py @@ -0,0 +1,175 @@ +import torch + +from typing import List +from isegm.inference.clicker import Click +from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox +from .base import BaseTransform + + +class ZoomIn(BaseTransform): + def __init__(self, + target_size=400, + skip_clicks=1, + expansion_ratio=1.4, + min_crop_size=200, + recompute_thresh_iou=0.5, + prob_thresh=0.50): + super().__init__() + self.target_size = target_size + self.min_crop_size = min_crop_size + self.skip_clicks = skip_clicks + self.expansion_ratio = expansion_ratio + self.recompute_thresh_iou = recompute_thresh_iou + self.prob_thresh = prob_thresh + + self._input_image_shape = None + self._prev_probs = None + self._object_roi = None + self._roi_image = None + + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + self.image_changed = False + + clicks_list = clicks_lists[0] + if len(clicks_list) <= self.skip_clicks: + return image_nd, clicks_lists + + self._input_image_shape = image_nd.shape + + current_object_roi = None + if self._prev_probs is not None: + current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if current_pred_mask.sum() > 0: + current_object_roi = get_object_roi(current_pred_mask, clicks_list, + self.expansion_ratio, self.min_crop_size) + + if current_object_roi is None: + if self.skip_clicks >= 0: + return image_nd, clicks_lists + else: + current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1 + + update_object_roi = False + if self._object_roi is None: + update_object_roi = True + elif not check_object_roi(self._object_roi, clicks_list): + update_object_roi = True + elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: + update_object_roi = True + + if update_object_roi: + self._object_roi = current_object_roi + self.image_changed = True + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + + tclicks_lists = [self._transform_clicks(clicks_list)] + return self._roi_image.to(image_nd.device), tclicks_lists + + def inv_transform(self, prob_map): + if self._object_roi is None: + self._prev_probs = prob_map.cpu().numpy() + return prob_map + + assert prob_map.shape[0] == 1 + rmin, rmax, cmin, cmax = self._object_roi + prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), + mode='bilinear', align_corners=True) + + if self._prev_probs is not None: + new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) + new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map + else: + new_prob_map = prob_map + + self._prev_probs = new_prob_map.cpu().numpy() + + return new_prob_map + + def check_possible_recalculation(self): + if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: + return False + + pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if pred_mask.sum() > 0: + possible_object_roi = get_object_roi(pred_mask, [], + self.expansion_ratio, self.min_crop_size) + image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) + if get_bbox_iou(possible_object_roi, image_roi) < 0.50: + return True + return False + + def get_state(self): + roi_image = self._roi_image.cpu() if self._roi_image is not None else None + return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed + + def set_state(self, state): + self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state + + def reset(self): + self._input_image_shape = None + self._object_roi = None + self._prev_probs = None + self._roi_image = None + self.image_changed = False + + def _transform_clicks(self, clicks_list): + if self._object_roi is None: + return clicks_list + + rmin, rmax, cmin, cmax = self._object_roi + crop_height, crop_width = self._roi_image.shape[2:] + + transformed_clicks = [] + for click in clicks_list: + new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1) + new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1) + transformed_clicks.append(click.copy(coords=(new_r, new_c))) + return transformed_clicks + + +def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size): + pred_mask = pred_mask.copy() + + for click in clicks_list: + if click.is_positive: + pred_mask[int(click.coords[0]), int(click.coords[1])] = 1 + + bbox = get_bbox_from_mask(pred_mask) + bbox = expand_bbox(bbox, expansion_ratio, min_crop_size) + h, w = pred_mask.shape[0], pred_mask.shape[1] + bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1) + + return bbox + + +def get_roi_image_nd(image_nd, object_roi, target_size): + rmin, rmax, cmin, cmax = object_roi + + height = rmax - rmin + 1 + width = cmax - cmin + 1 + + if isinstance(target_size, tuple): + new_height, new_width = target_size + else: + scale = target_size / max(height, width) + new_height = int(round(height * scale)) + new_width = int(round(width * scale)) + + with torch.no_grad(): + roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] + roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), + mode='bilinear', align_corners=True) + + return roi_image_nd + + +def check_object_roi(object_roi, clicks_list): + for click in clicks_list: + if click.is_positive: + if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]: + return False + if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]: + return False + + return True diff --git a/isegm/inference/utils.py b/isegm/inference/utils.py new file mode 100644 index 0000000..7102d40 --- /dev/null +++ b/isegm/inference/utils.py @@ -0,0 +1,143 @@ +from datetime import timedelta +from pathlib import Path + +import torch +import numpy as np + +from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset +from isegm.utils.serialization import load_model + + +def get_time_metrics(all_ious, elapsed_time): + n_images = len(all_ious) + n_clicks = sum(map(len, all_ious)) + + mean_spc = elapsed_time / n_clicks + mean_spi = elapsed_time / n_images + + return mean_spc, mean_spi + + +def load_is_model(checkpoint, device, **kwargs): + if isinstance(checkpoint, (str, Path)): + state_dict = torch.load(checkpoint, map_location='cpu') + else: + state_dict = checkpoint + + if isinstance(state_dict, list): + model = load_single_is_model(state_dict[0], device, **kwargs) + models = [load_single_is_model(x, device, **kwargs) for x in state_dict] + + return model, models + else: + return load_single_is_model(state_dict, device, **kwargs) + + +def load_single_is_model(state_dict, device, **kwargs): + model = load_model(state_dict['config'], **kwargs) + model.load_state_dict(state_dict['state_dict'], strict=False) + + for param in model.parameters(): + param.requires_grad = False + model.to(device) + model.eval() + + return model + + +def get_dataset(dataset_name, cfg): + if dataset_name == 'GrabCut': + dataset = GrabCutDataset(cfg.GRABCUT_PATH) + elif dataset_name == 'Berkeley': + dataset = BerkeleyDataset(cfg.BERKELEY_PATH) + elif dataset_name == 'DAVIS': + dataset = DavisDataset(cfg.DAVIS_PATH) + elif dataset_name == 'SBD': + dataset = SBDEvaluationDataset(cfg.SBD_PATH) + elif dataset_name == 'SBD_Train': + dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train') + elif dataset_name == 'PascalVOC': + dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test') + elif dataset_name == 'COCO_MVal': + dataset = DavisDataset(cfg.COCO_MVAL_PATH) + else: + dataset = None + + return dataset + + +def get_iou(gt_mask, pred_mask, ignore_label=-1): + ignore_gt_mask_inv = gt_mask != ignore_label + obj_gt_mask = gt_mask == 1 + + intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + + return intersection / union + + +def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): + def _get_noc(iou_arr, iou_thr): + vals = iou_arr >= iou_thr + return np.argmax(vals) + 1 if np.any(vals) else max_clicks + + noc_list = [] + over_max_list = [] + for iou_thr in iou_thrs: + scores_arr = np.array([_get_noc(iou_arr, iou_thr) + for iou_arr in all_ious], dtype=np.int) + + score = scores_arr.mean() + over_max = (scores_arr == max_clicks).sum() + + noc_list.append(score) + over_max_list.append(over_max) + + return noc_list, over_max_list + + +def find_checkpoint(weights_folder, checkpoint_name): + weights_folder = Path(weights_folder) + if ':' in checkpoint_name: + model_name, checkpoint_name = checkpoint_name.split(':') + models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] + assert len(models_candidates) == 1 + model_folder = models_candidates[0] + else: + model_folder = weights_folder + + if checkpoint_name.endswith('.pth'): + if Path(checkpoint_name).exists(): + checkpoint_path = checkpoint_name + else: + checkpoint_path = weights_folder / checkpoint_name + else: + model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) + assert len(model_checkpoints) == 1 + checkpoint_path = model_checkpoints[0] + + return str(checkpoint_path) + + +def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, + n_clicks=20, model_name=None): + table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' + f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' + f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' + f'{"SPC,s":^7}|{"Time":^9}|') + row_width = len(table_header) + + header = f'Eval results for model: {model_name}\n' if model_name is not None else '' + header += '-' * row_width + '\n' + header += table_header + '\n' + '-' * row_width + + eval_time = str(timedelta(seconds=int(elapsed_time))) + table_row = f'|{brs_type:^13}|{dataset_name:^11}|' + table_row += f'{noc_list[0]:^9.2f}|' + table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' + + return header, table_row \ No newline at end of file diff --git a/isegm/model/initializer.py b/isegm/model/initializer.py new file mode 100644 index 0000000..470c7df --- /dev/null +++ b/isegm/model/initializer.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import numpy as np + + +class Initializer(object): + def __init__(self, local_init=True, gamma=None): + self.local_init = local_init + self.gamma = gamma + + def __call__(self, m): + if getattr(m, '__initialized', False): + return + + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: + if m.weight is not None: + self._init_gamma(m.weight.data) + if m.bias is not None: + self._init_beta(m.bias.data) + else: + if getattr(m, 'weight', None) is not None: + self._init_weight(m.weight.data) + if getattr(m, 'bias', None) is not None: + self._init_bias(m.bias.data) + + if self.local_init: + object.__setattr__(m, '__initialized', True) + + def _init_weight(self, data): + nn.init.uniform_(data, -0.07, 0.07) + + def _init_bias(self, data): + nn.init.constant_(data, 0) + + def _init_gamma(self, data): + if self.gamma is None: + nn.init.constant_(data, 1.0) + else: + nn.init.normal_(data, 1.0, self.gamma) + + def _init_beta(self, data): + nn.init.constant_(data, 0) + + +class Bilinear(Initializer): + def __init__(self, scale, groups, in_channels, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.groups = groups + self.in_channels = in_channels + + def _init_weight(self, data): + """Reset the weight and bias.""" + bilinear_kernel = self.get_bilinear_kernel(self.scale) + weight = torch.zeros_like(data) + for i in range(self.in_channels): + if self.groups == 1: + j = i + else: + j = 0 + weight[i, j] = bilinear_kernel + data[:] = weight + + @staticmethod + def get_bilinear_kernel(scale): + """Generate a bilinear upsampling kernel.""" + kernel_size = 2 * scale - scale % 2 + scale = (kernel_size + 1) // 2 + center = scale - 0.5 * (1 + kernel_size % 2) + + og = np.ogrid[:kernel_size, :kernel_size] + kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) + + return torch.tensor(kernel, dtype=torch.float32) + + +class XavierGluon(Initializer): + def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): + super().__init__(**kwargs) + + self.rnd_type = rnd_type + self.factor_type = factor_type + self.magnitude = float(magnitude) + + def _init_weight(self, arr): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) + + if self.factor_type == 'avg': + factor = (fan_in + fan_out) / 2.0 + elif self.factor_type == 'in': + factor = fan_in + elif self.factor_type == 'out': + factor = fan_out + else: + raise ValueError('Incorrect factor type') + scale = np.sqrt(self.magnitude / factor) + + if self.rnd_type == 'uniform': + nn.init.uniform_(arr, -scale, scale) + elif self.rnd_type == 'gaussian': + nn.init.normal_(arr, 0, scale) + else: + raise ValueError('Unknown random type') diff --git a/isegm/model/is_deeplab_model.py b/isegm/model/is_deeplab_model.py new file mode 100644 index 0000000..45fa553 --- /dev/null +++ b/isegm/model/is_deeplab_model.py @@ -0,0 +1,25 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.deeplab_v3 import DeepLabV3Plus +from .modeling.basic_blocks import SepConvHead +from isegm.model.modifiers import LRMult + + +class DeeplabModel(ISModel): + @serialize + def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, + backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs): + super().__init__(norm_layer=norm_layer, **kwargs) + + self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout, + norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer) + self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult)) + self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, + num_layers=2, norm_layer=norm_layer) + + def backbone_forward(self, image, coord_features=None): + backbone_features = self.feature_extractor(image, coord_features) + + return {'instances': self.head(backbone_features[0])} diff --git a/isegm/model/is_hrnet_model.py b/isegm/model/is_hrnet_model.py new file mode 100644 index 0000000..b8a82e7 --- /dev/null +++ b/isegm/model/is_hrnet_model.py @@ -0,0 +1,26 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.hrnet_ocr import HighResolutionNet +from isegm.model.modifiers import LRMult + + +class HRNetModel(ISModel): + @serialize + def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1, + norm_layer=nn.BatchNorm2d, **kwargs): + super().__init__(norm_layer=norm_layer, **kwargs) + + self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small, + num_classes=1, norm_layer=norm_layer) + self.feature_extractor.apply(LRMult(backbone_lr_mult)) + if ocr_width > 0: + self.feature_extractor.ocr_distri_head.apply(LRMult(1.0)) + self.feature_extractor.ocr_gather_head.apply(LRMult(1.0)) + self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0)) + + def backbone_forward(self, image, coord_features=None): + net_outputs = self.feature_extractor(image, coord_features) + + return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]} diff --git a/isegm/model/is_model.py b/isegm/model/is_model.py new file mode 100644 index 0000000..f655540 --- /dev/null +++ b/isegm/model/is_model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import numpy as np + +from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize +from isegm.model.modifiers import LRMult + + +class ISModel(nn.Module): + def __init__(self, use_rgb_conv=True, with_aux_output=False, + norm_radius=260, use_disks=False, cpu_dist_maps=False, + clicks_groups=None, with_prev_mask=False, use_leaky_relu=False, + binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d, + norm_mean_std=([.485, .456, .406], [.229, .224, .225])): + super().__init__() + self.with_aux_output = with_aux_output + self.clicks_groups = clicks_groups + self.with_prev_mask = with_prev_mask + self.binary_prev_mask = binary_prev_mask + self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) + + self.coord_feature_ch = 2 + if clicks_groups is not None: + self.coord_feature_ch *= len(clicks_groups) + + if self.with_prev_mask: + self.coord_feature_ch += 1 + + if use_rgb_conv: + rgb_conv_layers = [ + nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1), + norm_layer(6 + self.coord_feature_ch), + nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), + nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1) + ] + self.rgb_conv = nn.Sequential(*rgb_conv_layers) + elif conv_extend: + self.rgb_conv = None + self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64, + kernel_size=3, stride=2, padding=1) + self.maps_transform.apply(LRMult(0.1)) + else: + self.rgb_conv = None + mt_layers = [ + nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1), + ScaleLayer(init_value=0.05, lr_mult=1) + ] + self.maps_transform = nn.Sequential(*mt_layers) + + if self.clicks_groups is not None: + self.dist_maps = nn.ModuleList() + for click_radius in self.clicks_groups: + self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0, + cpu_mode=cpu_dist_maps, use_disks=use_disks)) + else: + self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, + cpu_mode=cpu_dist_maps, use_disks=use_disks) + + def forward(self, image, points): + image, prev_mask = self.prepare_input(image) + coord_features = self.get_coord_features(image, prev_mask, points) + + if self.rgb_conv is not None: + x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) + outputs = self.backbone_forward(x) + else: + coord_features = self.maps_transform(coord_features) + outputs = self.backbone_forward(image, coord_features) + + outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:], + mode='bilinear', align_corners=True) + if self.with_aux_output: + outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:], + mode='bilinear', align_corners=True) + + return outputs + + def prepare_input(self, image): + prev_mask = None + if self.with_prev_mask: + prev_mask = image[:, 3:, :, :] + image = image[:, :3, :, :] + if self.binary_prev_mask: + prev_mask = (prev_mask > 0.5).float() + + image = self.normalization(image) + return image, prev_mask + + def backbone_forward(self, image, coord_features=None): + raise NotImplementedError + + def get_coord_features(self, image, prev_mask, points): + if self.clicks_groups is not None: + points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,)) + coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)] + coord_features = torch.cat(coord_features, dim=1) + else: + coord_features = self.dist_maps(image, points) + + if prev_mask is not None: + coord_features = torch.cat((prev_mask, coord_features), dim=1) + + return coord_features + + +def split_points_by_order(tpoints: torch.Tensor, groups): + points = tpoints.cpu().numpy() + num_groups = len(groups) + bs = points.shape[0] + num_points = points.shape[1] // 2 + + groups = [x if x > 0 else num_points for x in groups] + group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) + for x in groups] + + last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int) + for group_indx, group_size in enumerate(groups): + last_point_indx_group[:, group_indx, 1] = group_size + + for bindx in range(bs): + for pindx in range(2 * num_points): + point = points[bindx, pindx, :] + group_id = int(point[2]) + if group_id < 0: + continue + + is_negative = int(pindx >= num_points) + if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click + group_id = num_groups - 1 + + new_point_indx = last_point_indx_group[bindx, group_id, is_negative] + last_point_indx_group[bindx, group_id, is_negative] += 1 + + group_points[group_id][bindx, new_point_indx, :] = point + + group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) + for x in group_points] + + return group_points diff --git a/isegm/model/losses.py b/isegm/model/losses.py new file mode 100644 index 0000000..b90f18f --- /dev/null +++ b/isegm/model/losses.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from isegm.utils import misc + + +class NormalizedFocalLossSigmoid(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, + from_sigmoid=False, detach_delimeter=True, + batch_axis=0, weight=None, size_average=True, + ignore_label=-1): + super(NormalizedFocalLossSigmoid, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._from_logits = from_sigmoid + self._eps = eps + self._size_average = size_average + self._detach_delimeter = detach_delimeter + self._max_mult = max_mult + self._k_sum = 0 + self._m_max = 0 + + def forward(self, pred, label): + one_hot = label > 0.5 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + + beta = (1 - pt) ** self._gamma + + sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) + beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) + mult = sw_sum / (beta_sum + self._eps) + if self._detach_delimeter: + mult = mult.detach() + beta = beta * mult + if self._max_mult > 0: + beta = torch.clamp_max(beta, self._max_mult) + + with torch.no_grad(): + ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() + sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() + if np.any(ignore_area == 0): + self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() + + beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) + beta_pmax = beta_pmax.mean().item() + self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return loss + + def log_states(self, sw, name, global_step): + sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) + sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) + + +class FocalLoss(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, + from_logits=False, batch_axis=0, + weight=None, num_class=None, + eps=1e-9, size_average=True, scale=1.0, + ignore_label=-1): + super(FocalLoss, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._scale = scale + self._num_class = num_class + self._from_logits = from_logits + self._eps = eps + self._size_average = size_average + + def forward(self, pred, label, sample_weight=None): + one_hot = label > 0.5 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + + beta = (1 - pt) ** self._gamma + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return self._scale * loss + + +class SoftIoU(nn.Module): + def __init__(self, from_sigmoid=False, ignore_label=-1): + super().__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + + if not self._from_sigmoid: + pred = torch.sigmoid(pred) + + loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \ + / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8) + + return loss + + +class SigmoidBinaryCrossEntropyLoss(nn.Module): + def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): + super(SigmoidBinaryCrossEntropyLoss, self).__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + label = torch.where(sample_weight, label, torch.zeros_like(label)) + + if not self._from_sigmoid: + loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) + else: + eps = 1e-12 + loss = -(torch.log(pred + eps) * label + + torch.log(1. - pred + eps) * (1. - label)) + + loss = self._weight * (loss * sample_weight) + return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py new file mode 100644 index 0000000..a572dcd --- /dev/null +++ b/isegm/model/metrics.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +from isegm.utils import misc + + +class TrainMetric(object): + def __init__(self, pred_outputs, gt_outputs): + self.pred_outputs = pred_outputs + self.gt_outputs = gt_outputs + + def update(self, *args, **kwargs): + raise NotImplementedError + + def get_epoch_value(self): + raise NotImplementedError + + def reset_epoch_stats(self): + raise NotImplementedError + + def log_states(self, sw, tag_prefix, global_step): + pass + + @property + def name(self): + return type(self).__name__ + + +class AdaptiveIoU(TrainMetric): + def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, + ignore_label=-1, from_logits=True, + pred_output='instances', gt_output='instances'): + super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) + self._ignore_label = ignore_label + self._from_logits = from_logits + self._iou_thresh = init_thresh + self._thresh_step = thresh_step + self._thresh_beta = thresh_beta + self._iou_beta = iou_beta + self._ema_iou = 0.0 + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def update(self, pred, gt): + gt_mask = gt > 0.5 + if self._from_logits: + pred = torch.sigmoid(pred) + + gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() + if np.all(gt_mask_area == 0): + return + + ignore_mask = gt == self._ignore_label + max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() + best_thresh = self._iou_thresh + for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: + temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() + if temp_iou > max_iou: + max_iou = temp_iou + best_thresh = t + + self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh + self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou + self._epoch_iou_sum += max_iou + self._epoch_batch_count += 1 + + def get_epoch_value(self): + if self._epoch_batch_count > 0: + return self._epoch_iou_sum / self._epoch_batch_count + else: + return 0.0 + + def reset_epoch_stats(self): + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def log_states(self, sw, tag_prefix, global_step): + sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) + sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) + + @property + def iou_thresh(self): + return self._iou_thresh + + +def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): + if ignore_mask is not None: + pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) + + reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) + union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + nonzero = union > 0 + + iou = intersection[nonzero] / union[nonzero] + if not keep_ignore: + return iou + else: + result = np.full_like(intersection, -1) + result[nonzero] = iou + return result diff --git a/isegm/model/modeling/basic_blocks.py b/isegm/model/modeling/basic_blocks.py new file mode 100644 index 0000000..13753e8 --- /dev/null +++ b/isegm/model/modeling/basic_blocks.py @@ -0,0 +1,71 @@ +import torch.nn as nn + +from isegm.model import ops + + +class ConvHead(nn.Module): + def __init__(self, out_channels, in_channels=32, num_layers=1, + kernel_size=3, padding=1, + norm_layer=nn.BatchNorm2d): + super(ConvHead, self).__init__() + convhead = [] + + for i in range(num_layers): + convhead.extend([ + nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), + nn.ReLU(), + norm_layer(in_channels) if norm_layer is not None else nn.Identity() + ]) + convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + self.convhead = nn.Sequential(*convhead) + + def forward(self, *inputs): + return self.convhead(inputs[0]) + + +class SepConvHead(nn.Module): + def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, + kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, + norm_layer=nn.BatchNorm2d): + super(SepConvHead, self).__init__() + + sepconvhead = [] + + for i in range(num_layers): + sepconvhead.append( + SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, + out_channels=mid_channels, + dw_kernel=kernel_size, dw_padding=padding, + norm_layer=norm_layer, activation='relu') + ) + if dropout_ratio > 0 and dropout_indx == i: + sepconvhead.append(nn.Dropout(dropout_ratio)) + + sepconvhead.append( + nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) + ) + + self.layers = nn.Sequential(*sepconvhead) + + def forward(self, *inputs): + x = inputs[0] + + return self.layers(x) + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, + activation=None, use_bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + _activation = ops.select_activation_function(activation) + self.body = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, + padding=dw_padding, bias=use_bias, groups=in_channels), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + _activation() + ) + + def forward(self, x): + return self.body(x) diff --git a/isegm/model/modeling/deeplab_v3.py b/isegm/model/modeling/deeplab_v3.py new file mode 100644 index 0000000..8219a4e --- /dev/null +++ b/isegm/model/modeling/deeplab_v3.py @@ -0,0 +1,176 @@ +from contextlib import ExitStack + +import torch +from torch import nn +import torch.nn.functional as F + +from .basic_blocks import SeparableConv2d +from .resnet import ResNetBackbone +from isegm.model import ops + + +class DeepLabV3Plus(nn.Module): + def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, + backbone_norm_layer=None, + ch=256, + project_dropout=0.5, + inference_mode=False, + **kwargs): + super(DeepLabV3Plus, self).__init__() + if backbone_norm_layer is None: + backbone_norm_layer = norm_layer + + self.backbone_name = backbone + self.norm_layer = norm_layer + self.backbone_norm_layer = backbone_norm_layer + self.inference_mode = False + self.ch = ch + self.aspp_in_channels = 2048 + self.skip_project_in_channels = 256 # layer 1 out_channels + + self._kwargs = kwargs + if backbone == 'resnet34': + self.aspp_in_channels = 512 + self.skip_project_in_channels = 64 + + self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, + norm_layer=self.backbone_norm_layer, **kwargs) + + self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, + norm_layer=self.norm_layer) + self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) + self.aspp = _ASPP(in_channels=self.aspp_in_channels, + atrous_rates=[12, 24, 36], + out_channels=ch, + project_dropout=project_dropout, + norm_layer=self.norm_layer) + + if inference_mode: + self.set_prediction_mode() + + def load_pretrained_weights(self): + pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, + norm_layer=self.backbone_norm_layer, **self._kwargs) + backbone_state_dict = self.backbone.state_dict() + pretrained_state_dict = pretrained.state_dict() + + backbone_state_dict.update(pretrained_state_dict) + self.backbone.load_state_dict(backbone_state_dict) + + if self.inference_mode: + for param in self.backbone.parameters(): + param.requires_grad = False + + def set_prediction_mode(self): + self.inference_mode = True + self.eval() + + def forward(self, x, additional_features=None): + with ExitStack() as stack: + if self.inference_mode: + stack.enter_context(torch.no_grad()) + + c1, _, c3, c4 = self.backbone(x, additional_features) + c1 = self.skip_project(c1) + + x = self.aspp(c4) + x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + x = self.head(x) + + return x, + + +class _SkipProject(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): + super(_SkipProject, self).__init__() + _activation = ops.select_activation_function("relu") + + self.skip_project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + _activation() + ) + + def forward(self, x): + return self.skip_project(x) + + +class _DeepLabHead(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): + super(_DeepLabHead, self).__init__() + + self.block = nn.Sequential( + SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) + ) + + def forward(self, x): + return self.block(x) + + +class _ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates, out_channels=256, + project_dropout=0.5, norm_layer=nn.BatchNorm2d): + super(_ASPP, self).__init__() + + b0 = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) + b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) + b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) + b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) + + self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) + + project = [ + nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ] + if project_dropout > 0: + project.append(nn.Dropout(project_dropout)) + self.project = nn.Sequential(*project) + + def forward(self, x): + x = torch.cat([block(x) for block in self.concurent], dim=1) + + return self.project(x) + + +class _AsppPooling(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer): + super(_AsppPooling, self).__init__() + + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + def forward(self, x): + pool = self.gap(x) + return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) + + +def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): + block = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, padding=atrous_rate, + dilation=atrous_rate, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + return block diff --git a/isegm/model/modeling/hrnet_ocr.py b/isegm/model/modeling/hrnet_ocr.py new file mode 100644 index 0000000..d386ee0 --- /dev/null +++ b/isegm/model/modeling/hrnet_ocr.py @@ -0,0 +1,416 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionNet, self).__init__() + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + last_inp_channels = np.int(np.sum(pre_stage_channels)) + if self.ocr_width > 0: + ocr_mid_channels = 2 * self.ocr_width + ocr_key_channels = self.ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners) + self.cls_head = nn.Conv2d( + ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) + + self.aux_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=1, stride=1, padding=0), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + else: + self.cls_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, additional_features=None): + feats = self.compute_hrnet_feats(x, additional_features) + if self.ocr_width > 0: + out_aux = self.aux_head(feats) + feats = self.conv3x3_ocr(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + out = self.cls_head(feats) + return [out, out_aux] + else: + return [self.cls_head(feats), None] + + def compute_hrnet_feats(self, x, additional_features): + x = self.compute_pre_stage_features(x, additional_features) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + return self.aggregate_hrnet_features(x) + + def compute_pre_stage_features(self, x, additional_features): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + additional_features + x = self.conv2(x) + x = self.bn2(x) + return self.relu(x) + + def aggregate_hrnet_features(self, x): + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in + pretrained_dict.items()} + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) diff --git a/isegm/model/modeling/ocr.py b/isegm/model/modeling/ocr.py new file mode 100644 index 0000000..df3b4f6 --- /dev/null +++ b/isegm/model/modeling/ocr.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) \ + .permute(0, 2, 1).unsqueeze(3) # batch x k x c + return ocr_context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, + norm_layer, align_corners) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class ObjectAttentionBlock2D(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(ObjectAttentionBlock2D, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.align_corners = align_corners + + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), + mode='bilinear', align_corners=self.align_corners) + + return context diff --git a/isegm/model/modeling/resnet.py b/isegm/model/modeling/resnet.py new file mode 100644 index 0000000..65fe949 --- /dev/null +++ b/isegm/model/modeling/resnet.py @@ -0,0 +1,43 @@ +import torch +from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s + + +class ResNetBackbone(torch.nn.Module): + def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): + super(ResNetBackbone, self).__init__() + + if backbone == 'resnet34': + pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet50': + pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet101': + pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet152': + pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + else: + raise RuntimeError(f'unknown backbone: {backbone}') + + self.conv1 = pretrained.conv1 + self.bn1 = pretrained.bn1 + self.relu = pretrained.relu + self.maxpool = pretrained.maxpool + self.layer1 = pretrained.layer1 + self.layer2 = pretrained.layer2 + self.layer3 = pretrained.layer3 + self.layer4 = pretrained.layer4 + + def forward(self, x, additional_features=None): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + torch.nn.functional.pad(additional_features, + [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], + mode='constant', value=0) + x = self.maxpool(x) + c1 = self.layer1(x) + c2 = self.layer2(c1) + c3 = self.layer3(c2) + c4 = self.layer4(c3) + + return c1, c2, c3, c4 diff --git a/isegm/model/modeling/resnetv1b.py b/isegm/model/modeling/resnetv1b.py new file mode 100644 index 0000000..4ad24ce --- /dev/null +++ b/isegm/model/modeling/resnetv1b.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + + +class BasicBlockV1b(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BasicBlockV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn1 = norm_layer(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=previous_dilation, dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class BottleneckV1b(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BottleneckV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class ResNetV1b(nn.Module): + """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`nn.BatchNorm2d`) + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): + self.inplanes = stem_width*2 if deep_stem else 64 + super(ResNetV1b, self).__init__() + if not deep_stem: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, + norm_layer=norm_layer) + if dilated: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, + avg_down=avg_down, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Linear(512 * block.expansion, classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, + avg_down=False, norm_layer=nn.BatchNorm2d): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = [] + if avg_down: + if dilation == 1: + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + else: + downsample.append( + nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + ) + downsample.extend([ + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=1, bias=False), + norm_layer(planes * block.expansion) + ]) + downsample = nn.Sequential(*downsample) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + norm_layer(planes * block.expansion) + ) + + layers = [] + if dilation in (1, 2): + layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x + + +def _safe_state_dict_filtering(orig_dict, model_dict_keys): + filtered_orig_dict = {} + for k, v in orig_dict.items(): + if k in model_dict_keys: + filtered_orig_dict[k] = v + else: + print(f"[ERROR] Failed to load <{k}> in backbone") + return filtered_orig_dict + + +def resnet34_v1b(pretrained=False, **kwargs): + model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet50_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet101_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model diff --git a/isegm/model/modifiers.py b/isegm/model/modifiers.py new file mode 100644 index 0000000..0462218 --- /dev/null +++ b/isegm/model/modifiers.py @@ -0,0 +1,11 @@ + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult diff --git a/isegm/model/ops.py b/isegm/model/ops.py new file mode 100644 index 0000000..9be9c73 --- /dev/null +++ b/isegm/model/ops.py @@ -0,0 +1,116 @@ +import torch +from torch import nn as nn +import numpy as np +import isegm.model.initializer as initializer + + +def select_activation_function(activation): + if isinstance(activation, str): + if activation.lower() == 'relu': + return nn.ReLU + elif activation.lower() == 'softplus': + return nn.Softplus + else: + raise ValueError(f"Unknown activation type {activation}") + elif isinstance(activation, nn.Module): + return activation + else: + raise ValueError(f"Unknown activation type {activation}") + + +class BilinearConvTranspose2d(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, scale, groups=1): + kernel_size = 2 * scale - scale % 2 + self.scale = scale + + super().__init__( + in_channels, out_channels, + kernel_size=kernel_size, + stride=scale, + padding=1, + groups=groups, + bias=False) + + self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) + + +class DistMaps(nn.Module): + def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): + super(DistMaps, self).__init__() + self.spatial_scale = spatial_scale + self.norm_radius = norm_radius + self.cpu_mode = cpu_mode + self.use_disks = use_disks + if self.cpu_mode: + from isegm.utils.cython import get_dist_maps + self._get_dist_maps = get_dist_maps + + def get_coord_features(self, points, batchsize, rows, cols): + if self.cpu_mode: + coords = [] + for i in range(batchsize): + norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius + coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, + norm_delimeter)) + coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + else: + num_points = points.shape[1] // 2 + points = points.view(-1, points.size(2)) + points, points_order = torch.split(points, [2, 1], dim=1) + + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + + coord_rows, coord_cols = torch.meshgrid(row_array, col_array) + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + + add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) + coords.add_(-add_xy) + if not self.use_disks: + coords.div_(self.norm_radius * self.spatial_scale) + coords.mul_(coords) + + coords[:, 0] += coords[:, 1] + coords = coords[:, :1] + + coords[invalid_points, :, :, :] = 1e6 + + coords = coords.view(-1, num_points, 1, rows, cols) + coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w + coords = coords.view(-1, 2, rows, cols) + + if self.use_disks: + coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() + else: + coords.sqrt_().mul_(2).tanh_() + + return coords + + def forward(self, x, coords): + return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) + + +class ScaleLayer(nn.Module): + def __init__(self, init_value=1.0, lr_mult=1): + super().__init__() + self.lr_mult = lr_mult + self.scale = nn.Parameter( + torch.full((1,), init_value / lr_mult, dtype=torch.float32) + ) + + def forward(self, x): + scale = torch.abs(self.scale * self.lr_mult) + return x * scale + + +class BatchImageNormalize: + def __init__(self, mean, std, dtype=torch.float): + self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] + self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] + + def __call__(self, tensor): + tensor = tensor.clone() + + tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) + return tensor diff --git a/isegm/utils/cython/__init__.py b/isegm/utils/cython/__init__.py new file mode 100644 index 0000000..eb66bdb --- /dev/null +++ b/isegm/utils/cython/__init__.py @@ -0,0 +1,2 @@ +# noinspection PyUnresolvedReferences +from .dist_maps import get_dist_maps \ No newline at end of file diff --git a/isegm/utils/cython/_get_dist_maps.pyx b/isegm/utils/cython/_get_dist_maps.pyx new file mode 100644 index 0000000..779a7f0 --- /dev/null +++ b/isegm/utils/cython/_get_dist_maps.pyx @@ -0,0 +1,63 @@ +import numpy as np +cimport cython +cimport numpy as np +from libc.stdlib cimport malloc, free + +ctypedef struct qnode: + int row + int col + int layer + int orig_row + int orig_col + +@cython.infer_types(True) +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, + int height, int width, float norm_delimeter): + cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ + np.full((2, height, width), 1e6, dtype=np.float32, order="C") + + cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] + cdef int i, j, x, y, dx, dy + cdef qnode v + cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) + cdef int qhead = 0, qtail = -1 + cdef float ndist + + for i in range(points.shape[0]): + x, y = round(points[i, 0]), round(points[i, 1]) + if x >= 0: + qtail += 1 + q[qtail].row = x + q[qtail].col = y + q[qtail].orig_row = x + q[qtail].orig_col = y + if i >= points.shape[0] / 2: + q[qtail].layer = 1 + else: + q[qtail].layer = 0 + dist_maps[q[qtail].layer, x, y] = 0 + + while qtail - qhead + 1 > 0: + v = q[qhead] + qhead += 1 + + for k in range(4): + x = v.row + dxy[2 * k] + y = v.col + dxy[2 * k + 1] + + ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 + if (x >= 0 and y >= 0 and x < height and y < width and + dist_maps[v.layer, x, y] > ndist): + qtail += 1 + q[qtail].orig_col = v.orig_col + q[qtail].orig_row = v.orig_row + q[qtail].layer = v.layer + q[qtail].row = x + q[qtail].col = y + dist_maps[v.layer, x, y] = ndist + + free(q) + return dist_maps diff --git a/isegm/utils/cython/_get_dist_maps.pyxbld b/isegm/utils/cython/_get_dist_maps.pyxbld new file mode 100644 index 0000000..bd44517 --- /dev/null +++ b/isegm/utils/cython/_get_dist_maps.pyxbld @@ -0,0 +1,7 @@ +import numpy + +def make_ext(modname, pyxfilename): + from distutils.extension import Extension + return Extension(modname, [pyxfilename], + include_dirs=[numpy.get_include()], + extra_compile_args=['-O3'], language='c++') diff --git a/isegm/utils/cython/dist_maps.py b/isegm/utils/cython/dist_maps.py new file mode 100644 index 0000000..8ffa1e3 --- /dev/null +++ b/isegm/utils/cython/dist_maps.py @@ -0,0 +1,3 @@ +import pyximport; pyximport.install(pyximport=True, language_level=3) +# noinspection PyUnresolvedReferences +from ._get_dist_maps import get_dist_maps \ No newline at end of file diff --git a/isegm/utils/distributed.py b/isegm/utils/distributed.py new file mode 100644 index 0000000..a1e48f5 --- /dev/null +++ b/isegm/utils/distributed.py @@ -0,0 +1,67 @@ +import torch +from torch import distributed as dist +from torch.utils import data + + +def get_rank(): + if not dist.is_available() or not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize(): + if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1: + return + dist.barrier() + + +def get_world_size(): + if not dist.is_available() or not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in loss_dict.keys(): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses + + +def get_sampler(dataset, shuffle, distributed): + if distributed: + return data.distributed.DistributedSampler(dataset, shuffle=shuffle) + + if shuffle: + return data.RandomSampler(dataset) + else: + return data.SequentialSampler(dataset) + + +def get_dp_wrapper(distributed): + class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + return DPWrapper diff --git a/isegm/utils/exp.py b/isegm/utils/exp.py new file mode 100644 index 0000000..1ff63cc --- /dev/null +++ b/isegm/utils/exp.py @@ -0,0 +1,187 @@ +import os +import sys +import shutil +import pprint +from pathlib import Path +from datetime import datetime + +import yaml +import torch +from easydict import EasyDict as edict + +from .log import logger, add_logging +from .distributed import synchronize, get_world_size + + +def init_experiment(args, model_name): + model_path = Path(args.model_path) + ftree = get_model_family_tree(model_path, model_name=model_name) + + if ftree is None: + print('Models can only be located in the "models" directory in the root of the repository') + sys.exit(1) + + cfg = load_config(model_path) + update_config(cfg, args) + + cfg.distributed = args.distributed + cfg.local_rank = args.local_rank + if cfg.distributed: + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.workers > 0: + torch.multiprocessing.set_start_method('forkserver', force=True) + + experiments_path = Path(cfg.EXPS_PATH) + exp_parent_path = experiments_path / '/'.join(ftree) + exp_parent_path.mkdir(parents=True, exist_ok=True) + + if cfg.resume_exp: + exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) + else: + last_exp_indx = find_last_exp_indx(exp_parent_path) + exp_name = f'{last_exp_indx:03d}' + if cfg.exp_name: + exp_name += '_' + cfg.exp_name + exp_path = exp_parent_path / exp_name + synchronize() + if cfg.local_rank == 0: + exp_path.mkdir(parents=True) + + cfg.EXP_PATH = exp_path + cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' + cfg.VIS_PATH = exp_path / 'vis' + cfg.LOGS_PATH = exp_path / 'logs' + + if cfg.local_rank == 0: + cfg.LOGS_PATH.mkdir(exist_ok=True) + cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) + cfg.VIS_PATH.mkdir(exist_ok=True) + + dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py')) + if args.temp_model_path: + shutil.copy(args.temp_model_path, dst_script_path) + os.remove(args.temp_model_path) + else: + shutil.copy(model_path, dst_script_path) + + synchronize() + + if cfg.gpus != '': + gpu_ids = [int(id) for id in cfg.gpus.split(',')] + else: + gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) + cfg.gpus = ','.join([str(id) for id in gpu_ids]) + + cfg.gpu_ids = gpu_ids + cfg.ngpus = len(gpu_ids) + cfg.multi_gpu = cfg.ngpus > 1 + + if cfg.distributed: + cfg.device = torch.device('cuda') + cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] + torch.cuda.set_device(cfg.gpu_ids[0]) + else: + if cfg.multi_gpu: + os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus + ngpus = torch.cuda.device_count() + assert ngpus == cfg.ngpus + cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}') + + if cfg.local_rank == 0: + add_logging(cfg.LOGS_PATH, prefix='train_') + logger.info(f'Number of GPUs: {cfg.ngpus}') + if cfg.distributed: + logger.info(f'Multi-Process Multi-GPU Distributed Training') + + logger.info('Run experiment with config:') + logger.info(pprint.pformat(cfg, indent=4)) + + return cfg + + +def get_model_family_tree(model_path, terminate_name='models', model_name=None): + if model_name is None: + model_name = model_path.stem + family_tree = [model_name] + for x in model_path.parents: + if x.stem == terminate_name: + break + family_tree.append(x.stem) + else: + return None + + return family_tree[::-1] + + +def find_last_exp_indx(exp_parent_path): + indx = 0 + for x in exp_parent_path.iterdir(): + if not x.is_dir(): + continue + + exp_name = x.stem + if exp_name[:3].isnumeric(): + indx = max(indx, int(exp_name[:3]) + 1) + + return indx + + +def find_resume_exp(exp_parent_path, exp_pattern): + candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*')) + if len(candidates) == 0: + print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"') + sys.exit(1) + elif len(candidates) > 1: + print('More than one experiment found:') + for x in candidates: + print(x) + sys.exit(1) + else: + exp_path = candidates[0] + print(f'Continue with experiment "{exp_path}"') + + return exp_path + + +def update_config(cfg, args): + for param_name, value in vars(args).items(): + if param_name.lower() in cfg or param_name.upper() in cfg: + continue + cfg[param_name] = value + + +def load_config(model_path): + model_name = model_path.stem + config_path = model_path.parent / (model_name + '.yml') + + if config_path.exists(): + cfg = load_config_file(config_path) + else: + cfg = dict() + + cwd = Path.cwd() + config_parent = config_path.parent.absolute() + while len(config_parent.parents) > 0: + config_path = config_parent / 'config.yml' + + if config_path.exists(): + local_config = load_config_file(config_path, model_name=model_name) + cfg.update({k: v for k, v in local_config.items() if k not in cfg}) + + if config_parent.absolute() == cwd: + break + config_parent = config_parent.parent + + return edict(cfg) + + +def load_config_file(config_path, model_name=None, return_edict=False): + with open(config_path, 'r') as f: + cfg = yaml.safe_load(f) + + if 'SUBCONFIGS' in cfg: + if model_name is not None and model_name in cfg['SUBCONFIGS']: + cfg.update(cfg['SUBCONFIGS'][model_name]) + del cfg['SUBCONFIGS'] + + return edict(cfg) if return_edict else cfg diff --git a/isegm/utils/exp_imports/default.py b/isegm/utils/exp_imports/default.py new file mode 100644 index 0000000..e78e21c --- /dev/null +++ b/isegm/utils/exp_imports/default.py @@ -0,0 +1,16 @@ +import torch +from functools import partial +from easydict import EasyDict as edict +from albumentations import * + +from isegm.data.datasets import * +from isegm.model.losses import * +from isegm.data.transforms import * +from isegm.engine.trainer import ISTrainer +from isegm.model.metrics import AdaptiveIoU +from isegm.data.points_sampler import MultiPointSampler +from isegm.utils.log import logger +from isegm.model import initializer + +from isegm.model.is_hrnet_model import HRNetModel +from isegm.model.is_deeplab_model import DeeplabModel \ No newline at end of file diff --git a/isegm/utils/log.py b/isegm/utils/log.py new file mode 100644 index 0000000..1f9f8bd --- /dev/null +++ b/isegm/utils/log.py @@ -0,0 +1,97 @@ +import io +import time +import logging +from datetime import datetime + +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +LOGGER_NAME = 'root' +LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' + +handler = logging.StreamHandler() + +logger = logging.getLogger(LOGGER_NAME) +logger.setLevel(logging.INFO) +logger.addHandler(handler) + + +def add_logging(logs_path, prefix): + log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' + stdout_log_path = logs_path / log_name + + fh = logging.FileHandler(str(stdout_log_path)) + formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', + datefmt=LOGGER_DATEFMT) + fh.setFormatter(formatter) + logger.addHandler(fh) + + +class TqdmToLogger(io.StringIO): + logger = None + level = None + buf = '' + + def __init__(self, logger, level=None, mininterval=5): + super(TqdmToLogger, self).__init__() + self.logger = logger + self.level = level or logging.INFO + self.mininterval = mininterval + self.last_time = 0 + + def write(self, buf): + self.buf = buf.strip('\r\n\t ') + + def flush(self): + if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: + self.logger.log(self.level, self.buf) + self.last_time = time.time() + + +class SummaryWriterAvg(SummaryWriter): + def __init__(self, *args, dump_period=20, **kwargs): + super().__init__(*args, **kwargs) + self._dump_period = dump_period + self._avg_scalars = dict() + + def add_scalar(self, tag, value, global_step=None, disable_avg=False): + if disable_avg or isinstance(value, (tuple, list, dict)): + super().add_scalar(tag, np.array(value), global_step=global_step) + else: + if tag not in self._avg_scalars: + self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) + avg_scalar = self._avg_scalars[tag] + avg_scalar.add(value) + + if avg_scalar.is_full(): + super().add_scalar(tag, avg_scalar.value, + global_step=global_step) + avg_scalar.reset() + + +class ScalarAccumulator(object): + def __init__(self, period): + self.sum = 0 + self.cnt = 0 + self.period = period + + def add(self, value): + self.sum += value + self.cnt += 1 + + @property + def value(self): + if self.cnt > 0: + return self.sum / self.cnt + else: + return 0 + + def reset(self): + self.cnt = 0 + self.sum = 0 + + def is_full(self): + return self.cnt >= self.period + + def __len__(self): + return self.cnt diff --git a/isegm/utils/misc.py b/isegm/utils/misc.py new file mode 100644 index 0000000..688c11e --- /dev/null +++ b/isegm/utils/misc.py @@ -0,0 +1,86 @@ +import torch +import numpy as np + +from .log import logger + + +def get_dims_with_exclusion(dim, exclude=None): + dims = list(range(dim)) + if exclude is not None: + dims.remove(exclude) + + return dims + + +def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False): + if epoch is None: + checkpoint_name = 'last_checkpoint.pth' + else: + checkpoint_name = f'{epoch:03d}.pth' + + if prefix: + checkpoint_name = f'{prefix}_{checkpoint_name}' + + if not checkpoints_path.exists(): + checkpoints_path.mkdir(parents=True) + + checkpoint_path = checkpoints_path / checkpoint_name + if verbose: + logger.info(f'Save checkpoint to {str(checkpoint_path)}') + + net = net.module if multi_gpu else net + torch.save({'state_dict': net.state_dict(), + 'config': net._config}, str(checkpoint_path)) + + +def get_bbox_from_mask(mask): + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return rmin, rmax, cmin, cmax + + +def expand_bbox(bbox, expand_ratio, min_crop_size=None): + rmin, rmax, cmin, cmax = bbox + rcenter = 0.5 * (rmin + rmax) + ccenter = 0.5 * (cmin + cmax) + height = expand_ratio * (rmax - rmin + 1) + width = expand_ratio * (cmax - cmin + 1) + if min_crop_size is not None: + height = max(height, min_crop_size) + width = max(width, min_crop_size) + + rmin = int(round(rcenter - 0.5 * height)) + rmax = int(round(rcenter + 0.5 * height)) + cmin = int(round(ccenter - 0.5 * width)) + cmax = int(round(ccenter + 0.5 * width)) + + return rmin, rmax, cmin, cmax + + +def clamp_bbox(bbox, rmin, rmax, cmin, cmax): + return (max(rmin, bbox[0]), min(rmax, bbox[1]), + max(cmin, bbox[2]), min(cmax, bbox[3])) + + +def get_bbox_iou(b1, b2): + h_iou = get_segments_iou(b1[:2], b2[:2]) + w_iou = get_segments_iou(b1[2:4], b2[2:4]) + return h_iou * w_iou + + +def get_segments_iou(s1, s2): + a, b = s1 + c, d = s2 + intersection = max(0, min(b, d) - max(a, c) + 1) + union = max(1e-6, max(b, d) - min(a, c) + 1) + return intersection / union + + +def get_labels_with_sizes(x): + obj_sizes = np.bincount(x.flatten()) + labels = np.nonzero(obj_sizes)[0].tolist() + labels = [x for x in labels if x != 0] + return labels, obj_sizes[labels].tolist() diff --git a/isegm/utils/serialization.py b/isegm/utils/serialization.py new file mode 100644 index 0000000..c73935b --- /dev/null +++ b/isegm/utils/serialization.py @@ -0,0 +1,107 @@ +from functools import wraps +from copy import deepcopy +import inspect +import torch.nn as nn + + +def serialize(init): + parameters = list(inspect.signature(init).parameters) + + @wraps(init) + def new_init(self, *args, **kwargs): + params = deepcopy(kwargs) + for pname, value in zip(parameters[1:], args): + params[pname] = value + + config = { + 'class': get_classname(self.__class__), + 'params': dict() + } + specified_params = set(params.keys()) + + for pname, param in get_default_params(self.__class__).items(): + if pname not in params: + params[pname] = param.default + + for name, value in list(params.items()): + param_type = 'builtin' + if inspect.isclass(value): + param_type = 'class' + value = get_classname(value) + + config['params'][name] = { + 'type': param_type, + 'value': value, + 'specified': name in specified_params + } + + setattr(self, '_config', config) + init(self, *args, **kwargs) + + return new_init + + +def load_model(config, **kwargs): + model_class = get_class_from_str(config['class']) + model_default_params = get_default_params(model_class) + + model_args = dict() + for pname, param in config['params'].items(): + value = param['value'] + if param['type'] == 'class': + value = get_class_from_str(value) + + if pname not in model_default_params and not param['specified']: + continue + + assert pname in model_default_params + if not param['specified'] and model_default_params[pname].default == value: + continue + model_args[pname] = value + + model_args.update(kwargs) + + return model_class(**model_args) + + +def get_config_repr(config): + config_str = f'Model: {config["class"]}\n' + for pname, param in config['params'].items(): + value = param["value"] + if param['type'] == 'class': + value = value.split('.')[-1] + param_str = f'{pname:<22} = {str(value):<12}' + if not param['specified']: + param_str += ' (default)' + config_str += param_str + '\n' + return config_str + + +def get_default_params(some_class): + params = dict() + for mclass in some_class.mro(): + if mclass is nn.Module or mclass is object: + continue + + mclass_params = inspect.signature(mclass.__init__).parameters + for pname, param in mclass_params.items(): + if param.default != param.empty and pname not in params: + params[pname] = param + + return params + + +def get_classname(cls): + module = cls.__module__ + name = cls.__qualname__ + if module is not None and module != "__builtin__": + name = module + "." + name + return name + + +def get_class_from_str(class_str): + components = class_str.split('.') + mod = __import__('.'.join(components[:-1])) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod diff --git a/isegm/utils/vis.py b/isegm/utils/vis.py new file mode 100644 index 0000000..9790a4c --- /dev/null +++ b/isegm/utils/vis.py @@ -0,0 +1,135 @@ +from functools import lru_cache + +import cv2 +import numpy as np + + +def visualize_instances(imask, bg_color=255, + boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): + num_objects = imask.max() + 1 + palette = get_palette(num_objects) + if bg_color is not None: + palette[0] = bg_color + + result = palette[imask].astype(np.uint8) + if boundaries_color is not None: + boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) + tresult = result.astype(np.float32) + tresult[boundaries_mask] = boundaries_color + tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result + result = tresult.astype(np.uint8) + + return result + + +@lru_cache(maxsize=16) +def get_palette(num_cls): + palette = np.zeros(3 * num_cls, dtype=np.int32) + + for j in range(0, num_cls): + lab = j + i = 0 + + while lab > 0: + palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) + palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) + palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) + i = i + 1 + lab >>= 3 + + return palette.reshape((-1, 3)) + + +def visualize_mask(mask, num_cls): + palette = get_palette(num_cls) + mask[mask == -1] = 0 + + return palette[mask].astype(np.uint8) + + +def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): + proposal_map, colors, candidates = proposals_info + + proposal_map = draw_probmap(proposal_map) + for x, y in candidates: + proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) + + return proposal_map + + +def draw_probmap(x): + return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) + + +def draw_points(image, points, color, radius=3): + image = image.copy() + for p in points: + if p[0] < 0: + continue + if len(p) == 3: + pradius = {0: 8, 1: 6, 2: 4}[p[2]] if p[2] < 3 else 2 + else: + pradius = radius + image = cv2.circle(image, (int(p[1]), int(p[0])), pradius, color, -1) + + return image + + +def draw_instance_map(x, palette=None): + num_colors = x.max() + 1 + if palette is None: + palette = get_palette(num_colors) + + return palette[x].astype(np.uint8) + + +def blend_mask(image, mask, alpha=0.6): + if mask.min() == -1: + mask = mask.copy() + 1 + + imap = draw_instance_map(mask) + result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) + return result + + +def get_boundaries(instances_masks, boundaries_width=1): + boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) + + for obj_id in np.unique(instances_masks.flatten()): + if obj_id == 0: + continue + + obj_mask = instances_masks == obj_id + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) + + obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) + boundaries = np.logical_or(boundaries, obj_boundary) + return boundaries + + +def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), + neg_color=(255, 0, 0), radius=4): + result = img.copy() + + if mask is not None: + palette = get_palette(np.max(mask) + 1) + rgb_mask = palette[mask.astype(np.uint8)] + + mask_region = (mask > 0).astype(np.uint8) + result = result * (1 - mask_region[:, :, np.newaxis]) + \ + (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ + alpha * rgb_mask + result = result.astype(np.uint8) + + # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) + + if clicks_list is not None and len(clicks_list) > 0: + pos_points = [click.coords for click in clicks_list if click.is_positive] + neg_points = [click.coords for click in clicks_list if not click.is_positive] + + result = draw_points(result, pos_points, pos_color, radius=radius) + result = draw_points(result, neg_points, neg_color, radius=radius) + + return result + diff --git a/models/cocolvis_loss_ablation/hrnet18_ocr64_bce.py b/models/cocolvis_loss_ablation/hrnet18_ocr64_bce.py new file mode 100755 index 0000000..1cb1c36 --- /dev/null +++ b/models/cocolvis_loss_ablation/hrnet18_ocr64_bce.py @@ -0,0 +1,89 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) diff --git a/models/cocolvis_loss_ablation/hrnet18_ocr64_fl.py b/models/cocolvis_loss_ablation/hrnet18_ocr64_fl.py new file mode 100755 index 0000000..5466a27 --- /dev/null +++ b/models/cocolvis_loss_ablation/hrnet18_ocr64_fl.py @@ -0,0 +1,89 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = FocalLoss(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) diff --git a/models/cocolvis_loss_ablation/hrnet18_ocr64_nfl.py b/models/cocolvis_loss_ablation/hrnet18_ocr64_nfl.py new file mode 100755 index 0000000..a8179ec --- /dev/null +++ b/models/cocolvis_loss_ablation/hrnet18_ocr64_nfl.py @@ -0,0 +1,89 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) diff --git a/models/cocolvis_loss_ablation/hrnet18_ocr64_softiou.py b/models/cocolvis_loss_ablation/hrnet18_ocr64_softiou.py new file mode 100755 index 0000000..5568bb2 --- /dev/null +++ b/models/cocolvis_loss_ablation/hrnet18_ocr64_softiou.py @@ -0,0 +1,89 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = SoftIoU() + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) diff --git a/models/iter_mask/hrnet18_cocolvis_itermask_3p.py b/models/iter_mask/hrnet18_cocolvis_itermask_3p.py new file mode 100644 index 0000000..6166a80 --- /dev/null +++ b/models/iter_mask/hrnet18_cocolvis_itermask_3p.py @@ -0,0 +1,91 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'cocolvis_hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5, + with_prev_mask=True) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[200, 220], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 5), (200, 1)], + image_dump_interval=3000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=230) \ No newline at end of file diff --git a/models/iter_mask/hrnet18_sbd_itermask_3p.py b/models/iter_mask/hrnet18_sbd_itermask_3p.py new file mode 100644 index 0000000..31fc7ec --- /dev/null +++ b/models/iter_mask/hrnet18_sbd_itermask_3p.py @@ -0,0 +1,94 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'sbd_hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5, with_prev_mask=True) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.25)), + Flip(), + RandomRotate90(), + ShiftScaleRotate(shift_limit=0.03, scale_limit=0, + rotate_limit=(-3, 3), border_mode=0, p=0.75), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.25)), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = SBDDataset( + cfg.SBD_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.01, + points_sampler=points_sampler, + samples_scores_path='./assets/sbd_samples_weights.pkl', + samples_scores_gamma=1.25 + ) + + valset = SBDDataset( + cfg.SBD_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + points_sampler=points_sampler, + epoch_len=500 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[200, 215], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 5), (100, 1)], + image_dump_interval=3000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=220) diff --git a/models/iter_mask/hrnet18s_cocolvis_itermask_3p.py b/models/iter_mask/hrnet18s_cocolvis_itermask_3p.py new file mode 100644 index 0000000..34b1e3e --- /dev/null +++ b/models/iter_mask/hrnet18s_cocolvis_itermask_3p.py @@ -0,0 +1,91 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'cocolvis_hrnet18s' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=48, small=True, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5, + with_prev_mask=True) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18_SMALL) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[200, 220], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 5), (200, 1)], + image_dump_interval=3000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=230) \ No newline at end of file diff --git a/models/iter_mask/hrnet32_cocolvis_itermask_3p.py b/models/iter_mask/hrnet32_cocolvis_itermask_3p.py new file mode 100644 index 0000000..16253e2 --- /dev/null +++ b/models/iter_mask/hrnet32_cocolvis_itermask_3p.py @@ -0,0 +1,90 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'cocolvis_hrnet32' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=32, ocr_width=128, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5, with_prev_mask=True) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W32) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[200, 220], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 5), (200, 1)], + image_dump_interval=3000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=230) diff --git a/models/noniterative_baselines/hrnet18_ocr64_ade20k.py b/models/noniterative_baselines/hrnet18_ocr64_ade20k.py new file mode 100755 index 0000000..d03f57e --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_ade20k.py @@ -0,0 +1,90 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = ADE20kDataset( + cfg.ADE20K_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + stuff_prob=0.30, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000 + ) + + valset = ADE20kDataset( + cfg.ADE20K_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[100, 115], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=120) diff --git a/models/noniterative_baselines/hrnet18_ocr64_coco.py b/models/noniterative_baselines/hrnet18_ocr64_coco.py new file mode 100755 index 0000000..639a9f7 --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_coco.py @@ -0,0 +1,90 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoDataset( + cfg.COCO_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + stuff_prob=0.30, + epoch_len=30000 + ) + + valset = CocoDataset( + cfg.COCO_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[120, 135], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=140) diff --git a/models/noniterative_baselines/hrnet18_ocr64_cocolvis.py b/models/noniterative_baselines/hrnet18_ocr64_cocolvis.py new file mode 100755 index 0000000..e2bc28d --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_cocolvis.py @@ -0,0 +1,89 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[200, 220], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 5), (200, 1)], + image_dump_interval=3000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=230) diff --git a/models/noniterative_baselines/hrnet18_ocr64_lvis.py b/models/noniterative_baselines/hrnet18_ocr64_lvis.py new file mode 100755 index 0000000..d8ca8b0 --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_lvis.py @@ -0,0 +1,88 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = LvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000 + ) + + valset = LvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) \ No newline at end of file diff --git a/models/noniterative_baselines/hrnet18_ocr64_openimages.py b/models/noniterative_baselines/hrnet18_ocr64_openimages.py new file mode 100755 index 0000000..8548271 --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_openimages.py @@ -0,0 +1,89 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = OpenImagesDataset( + cfg.OPENIMAGES_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000 + ) + + valset = OpenImagesDataset( + cfg.OPENIMAGES_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[120, 135], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=140) diff --git a/models/noniterative_baselines/hrnet18_ocr64_sbd.py b/models/noniterative_baselines/hrnet18_ocr64_sbd.py new file mode 100755 index 0000000..49e0aa0 --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_sbd.py @@ -0,0 +1,88 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = SBDDataset( + cfg.SBD_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + samples_scores_path='./assets/sbd_samples_weights.pkl', + samples_scores_gamma=1.25 + ) + + valset = SBDDataset( + cfg.SBD_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[100, 115], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=120) diff --git a/models/noniterative_baselines/hrnet18_ocr64_vocsbd.py b/models/noniterative_baselines/hrnet18_ocr64_vocsbd.py new file mode 100755 index 0000000..44ee3ef --- /dev/null +++ b/models/noniterative_baselines/hrnet18_ocr64_vocsbd.py @@ -0,0 +1,93 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'hrnet18' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = HRNetModel(width=18, ocr_width=64, with_aux_output=True, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights(cfg.IMAGENET_PRETRAINED_MODELS.HRNETV2_W18) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + loss_cfg.instance_aux_loss = SigmoidBinaryCrossEntropyLoss() + loss_cfg.instance_aux_loss_weight = 0.4 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = ComposeDataset( + [ + PascalVocDataset(cfg.PASCALVOC_PATH, split='train'), + SBDDataset( + cfg.SBD_PATH, + split='train', + samples_scores_path='./assets/sbd_samples_weights.pkl', + samples_scores_gamma=1.25 + ) + ], + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler + ) + + valset = SBDDataset( + cfg.SBD_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[100, 115], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=120) diff --git a/models/noniterative_baselines/r34_dh128_ade20k.py b/models/noniterative_baselines/r34_dh128_ade20k.py new file mode 100755 index 0000000..23cc7a0 --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_ade20k.py @@ -0,0 +1,88 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = ADE20kDataset( + cfg.ADE20K_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + stuff_prob=0.30, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000 + ) + + valset = ADE20kDataset( + cfg.ADE20K_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[100, 115], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=120) diff --git a/models/noniterative_baselines/r34_dh128_coco.py b/models/noniterative_baselines/r34_dh128_coco.py new file mode 100755 index 0000000..dcb5299 --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_coco.py @@ -0,0 +1,88 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoDataset( + cfg.COCO_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + stuff_prob=0.30, + epoch_len=30000 + ) + + valset = CocoDataset( + cfg.COCO_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[120, 135], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=140) diff --git a/models/noniterative_baselines/r34_dh128_cocolvis.py b/models/noniterative_baselines/r34_dh128_cocolvis.py new file mode 100755 index 0000000..1e094e1 --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_cocolvis.py @@ -0,0 +1,87 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=80, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) diff --git a/models/noniterative_baselines/r34_dh128_lvis.py b/models/noniterative_baselines/r34_dh128_lvis.py new file mode 100755 index 0000000..4d1bc9c --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_lvis.py @@ -0,0 +1,86 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = LvisDataset( + cfg.LVIS_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000 + ) + + valset = LvisDataset( + cfg.LVIS_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[140, 155], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=160) diff --git a/models/noniterative_baselines/r34_dh128_openimages.py b/models/noniterative_baselines/r34_dh128_openimages.py new file mode 100755 index 0000000..eb727bd --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_openimages.py @@ -0,0 +1,87 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = OpenImagesDataset( + cfg.OPENIMAGES_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000 + ) + + valset = OpenImagesDataset( + cfg.OPENIMAGES_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[120, 135], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=140) diff --git a/models/noniterative_baselines/r34_dh128_sbd.py b/models/noniterative_baselines/r34_dh128_sbd.py new file mode 100755 index 0000000..a93830b --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_sbd.py @@ -0,0 +1,86 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = SBDDataset( + cfg.SBD_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler, + samples_scores_path='./assets/sbd_samples_weights.pkl', + samples_scores_gamma=1.25 + ) + + valset = SBDDataset( + cfg.SBD_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[100, 115], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=120) diff --git a/models/noniterative_baselines/r34_dh128_vocsbd.py b/models/noniterative_baselines/r34_dh128_vocsbd.py new file mode 100755 index 0000000..f2508d6 --- /dev/null +++ b/models/noniterative_baselines/r34_dh128_vocsbd.py @@ -0,0 +1,91 @@ +from isegm.utils.exp_imports.default import * +MODEL_NAME = 'resnet34' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (320, 480) + model_cfg.num_max_points = 24 + + model = DeeplabModel(backbone='resnet34', deeplab_ch=128, aspp_dropout=0.20, use_leaky_relu=True, + use_rgb_conv=False, use_disks=True, norm_radius=5) + + model.to(cfg.device) + model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) + model.feature_extractor.load_pretrained_weights() + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 28 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.8, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = ComposeDataset( + [ + PascalVocDataset(cfg.PASCALVOC_PATH, split='train'), + SBDDataset( + cfg.SBD_PATH, + split='train', + samples_scores_path='./assets/sbd_samples_weights.pkl', + samples_scores_gamma=1.25 + ) + ], + augmentator=train_augmentator, + min_object_area=80, + keep_background_prob=0.0, + points_sampler=points_sampler + ) + + valset = SBDDataset( + cfg.SBD_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + ) + + optimizer_params = { + 'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[100, 115], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=5, + image_dump_interval=2000, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points) + trainer.run(num_epochs=120) diff --git a/notebooks/colab_test_any_model.ipynb b/notebooks/colab_test_any_model.ipynb new file mode 100644 index 0000000..1e66762 --- /dev/null +++ b/notebooks/colab_test_any_model.ipynb @@ -0,0 +1,221 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "colab_test_any_model.ipynb", + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyM+L4AdVeLMu90VmvOnwlTB" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "zKd4Z6wz8WHv" + }, + "source": [ + "### Clone repository, download models and data, install necessary packages" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "aOhRdxx01YH0" + }, + "source": [ + "!git clone -q https://github.com/saic-vul/ritm_interactive_segmentation\n", + "\n", + "URL_PREFIX = \"https://github.com/saic-vul/ritm_interactive_segmentation/releases/download/v1.0\"\n", + "DATA_FOLDER = \"./ritm_interactive_segmentation/datasets\" \n", + "WEIGHTS_FOLDER = \"./ritm_interactive_segmentation/weights\"\n", + "\n", + "!mkdir -p {DATA_FOLDER}\n", + "!mkdir -p {WEIGHTS_FOLDER}\n", + "\n", + "# CHOOSE MODEL HERE\n", + "# possible choices are: coco_lvis_h18s_itermask, coco_lvis_h18_baseline, coco_lvis_h18_itermask,\n", + "# coco_lvis_h18_itermask, sbd_h18_itermask\n", + "MODEL_NAME = \"coco_lvis_h18s_itermask\"\n", + "WEIGHTS_URL = f\"{URL_PREFIX}/{MODEL_NAME}.pth\"\n", + "!wget -q -P {WEIGHTS_FOLDER} {WEIGHTS_URL}\n", + "\n", + "for dataset in ['GrabCut', 'Berkeley', 'DAVIS', 'COCO_MVal']:\n", + " dataset_url = f\"{URL_PREFIX}/{dataset}.zip\"\n", + " dataset_path = f\"{DATA_FOLDER}/{dataset}.zip\"\n", + " !wget -q -O {dataset_path} {dataset_url}\n", + " !unzip -q {dataset_path} -d {DATA_FOLDER}\n", + " !rm {dataset_path}\n", + "\n", + "!pip3 install -q -r ./ritm_interactive_segmentation/requirements.txt\n", + "%cd ritm_interactive_segmentation/" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9cwOspcT8gDb" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "E8RethT83nRc" + }, + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import sys\n", + "import torch\n", + "import numpy as np\n", + "\n", + "sys.path.insert(0, './')\n", + "\n", + "from isegm.utils import vis, exp\n", + "from isegm.inference import utils\n", + "from isegm.inference.evaluation import evaluate_dataset, evaluate_sample\n", + "\n", + "device = torch.device('cuda:0')\n", + "cfg = exp.load_config_file('./config.yml', return_edict=True)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9T3CWGZc8kZt" + }, + "source": [ + "### Init dataset" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qvokqHDgCsHi" + }, + "source": [ + "# Possible choices: 'GrabCut', 'Berkeley', 'DAVIS', 'SBD'\n", + "DATASET = 'GrabCut'\n", + "dataset = utils.get_dataset(DATASET, cfg)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_DZ60HVCsZi" + }, + "source": [ + "### Init model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dK-1alE08m8m" + }, + "source": [ + "from isegm.inference.predictors import get_predictor\n", + "\n", + "EVAL_MAX_CLICKS = 20\n", + "MODEL_THRESH = 0.49\n", + "\n", + "checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, 'resnet34_dh128_sbd')\n", + "model = utils.load_is_model(checkpoint_path, device)\n", + "\n", + "# Possible choices: 'NoBRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C', 'RGB-BRS', 'DistMap-BRS'\n", + "brs_mode = 'f-BRS-B'\n", + "predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3xpSOo-e8pyt" + }, + "source": [ + "### Dataset evaluation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4SINqtTo8o-n" + }, + "source": [ + "TARGET_IOU = 0.9\n", + "\n", + "all_ious, elapsed_time = evaluate_dataset(dataset, predictor, pred_thr=MODEL_THRESH, \n", + " max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)\n", + "mean_spc, mean_spi = utils.get_time_metrics(all_ious, elapsed_time)\n", + "noc_list, over_max_list = utils.compute_noc_metric(all_ious,\n", + " iou_thrs=[0.8, 0.85, 0.9],\n", + " max_clicks=EVAL_MAX_CLICKS)\n", + "\n", + "header, table_row = utils.get_results_table(noc_list, over_max_list, brs_mode, DATASET,\n", + " mean_spc, elapsed_time, EVAL_MAX_CLICKS)\n", + "print(header)\n", + "print(table_row)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SP0yyV-08s8d" + }, + "source": [ + "### Single sample eval" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1JHMy1dQ8us1" + }, + "source": [ + "sample_id = 12\n", + "TARGET_IOU = 0.95\n", + "\n", + "sample = dataset.get_sample(sample_id)\n", + "gt_mask = sample.gt_mask\n", + "\n", + "clicks_list, ious_arr, pred = evaluate_sample(sample.image, gt_mask, predictor, \n", + " pred_thr=MODEL_THRESH, \n", + " max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)\n", + "\n", + "pred_mask = pred > MODEL_THRESH\n", + "draw = vis.draw_with_blend_and_clicks(sample.image, mask=pred_mask, clicks_list=clicks_list)\n", + "draw = np.concatenate((draw,\n", + " 255 * pred_mask[:, :, np.newaxis].repeat(3, axis=2),\n", + " 255 * (gt_mask > 0)[:, :, np.newaxis].repeat(3, axis=2)\n", + "), axis=1)\n", + "\n", + "print(ious_arr)\n", + "\n", + "plt.figure(figsize=(20, 30))\n", + "plt.imshow(draw)\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/notebooks/test_any_model.ipynb b/notebooks/test_any_model.ipynb new file mode 100644 index 0000000..e72e6f3 --- /dev/null +++ b/notebooks/test_any_model.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-23T15:40:40.146768Z", + "start_time": "2020-01-23T15:40:39.277344Z" + } + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import sys\n", + "import numpy as np\n", + "import torch\n", + "\n", + "sys.path.insert(0, '..')\n", + "from isegm.utils import vis, exp\n", + "\n", + "from isegm.inference import utils\n", + "from isegm.inference.evaluation import evaluate_dataset, evaluate_sample\n", + "\n", + "device = torch.device('cuda:0')\n", + "cfg = exp.load_config_file('../config.yml', return_edict=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Init dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-23T15:40:40.540120Z", + "start_time": "2020-01-23T15:40:40.535379Z" + } + }, + "outputs": [], + "source": [ + "# Possible choices: 'GrabCut', 'Berkeley', 'DAVIS', 'COCO_MVal', 'SBD'\n", + "DATASET = 'GrabCut'\n", + "dataset = utils.get_dataset(DATASET, cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Init model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-23T15:40:46.953312Z", + "start_time": "2020-01-23T15:40:41.632849Z" + } + }, + "outputs": [], + "source": [ + "from isegm.inference.predictors import get_predictor\n", + "\n", + "EVAL_MAX_CLICKS = 20\n", + "MODEL_THRESH = 0.49\n", + "\n", + "checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, 'coco_lvis_h18s_itermask')\n", + "model = utils.load_is_model(checkpoint_path, device)\n", + "\n", + "# Possible choices: 'NoBRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C', 'RGB-BRS', 'DistMap-BRS'\n", + "brs_mode = 'f-BRS-B'\n", + "predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-23T15:41:05.430871Z", + "start_time": "2020-01-23T15:40:46.956196Z" + } + }, + "outputs": [], + "source": [ + "TARGET_IOU = 0.9\n", + "\n", + "all_ious, elapsed_time = evaluate_dataset(dataset, predictor, pred_thr=MODEL_THRESH, \n", + " max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)\n", + "mean_spc, mean_spi = utils.get_time_metrics(all_ious, elapsed_time)\n", + "noc_list, over_max_list = utils.compute_noc_metric(all_ious,\n", + " iou_thrs=[0.8, 0.85, 0.9],\n", + " max_clicks=EVAL_MAX_CLICKS)\n", + "\n", + "header, table_row = utils.get_results_table(noc_list, over_max_list, brs_mode, DATASET,\n", + " mean_spc, elapsed_time, EVAL_MAX_CLICKS)\n", + "print(header)\n", + "print(table_row)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Single sample eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2019-12-04T10:53:23.817566Z", + "start_time": "2019-12-04T10:53:22.592826Z" + } + }, + "outputs": [], + "source": [ + "sample_id = 12\n", + "TARGET_IOU = 0.95\n", + "\n", + "sample = dataset.get_sample(sample_id)\n", + "gt_mask = sample.gt_mask\n", + "\n", + "clicks_list, ious_arr, pred = evaluate_sample(sample.image, gt_mask, predictor, \n", + " pred_thr=MODEL_THRESH, \n", + " max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)\n", + "\n", + "pred_mask = pred > MODEL_THRESH\n", + "draw = vis.draw_with_blend_and_clicks(sample.image, mask=pred_mask, clicks_list=clicks_list)\n", + "draw = np.concatenate((draw,\n", + " 255 * pred_mask[:, :, np.newaxis].repeat(3, axis=2),\n", + " 255 * (gt_mask > 0)[:, :, np.newaxis].repeat(3, axis=2)\n", + "), axis=1)\n", + "\n", + "print(ious_arr)\n", + "\n", + "plt.figure(figsize=(20, 30))\n", + "plt.imshow(draw)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "294px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f56198e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +scipy +numpy +Cython +scikit-image +opencv-python-headless +Pillow +matplotlib +imgaug +albumentations +graphviz +tqdm +pyyaml +easydict +torch>=1.4.0 +torchvision>=0.5.0 +tensorboard +future +cffi +ninja \ No newline at end of file diff --git a/scripts/annotations_conversion/ade20k.py b/scripts/annotations_conversion/ade20k.py new file mode 100644 index 0000000..01193ad --- /dev/null +++ b/scripts/annotations_conversion/ade20k.py @@ -0,0 +1,83 @@ +import pickle as pkl +from pathlib import Path +from scipy.io import loadmat + +from scripts.annotations_conversion.common import parallel_map + + +ADE20K_STUFF_CLASSES = ['water', 'wall', 'snow', 'sky', 'sea', 'sand', 'road', 'route', 'river', 'path', 'mountain', + 'mount', 'land', 'ground', 'soil', 'hill', 'grass', 'floor', 'flooring', 'field', 'earth', + 'ground', 'fence', 'ceiling', 'wave', 'crosswalk', 'hay bale', 'bridge', 'span', 'building', + 'edifice', 'cabinet', 'cushion', 'curtain', 'drape', 'drapery', 'mantle', 'pall', 'door', + 'fencing', 'house', 'pole', 'seat', 'windowpane', 'window', 'tree', 'towel', 'table', + 'stairs', 'steps', 'streetlight', 'street lamp', 'sofa', 'couch', 'lounge', 'skyscraper', + 'signboard', 'sign', 'sidewalk', 'pavement', 'shrub', 'bush', 'rug', 'carpet'] + + +def worker_annotations_loader(anno_pair, dataset_path): + image_id, folder = anno_pair + n_masks = len(list((dataset_path / folder).glob(f'{image_id}_*.png'))) + + # each image has several layers with instances, + # each layer has mask name and instance_to_class mapping + layers = [{ + 'mask_name': f'{image_id}_{suffix}.png', + 'instance_to_class': {}, + 'object_instances': [], + 'stuff_instances': [] + } for suffix in ['seg'] + [f'parts_{i}' for i in range(1, n_masks)]] + + # parse txt with instance to class mappings + with (dataset_path / folder / (image_id + "_atr.txt")).open('r') as f: + for line in f: + # instance_id layer_n is_occluded class_names class_name_raw attributes + line = line.strip().split('#') + inst_id, layer_n, class_names = int(line[0]), int(line[1]), line[3] + + # there may be more than one class name for each instance + class_names = [name.strip() for name in class_names.split(',')] + + # check if any of classes is stuff + if set(class_names) & set(ADE20K_STUFF_CLASSES): + layers[layer_n]['stuff_instances'].append(inst_id) + else: + layers[layer_n]['object_instances'].append(inst_id) + layers[layer_n]['instance_to_class'][inst_id] = class_names + + return layers + + +def load_and_parse_annotations(dataset_path, dataset_split, n_jobs=1): + dataset_split_folder = 'training' if dataset_split == 'train' else 'validation' + + orig_annotations = loadmat(dataset_path / 'index_ade20k.mat', squeeze_me=True, struct_as_record=True) + image_ids = [image_id.split('.')[0] for image_id in orig_annotations['index'].item()[0] + if dataset_split in image_id] + folders = [Path(folder).relative_to('ADE20K_2016_07_26') for folder in orig_annotations['index'].item()[1] + if dataset_split_folder in folder] + + # list of dictionaries with filename and instance to class mapping + all_layers = parallel_map(list(zip(image_ids, folders)), worker_annotations_loader, n_jobs=n_jobs, + use_kwargs=False, const_args={ + 'dataset_path': dataset_path + }) + + return image_ids, folders, all_layers + + +def create_annotations(dataset_path, dataset_split='train', n_jobs=1): + anno_path = dataset_path / f'{dataset_split}-annotations-object-segmentation.pkl' + image_ids, folders, all_layers = load_and_parse_annotations(dataset_path, dataset_split, n_jobs=n_jobs) + + # create dictionary with annotations + annotations = {} + for index, image_id in enumerate(image_ids): + annotations[image_id] = { + 'folder': folders[index], + 'layers': all_layers[index] + } + + with anno_path.open('wb') as f: + pkl.dump(annotations, f) + + return annotations diff --git a/scripts/annotations_conversion/coco_lvis.py b/scripts/annotations_conversion/coco_lvis.py new file mode 100644 index 0000000..1c9b9e1 --- /dev/null +++ b/scripts/annotations_conversion/coco_lvis.py @@ -0,0 +1,140 @@ +import cv2 +import pickle +import numpy as np +from pathlib import Path +from tqdm import tqdm + +from isegm.data.datasets import LvisDataset, CocoDataset +from isegm.utils.misc import get_bbox_from_mask, get_bbox_iou +from scripts.annotations_conversion.common import get_masks_hierarchy, get_iou, encode_masks + + +def create_annotations(lvis_path: Path, coco_path: Path, dataset_split='train', min_object_area=80): + lvis_dataset = LvisDataset(lvis_path, split=dataset_split) + lvis_samples = lvis_dataset.dataset_samples + lvis_annotations = lvis_dataset.annotations + + coco_dataset = CocoDataset(coco_path, split=dataset_split + '2017') + + coco_lvis_mapping = [] + lvis_images = {x['coco_url'].split('/')[-1].split('.')[0]: lvis_indx + for lvis_indx, x in enumerate(lvis_samples)} + for indx, coco_sample in enumerate(coco_dataset.dataset_samples): + lvis_indx = lvis_images.get(coco_sample['file_name'].split('.')[0], None) + if lvis_indx is not None: + coco_lvis_mapping.append((indx, lvis_indx)) + + output_masks_path = lvis_path / dataset_split / 'masks' + output_masks_path.mkdir(parents=True, exist_ok=True) + + hlvis_annotation = dict() + for coco_indx, lvis_indx in tqdm(coco_lvis_mapping): + coco_sample = get_coco_sample(coco_dataset, coco_indx) + + lvis_info = lvis_samples[lvis_indx] + lvis_annotation = lvis_annotations[lvis_info['id']] + empty_mask = np.zeros((lvis_info['height'], lvis_info['width'])) + image_name = lvis_info['coco_url'].split('/')[-1].split('.')[0] + + lvis_masks = [] + lvis_bboxes = [] + for obj_annotation in lvis_annotation: + obj_mask = lvis_dataset.get_mask_from_polygon(obj_annotation, empty_mask) + obj_mask = obj_mask == 1 + if obj_mask.sum() >= min_object_area: + lvis_masks.append(obj_mask) + lvis_bboxes.append(get_bbox_from_mask(obj_mask)) + + coco_bboxes = [] + coco_masks = [] + for inst_id in coco_sample['instances_info'].keys(): + obj_mask = coco_sample['instances_mask'] == inst_id + if obj_mask.sum() >= min_object_area: + coco_masks.append(obj_mask) + coco_bboxes.append(get_bbox_from_mask(obj_mask)) + + masks = [] + for coco_j, coco_bbox in enumerate(coco_bboxes): + for lvis_i, lvis_bbox in enumerate(lvis_bboxes): + if get_bbox_iou(lvis_bbox, coco_bbox) > 0.70 and \ + get_iou(lvis_masks[lvis_i], coco_masks[coco_j]) > 0.70: + break + else: + masks.append(coco_masks[coco_j]) + + for ti, (lvis_mask, lvis_bbox) in enumerate(zip(lvis_masks, lvis_bboxes)): + for tj_mask, tj_bbox in zip(lvis_masks[ti + 1:], lvis_bboxes[ti + 1:]): + bbox_iou = get_bbox_iou(lvis_bbox, tj_bbox) + if bbox_iou > 0.7 and get_iou(lvis_mask, tj_mask) > 0.85: + break + else: + masks.append(lvis_mask) + + masks_meta = [(get_bbox_from_mask(x), x.sum()) for x in masks] + if not masks: + continue + + hierarchy = get_masks_hierarchy(masks, masks_meta) + + for obj_id, obj_info in list(hierarchy.items()): + if obj_info['parent'] is None and len(obj_info['children']) == 0: + hierarchy[obj_id] = None + + merged_mask = np.max(masks, axis=0) + num_instance_masks = len(masks) + for obj_id in coco_sample['semantic_info'].keys(): + obj_mask = coco_sample['semantic_map'] == obj_id + obj_mask = np.logical_and(obj_mask, np.logical_not(merged_mask)) + if obj_mask.sum() > 500: + masks.append(obj_mask) + + hlvis_annotation[image_name] = { + 'num_instance_masks': num_instance_masks, + 'hierarchy': hierarchy + } + + with open(output_masks_path / f'{image_name}.pickle', 'wb') as f: + pickle.dump(encode_masks(masks), f) + + with open(lvis_path / dataset_split / 'hannotation.pickle', 'wb') as f: + pickle.dump(hlvis_annotation, f, protocol=pickle.HIGHEST_PROTOCOL) + + +def get_coco_sample(dataset, index): + dataset_sample = dataset.dataset_samples[index] + + image_path = dataset.images_path / dataset.get_image_name(dataset_sample['file_name']) + label_path = dataset.labels_path / dataset_sample['file_name'] + + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32) + label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2] + + instance_map = np.full_like(label, 0) + semantic_map = np.full_like(label, 0) + semantic_info = dict() + instances_info = dict() + for segment in dataset_sample['segments_info']: + class_id = segment['category_id'] + obj_id = segment['id'] + if class_id not in dataset._things_labels_set: + semantic_map[label == obj_id] = obj_id + semantic_info[obj_id] = {'ignore': False} + continue + + instance_map[label == obj_id] = obj_id + ignore = segment['iscrowd'] == 1 + instances_info[obj_id] = { + 'ignore': ignore + } + + sample = { + 'image': image, + 'instances_mask': instance_map, + 'instances_info': instances_info, + 'semantic_map': semantic_map, + 'semantic_info': semantic_info + } + + return sample diff --git a/scripts/annotations_conversion/common.py b/scripts/annotations_conversion/common.py new file mode 100644 index 0000000..c594dcc --- /dev/null +++ b/scripts/annotations_conversion/common.py @@ -0,0 +1,179 @@ +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed + +import cv2 +from tqdm import tqdm +import numpy as np + + +def parallel_map(array, worker, const_args=None, n_jobs=16, use_kwargs=False, front_num=3, drop_none=False): + """ + A parallel version of the map function with a progress bar. + + Args: + array (array-like): A list to iterate over + worker (function): A python function to apply to the elements of array + const_args (dict, default=None): Constant arguments, shared between all processes + n_jobs (int, default=16): The number of cores to use + use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of + keyword arguments to function + front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job + drop_none (boolean, default=False): Whether to drop None values from the list of results or not + Returns: + [worker(**list[0], **const_args), worker(**list[1], **const_args), ...] + """ + # Replace None with empty dict + const_args = dict() if const_args is None else const_args + # We run the first few iterations serially to catch bugs + if front_num > 0: + front = [worker(**a, **const_args) if use_kwargs else worker(a, **const_args) for a in array[:front_num]] + else: + front = [] + # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. + if n_jobs == 1: + return front + [worker(**a, **const_args) if use_kwargs else + worker(a, **const_args) for a in tqdm(array[front_num:])] + # Assemble the workers + with ProcessPoolExecutor(max_workers=n_jobs) as pool: + # Pass the elements of array into function + if use_kwargs: + futures = [pool.submit(worker, **a, **const_args) for a in array[front_num:]] + else: + futures = [pool.submit(worker, a, **const_args) for a in array[front_num:]] + tqdm_kwargs = { + 'total': len(futures), + 'unit': 'it', + 'unit_scale': True, + 'leave': True, + 'ncols': 100 + } + # Print out the progress as tasks complete + for _ in tqdm(as_completed(futures), **tqdm_kwargs): + pass + out = [] + # Get the results from the futures. + for i, future in enumerate(futures): + try: + out.append(future.result()) + except Exception as e: + print(f"Caught {str(e)} on {i}-th input.") + out.append(None) + + if drop_none: + return [v for v in front+out if v is not None] + else: + return front + out + + +def get_masks_hierarchy(masks, masks_meta): + order = sorted(list(enumerate(masks_meta)), key=lambda x: x[1][1])[::-1] + hierarchy = defaultdict(list) + + def check_inter(i, j): + assert masks_meta[i][1] >= masks_meta[j][1] + bbox_i, bbox_j = masks_meta[i][0], masks_meta[j][0] + bbox_score = get_bbox_intersection(bbox_i, bbox_j) / get_bbox_area(bbox_j) + if bbox_score < 0.7: + return False + + mask_i, mask_j = masks[i], masks[j] + mask_score = np.logical_and(mask_i, mask_j).sum() / masks_meta[j][1] + return mask_score > 0.8 + + def get_root_indx(root_indx, check_indx): + children = hierarchy[root_indx] + for child_indx in children: + if masks_meta[child_indx][1] < masks_meta[check_indx][1]: + continue + result_indx = get_root_indx(child_indx, check_indx) + if result_indx is not None: + return result_indx + + if check_inter(root_indx, check_indx): + return root_indx + + return None + + used_masks = np.zeros(len(masks), dtype=np.bool) + parents = [None] * len(masks) + node_level = [0] * len(masks) + for ti in range(len(masks) - 1): + for tj in range(ti + 1, len(masks)): + i = order[ti][0] + j = order[tj][0] + + assert i != j + if used_masks[j] or not check_inter(i, j): + continue + + ni = get_root_indx(i, j) + assert ni != j and parents[j] is None + hierarchy[ni].append(j) + used_masks[j] = True + parents[j] = ni + node_level[j] = node_level[ni] + 1 + + hierarchy = [hierarchy[i] for i in range(len(masks))] + hierarchy = {i: {'children': hierarchy[i], + 'parent': parents[i], + 'node_level': node_level[i] + } + for i in range(len(masks))} + return hierarchy + + +def get_bbox_intersection(b1, b2): + h_i = get_segments_intersection(b1[:2], b2[:2]) + w_i = get_segments_intersection(b1[2:4], b2[2:4]) + return h_i * w_i + + +def get_segments_intersection(s1, s2): + a, b = s1 + c, d = s2 + return max(0, min(b, d) - max(a, c) + 1) + + +def get_bbox_area(bbox): + return (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + + +def get_iou(mask1, mask2): + intersection_area = np.logical_and(mask1, mask2).sum() + union_area = np.logical_or(mask1, mask2).sum() + return intersection_area / union_area + + +def encode_masks(masks): + layers = [np.zeros(masks[0].shape, dtype=np.uint8)] + layers_objs = [[]] + objs_mapping = [(None, None)] * len(masks) + ordered_masks = sorted(list(enumerate(masks)), key=lambda x: x[1].sum())[::-1] + for global_id, obj_mask in ordered_masks: + for layer_indx, (layer_mask, layer_objs) in enumerate(zip(layers, layers_objs)): + if len(layer_objs) >= 255: + continue + if np.all(layer_mask[obj_mask] == 0): + layer_objs.append(global_id) + local_id = len(layer_objs) + layer_mask[obj_mask] = local_id + objs_mapping[global_id] = (layer_indx, local_id) + break + else: + new_layer = np.zeros_like(layers[-1]) + new_layer[obj_mask] = 1 + objs_mapping[global_id] = (len(layers), 1) + layers.append(new_layer) + layers_objs.append([global_id]) + + layers = [cv2.imencode('.png', x)[1] for x in layers] + return layers, objs_mapping + + +def decode_masks(packed_data): + layers, objs_mapping = packed_data + layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in layers] + masks = [] + for layer_indx, obj_id in objs_mapping: + masks.append(layers[layer_indx] == obj_id) + return masks diff --git a/scripts/annotations_conversion/openimages.py b/scripts/annotations_conversion/openimages.py new file mode 100644 index 0000000..246a201 --- /dev/null +++ b/scripts/annotations_conversion/openimages.py @@ -0,0 +1,33 @@ +import csv +import pickle as pkl +from pathlib import Path +from collections import defaultdict + + +def create_annotations(dataset_path, dataset_split='train'): + dataset_path = Path(dataset_path) + _split_path = dataset_path / dataset_split + _images_path = _split_path / 'images' + _masks_path = _split_path / 'masks' + clean_anno_path = _split_path / f'{dataset_split}-annotations-object-segmentation_clean.pkl' + + annotations = { + 'image_id_to_masks': defaultdict(list), # mapping from image_id to a list of masks + 'dataset_samples': [] # list of unique image ids + } + + with open(_split_path / f'{dataset_split}-annotations-object-segmentation.csv', 'r') as f: + reader = csv.DictReader(f, delimiter=',') + for row in reader: + image_id = row['ImageID'] + mask_path = row['MaskPath'] + + if (_images_path / f'{image_id}.jpg').is_file() \ + and (_masks_path / mask_path).is_file(): + annotations['image_id_to_masks'][image_id].append(mask_path) + annotations['dataset_samples'] = list(annotations['image_id_to_masks'].keys()) + + with clean_anno_path.open('wb') as f: + pkl.dump(annotations, f) + + return annotations diff --git a/scripts/convert_annotations.py b/scripts/convert_annotations.py new file mode 100644 index 0000000..27cdbac --- /dev/null +++ b/scripts/convert_annotations.py @@ -0,0 +1,40 @@ +import sys +import argparse +import multiprocessing as mp +from pathlib import Path + +sys.path.insert(0, '.') +from isegm.utils.exp import load_config_file +from scripts.annotations_conversion import openimages, ade20k, coco_lvis + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('dataset', choices=['openimages', 'ade20k', 'coco_lvis'], help='') + parser.add_argument('--split', nargs='+', choices=['train', 'val', 'test'], type=str, default=['train', 'val'], + help='') + parser.add_argument('--n-jobs', type=int, default=10) + parser.add_argument('--config-path', type=str, default='./config.yml', + help='The path to the config file.') + + args = parser.parse_args() + cfg = load_config_file(args.config_path, return_edict=True) + return args, cfg + + +def main(): + mp.set_start_method('spawn') + args, cfg = parse_args() + + for split in args.split: + if args.dataset == 'openimages': + openimages.create_annotations(Path(cfg.OPENIMAGES_PATH), dataset_split=split) + elif args.dataset == 'ade20k' and split != 'test': + ade20k.create_annotations(Path(cfg.ADE20K_PATH), dataset_split=split, n_jobs=args.n_jobs) + elif args.dataset == 'coco_lvis': + coco_lvis.create_annotations(Path(cfg.LVIS_PATH), Path(cfg.COCO_PATH), dataset_split=split) + + +if __name__ == '__main__': + main() diff --git a/scripts/evaluate_model.py b/scripts/evaluate_model.py new file mode 100644 index 0000000..70e281e --- /dev/null +++ b/scripts/evaluate_model.py @@ -0,0 +1,279 @@ +import sys +import pickle +import argparse +from pathlib import Path + +import cv2 +import torch +import numpy as np + +sys.path.insert(0, '.') +from isegm.inference import utils +from isegm.utils.exp import load_config_file +from isegm.utils.vis import draw_probmap, draw_with_blend_and_clicks +from isegm.inference.predictors import get_predictor +from isegm.inference.evaluation import evaluate_dataset + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('mode', choices=['NoBRS', 'RGB-BRS', 'DistMap-BRS', + 'f-BRS-A', 'f-BRS-B', 'f-BRS-C'], + help='') + + group_checkpoints = parser.add_mutually_exclusive_group(required=True) + group_checkpoints.add_argument('--checkpoint', type=str, default='', + help='The path to the checkpoint. ' + 'This can be a relative path (relative to cfg.INTERACTIVE_MODELS_PATH) ' + 'or an absolute path. The file extension can be omitted.') + group_checkpoints.add_argument('--exp-path', type=str, default='', + help='The relative path to the experiment with checkpoints.' + '(relative to cfg.EXPS_PATH)') + + parser.add_argument('--datasets', type=str, default='GrabCut,Berkeley,DAVIS,SBD,PascalVOC', + help='List of datasets on which the model should be tested. ' + 'Datasets are separated by a comma. Possible choices: ' + 'GrabCut, Berkeley, DAVIS, SBD, PascalVOC') + + group_device = parser.add_mutually_exclusive_group() + group_device.add_argument('--gpus', type=str, default='0', + help='ID of used GPU.') + group_device.add_argument('--cpu', action='store_true', default=False, + help='Use only CPU for inference.') + + group_iou_thresh = parser.add_mutually_exclusive_group() + group_iou_thresh.add_argument('--target-iou', type=float, default=0.90, + help='Target IoU threshold for the NoC metric. (min possible value = 0.8)') + group_iou_thresh.add_argument('--iou-analysis', action='store_true', default=False, + help='Plot mIoU(number of clicks) with target_iou=1.0.') + + parser.add_argument('--n-clicks', type=int, default=20, + help='Maximum number of clicks for the NoC metric.') + parser.add_argument('--min-n-clicks', type=int, default=1, + help='Minimum number of clicks for the evaluation.') + parser.add_argument('--thresh', type=float, required=False, default=0.49, + help='The segmentation mask is obtained from the probability outputs using this threshold.') + parser.add_argument('--clicks-limit', type=int, default=None) + parser.add_argument('--eval-mode', type=str, default='cvpr', + help='Possible choices: cvpr, fixed (e.g. fixed400, fixed600).') + + parser.add_argument('--save-ious', action='store_true', default=False) + parser.add_argument('--print-ious', action='store_true', default=False) + parser.add_argument('--vis-preds', action='store_true', default=False) + parser.add_argument('--model-name', type=str, default=None, + help='The model name that is used for making plots.') + parser.add_argument('--config-path', type=str, default='./config.yml', + help='The path to the config file.') + parser.add_argument('--logs-path', type=str, default='', + help='The path to the evaluation logs. Default path: cfg.EXPS_PATH/evaluation_logs.') + + args = parser.parse_args() + if args.cpu: + args.device = torch.device('cpu') + else: + args.device = torch.device(f"cuda:{args.gpus.split(',')[0]}") + + if (args.iou_analysis or args.print_ious) and args.min_n_clicks <= 1: + args.target_iou = 1.01 + else: + args.target_iou = max(0.8, args.target_iou) + + cfg = load_config_file(args.config_path, return_edict=True) + cfg.EXPS_PATH = Path(cfg.EXPS_PATH) + + if args.logs_path == '': + args.logs_path = cfg.EXPS_PATH / 'evaluation_logs' + else: + args.logs_path = Path(args.logs_path) + + return args, cfg + + +def main(): + args, cfg = parse_args() + + checkpoints_list, logs_path, logs_prefix = get_checkpoints_list_and_logs_path(args, cfg) + logs_path.mkdir(parents=True, exist_ok=True) + + single_model_eval = len(checkpoints_list) == 1 + assert not args.iou_analysis if not single_model_eval else True, \ + "Can't perform IoU analysis for multiple checkpoints" + print_header = single_model_eval + for dataset_name in args.datasets.split(','): + dataset = utils.get_dataset(dataset_name, cfg) + + for checkpoint_path in checkpoints_list: + model = utils.load_is_model(checkpoint_path, args.device) + + predictor_params, zoomin_params = get_predictor_and_zoomin_params(args, dataset_name) + predictor = get_predictor(model, args.mode, args.device, + prob_thresh=args.thresh, + predictor_params=predictor_params, + zoom_in_params=zoomin_params) + + vis_callback = get_prediction_vis_callback(logs_path, dataset_name, args.thresh) if args.vis_preds else None + dataset_results = evaluate_dataset(dataset, predictor, pred_thr=args.thresh, + max_iou_thr=args.target_iou, + min_clicks=args.min_n_clicks, + max_clicks=args.n_clicks, + callback=vis_callback) + + row_name = args.mode if single_model_eval else checkpoint_path.stem + if args.iou_analysis: + save_iou_analysis_data(args, dataset_name, logs_path, + logs_prefix, dataset_results, + model_name=args.model_name) + + save_results(args, row_name, dataset_name, logs_path, logs_prefix, dataset_results, + save_ious=single_model_eval and args.save_ious, + single_model_eval=single_model_eval, + print_header=print_header) + print_header = False + + +def get_predictor_and_zoomin_params(args, dataset_name): + predictor_params = {} + + if args.clicks_limit is not None: + if args.clicks_limit == -1: + args.clicks_limit = args.n_clicks + predictor_params['net_clicks_limit'] = args.clicks_limit + + if args.eval_mode == 'cvpr': + zoom_in_params = { + 'target_size': 600 if dataset_name == 'DAVIS' else 400 + } + elif args.eval_mode.startswith('fixed'): + crop_size = int(args.eval_mode[5:]) + zoom_in_params = { + 'skip_clicks': -1, + 'target_size': (crop_size, crop_size) + } + else: + raise NotImplementedError + + return predictor_params, zoom_in_params + + +def get_checkpoints_list_and_logs_path(args, cfg): + logs_prefix = '' + if args.exp_path: + rel_exp_path = args.exp_path + checkpoint_prefix = '' + if ':' in rel_exp_path: + rel_exp_path, checkpoint_prefix = rel_exp_path.split(':') + + exp_path_prefix = cfg.EXPS_PATH / rel_exp_path + candidates = list(exp_path_prefix.parent.glob(exp_path_prefix.stem + '*')) + assert len(candidates) == 1, "Invalid experiment path." + exp_path = candidates[0] + checkpoints_list = sorted((exp_path / 'checkpoints').glob(checkpoint_prefix + '*.pth'), reverse=True) + assert len(checkpoints_list) > 0, "Couldn't find any checkpoints." + + if checkpoint_prefix: + if len(checkpoints_list) == 1: + logs_prefix = checkpoints_list[0].stem + else: + logs_prefix = f'all_{checkpoint_prefix}' + else: + logs_prefix = 'all_checkpoints' + + logs_path = args.logs_path / exp_path.relative_to(cfg.EXPS_PATH) + else: + checkpoints_list = [Path(utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint))] + logs_path = args.logs_path / 'others' / checkpoints_list[0].stem + + return checkpoints_list, logs_path, logs_prefix + + +def save_results(args, row_name, dataset_name, logs_path, logs_prefix, dataset_results, + save_ious=False, print_header=True, single_model_eval=False): + all_ious, elapsed_time = dataset_results + mean_spc, mean_spi = utils.get_time_metrics(all_ious, elapsed_time) + + iou_thrs = np.arange(0.8, min(0.95, args.target_iou) + 0.001, 0.05).tolist() + noc_list, over_max_list = utils.compute_noc_metric(all_ious, iou_thrs=iou_thrs, max_clicks=args.n_clicks) + + row_name = 'last' if row_name == 'last_checkpoint' else row_name + model_name = str(logs_path.relative_to(args.logs_path)) + ':' + logs_prefix if logs_prefix else logs_path.stem + header, table_row = utils.get_results_table(noc_list, over_max_list, row_name, dataset_name, + mean_spc, elapsed_time, args.n_clicks, + model_name=model_name) + + if args.print_ious: + min_num_clicks = min(len(x) for x in all_ious) + mean_ious = np.array([x[:min_num_clicks] for x in all_ious]).mean(axis=0) + miou_str = ' '.join([f'mIoU@{click_id}={mean_ious[click_id - 1]:.2%};' + for click_id in [1, 2, 3, 5, 10, 20] if click_id <= min_num_clicks]) + table_row += '; ' + miou_str + else: + target_iou_int = int(args.target_iou * 100) + if target_iou_int not in [80, 85, 90]: + noc_list, over_max_list = utils.compute_noc_metric(all_ious, iou_thrs=[args.target_iou], + max_clicks=args.n_clicks) + table_row += f' NoC@{args.target_iou:.1%} = {noc_list[0]:.2f};' + table_row += f' >={args.n_clicks}@{args.target_iou:.1%} = {over_max_list[0]}' + + if print_header: + print(header) + print(table_row) + + if save_ious: + ious_path = logs_path / 'ious' / (logs_prefix if logs_prefix else '') + ious_path.mkdir(parents=True, exist_ok=True) + with open(ious_path / f'{dataset_name}_{args.eval_mode}_{args.mode}_{args.n_clicks}.pkl', 'wb') as fp: + pickle.dump(all_ious, fp) + + name_prefix = '' + if logs_prefix: + name_prefix = logs_prefix + '_' + if not single_model_eval: + name_prefix += f'{dataset_name}_' + + log_path = logs_path / f'{name_prefix}{args.eval_mode}_{args.mode}_{args.n_clicks}.txt' + if log_path.exists(): + with open(log_path, 'a') as f: + f.write(table_row + '\n') + else: + with open(log_path, 'w') as f: + if print_header: + f.write(header + '\n') + f.write(table_row + '\n') + + +def save_iou_analysis_data(args, dataset_name, logs_path, logs_prefix, dataset_results, model_name=None): + all_ious, _ = dataset_results + + name_prefix = '' + if logs_prefix: + name_prefix = logs_prefix + '_' + name_prefix += dataset_name + '_' + if model_name is None: + model_name = str(logs_path.relative_to(args.logs_path)) + ':' + logs_prefix if logs_prefix else logs_path.stem + + pkl_path = logs_path / f'plots/{name_prefix}{args.eval_mode}_{args.mode}_{args.n_clicks}.pickle' + pkl_path.parent.mkdir(parents=True, exist_ok=True) + with pkl_path.open('wb') as f: + pickle.dump({ + 'dataset_name': dataset_name, + 'model_name': f'{model_name}_{args.mode}', + 'all_ious': all_ious + }, f) + + +def get_prediction_vis_callback(logs_path, dataset_name, prob_thresh): + save_path = logs_path / 'predictions_vis' / dataset_name + save_path.mkdir(parents=True, exist_ok=True) + + def callback(image, gt_mask, pred_probs, sample_id, click_indx, clicks_list): + sample_path = save_path / f'{sample_id}_{click_indx}.jpg' + prob_map = draw_probmap(pred_probs) + image_with_mask = draw_with_blend_and_clicks(image, pred_probs > prob_thresh, clicks_list=clicks_list) + cv2.imwrite(str(sample_path), np.concatenate((image_with_mask, prob_map), axis=1)[:, :, ::-1]) + + return callback + + +if __name__ == '__main__': + main() diff --git a/scripts/plot_ious_analysis.py b/scripts/plot_ious_analysis.py new file mode 100644 index 0000000..d4986c1 --- /dev/null +++ b/scripts/plot_ious_analysis.py @@ -0,0 +1,143 @@ +import sys +import pickle +import argparse +from pathlib import Path +from collections import defaultdict + +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +sys.path.insert(0, '.') +from isegm.utils.exp import load_config_file + + +def parse_args(): + parser = argparse.ArgumentParser() + + group_pkl_path = parser.add_mutually_exclusive_group(required=True) + group_pkl_path.add_argument('--folder', type=str, default=None, + help='Path to folder with .pickle files.') + group_pkl_path.add_argument('--files', nargs='+', default=None, + help='List of paths to .pickle files separated by space.') + group_pkl_path.add_argument('--model-dirs', nargs='+', default=None, + help="List of paths to model directories with 'plots' folder " + "containing .pickle files separated by space.") + group_pkl_path.add_argument('--exp-models', nargs='+', default=None, + help='List of experiments paths suffixes (relative to cfg.EXPS_PATH/evaluation_logs). ' + 'For each experiment, the checkpoint prefix must be specified ' + 'by using the ":" delimiter at the end.') + + parser.add_argument('--mode', choices=['NoBRS', 'RGB-BRS', 'DistMap-BRS', + 'f-BRS-A', 'f-BRS-B', 'f-BRS-C'], + default=None, nargs='*', help='') + parser.add_argument('--datasets', type=str, default='GrabCut,Berkeley,DAVIS,COCO_MVal,SBD', + help='List of datasets for plotting the iou analysis' + 'Datasets are separated by a comma. Possible choices: ' + 'GrabCut, Berkeley, DAVIS, COCO_MVal, SBD') + parser.add_argument('--config-path', type=str, default='./config.yml', + help='The path to the config file.') + parser.add_argument('--n-clicks', type=int, default=-1, + help='Maximum number of clicks to plot.') + parser.add_argument('--plots-path', type=str, default='', + help='The path to the evaluation logs. ' + 'Default path: cfg.EXPS_PATH/evaluation_logs/iou_analysis.') + + args = parser.parse_args() + + cfg = load_config_file(args.config_path, return_edict=True) + cfg.EXPS_PATH = Path(cfg.EXPS_PATH) + + args.datasets = args.datasets.split(',') + if args.plots_path == '': + args.plots_path = cfg.EXPS_PATH / 'evaluation_logs/iou_analysis' + else: + args.plots_path = Path(args.plots_path) + print(args.plots_path) + args.plots_path.mkdir(parents=True, exist_ok=True) + + return args, cfg + + +def main(): + args, cfg = parse_args() + + files_list = get_files_list(args, cfg) + + # Dict of dicts with mapping dataset_name -> model_name -> results + aggregated_plot_data = defaultdict(dict) + for file in files_list: + with open(file, 'rb') as f: + data = pickle.load(f) + data['all_ious'] = [x[:args.n_clicks] for x in data['all_ious']] + aggregated_plot_data[data['dataset_name']][data['model_name']] = np.array(data['all_ious']).mean(0) + + for dataset_name, dataset_results in aggregated_plot_data.items(): + plt.figure(figsize=(12, 7)) + + max_clicks = 0 + for model_name, model_results in dataset_results.items(): + if args.n_clicks != -1: + model_results = model_results[:args.n_clicks] + + n_clicks = len(model_results) + max_clicks = max(max_clicks, n_clicks) + + miou_str = ' '.join([f'mIoU@{click_id}={model_results[click_id-1]:.2%};' + for click_id in [1, 3, 5, 10, 20] if click_id <= len(model_results)]) + print(f'{model_name} on {dataset_name}:\n{miou_str}\n') + + plt.plot(1 + np.arange(n_clicks), model_results, linewidth=2, label=model_name) + + plt.title(f'mIoU after every click for {dataset_name}', fontsize='x-large') + plt.grid() + plt.legend(loc=4, fontsize='x-large') + plt.yticks(fontsize='x-large') + plt.xticks(1 + np.arange(max_clicks), fontsize='x-large') + + fig_path = get_target_file_path(args.plots_path, dataset_name) + plt.savefig(str(fig_path)) + + +def get_target_file_path(plots_path, dataset_name): + previous_plots = sorted(plots_path.glob(f'{dataset_name}_*.png')) + if len(previous_plots) == 0: + index = 0 + else: + index = int(previous_plots[-1].stem.split('_')[-1]) + 1 + + return str(plots_path / f'{dataset_name}_{index:03d}.png') + + +def get_files_list(args, cfg): + if args.folder is not None: + files_list = Path(args.folder).glob('*.pickle') + elif args.files is not None: + files_list = args.files + elif args.model_dirs is not None: + files_list = [] + for folder in args.model_dirs: + folder = Path(folder) / 'plots' + files_list.extend(folder.glob('*.pickle')) + elif args.exp_models is not None: + files_list = [] + for rel_exp_path in args.exp_models: + rel_exp_path, checkpoint_prefix = rel_exp_path.split(':') + exp_path_prefix = cfg.EXPS_PATH / 'evaluation_logs' / rel_exp_path + candidates = list(exp_path_prefix.parent.glob(exp_path_prefix.stem + '*')) + assert len(candidates) == 1, "Invalid experiment path." + exp_path = candidates[0] + files_list.extend(sorted((exp_path / 'plots').glob(checkpoint_prefix + '*.pickle'))) + + if args.mode is not None: + files_list = [file for file in files_list + if any(mode in file.stem for mode in args.mode)] + files_list = [file for file in files_list + if any(dataset in file.stem for dataset in args.datasets)] + + return files_list + + +if __name__ == '__main__': + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000..3bfbbe2 --- /dev/null +++ b/train.py @@ -0,0 +1,82 @@ +import os +import argparse +import importlib.util + +import torch +from isegm.utils.exp import init_experiment + + +def main(): + args = parse_args() + if args.temp_model_path: + model_script = load_module(args.temp_model_path) + else: + model_script = load_module(args.model_path) + + model_base_name = getattr(model_script, 'MODEL_NAME', None) + + args.distributed = 'WORLD_SIZE' in os.environ + cfg = init_experiment(args, model_base_name) + + torch.backends.cudnn.benchmark = True + torch.multiprocessing.set_sharing_strategy('file_system') + + model_script.main(cfg) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('model_path', type=str, + help='Path to the model script.') + + parser.add_argument('--exp-name', type=str, default='', + help='Here you can specify the name of the experiment. ' + 'It will be added as a suffix to the experiment folder.') + + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch-size', type=int, default=-1, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--ngpus', type=int, default=1, + help='Number of GPUs. ' + 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. ' + 'You should use either this argument or "--gpus".') + + parser.add_argument('--gpus', type=str, default='', required=False, + help='Ids of used GPUs. You should use either this argument or "--ngpus".') + + parser.add_argument('--resume-exp', type=str, default=None, + help='The prefix of the name of the experiment to be continued. ' + 'If you use this field, you must specify the "--resume-prefix" argument.') + + parser.add_argument('--resume-prefix', type=str, default='latest', + help='The prefix of the name of the checkpoint to be loaded.') + + parser.add_argument('--start-epoch', type=int, default=0, + help='The number of the starting epoch from which training will continue. ' + '(it is important for correct logging and learning rate)') + + parser.add_argument('--weights', type=str, default=None, + help='Model weights will be loaded from the specified path if you use this argument.') + + parser.add_argument('--temp-model-path', type=str, default='', + help='Do not use this argument (for internal purposes).') + + parser.add_argument("--local_rank", type=int, default=0) + + return parser.parse_args() + + +def load_module(script_path): + spec = importlib.util.spec_from_file_location("model_script", script_path) + model_script = importlib.util.module_from_spec(spec) + spec.loader.exec_module(model_script) + + return model_script + + +if __name__ == '__main__': + main() \ No newline at end of file