This repository implement models described in recent computer vision literature, with a focus on a simple classification task with a classical dataset (MNIST). Three base models are explored: Spatial transformer networks, vision transformers, and SpinalNets. We also implement new variations for two of three of these models, by replacing standard convolutional layers by CoordConv layers.
- Spatial transformer networks (STN)
- Spatial transformer networks + CoordConv layers
- Vision transformers
- SpinalNet
- SpinalNet + STN + CoordConv layers
A complete run of the experiments together with results, comments and references are available in MNIST_benchmarks.ipynb
. These can also be reproduced in the following Colab notebook:
A standalone script
is also provided to reproduce the experiments. A few dependencies are necessary and listed in requirements.txt
usage: [-h] [--device {gpu,cpu}] [--workers WORKERS]
[--bs BS] [--maxepochs MAX_EPOCHS]
[--patience PATIENCE] [--mindelta MIN_DELTA]
[--model {stn,stncoordconv,vit,spinal,spinalstn}]
[--localization] [--lr LR] [--logs LOGPATH]
optional arguments:
-h, --help show this help message and exit
--device {gpu,cpu} Device on which to run the experiments. (default: cpu)
--workers WORKERS Number of workers for dataloaders. (default: 2)
--bs BS Batch size. (default: 64)
--maxepochs MAX_EPOCHS
Maximum number of epochs to run the experiment for.
(default: 20)
--patience PATIENCE Number of epochs with no improvement before triggering
early stopping. (default: 5)
--mindelta MIN_DELTA Required improvement in the validation loss for early
stopping. (default: 0.005)
--model {stn,stncoordconv,vit,spinal,spinalstn}
Type of model to train. (default: stn)
--localization Whether to use CoordConv in the localization network.
(default: False)
--lr LR Learning rate for SGD. (default: 0.01)
--logs LOGPATH Directory to store tensorboard logs. (default: logs/)
Tensorboard is used to save the training and validation logs and metrics. By default, the logs are saved in logs/
. To launch tensorboard, use the following line. More details on tensorboard are found here:
tensorboard --logdir=logs/ --port <port> --host <host>