Geo-SIC is an image classifier enhanced by joint geometric shape learning in atlas building. Read our paper on arXiv.
We've updated the most recent version of joint learning of Geo-SIC using PyTorch. Additionally, an integration between C++ and Python will be available soon. We're also expanding our testing procedure to include more diverse classifiers.
Geo-SIC operates as follows:
- Atlas Building: Input images into an atlas building network to extract geometric features in the latent space of transformation fields. An atlas is learned during this process.
- Image Classifier: Input images to a backbone classifier. Fuse both backbone features and the features from the atlas to predict a class label.
Ensure that your Python version is greater than 3.12. If you wish to use the spatial version of LDDMM (Lagomorph) to produce geometric features, make sure inlude the source code we provided and also install the lagomorph
package by:
pip install lagomorph
PyTorch version >= 0.4.0 is also required.
- optimizer: Specifies the optimizer used for training the model (e.g., 'Adam').
- scheduler: Specifies the learning rate scheduler (e.g., 'CosAn' for Cosine Annealing).
- loss: Specifies the loss function used for training (e.g., 'L2' for Mean Squared Error, NCC for normalized cross-correlation).
- augmentation: Boolean indicating whether data augmentation is enabled.
- reduced_dim: Specifies the reduced dimensions (e.g., 16, 16, 16).
- pretrain_epoch: Specifies the number of epochs for pretraining (e.g., 300).
- lr: Learning rate for both models (e.g., 1e-4).
- epochs: Total number of epochs for training (e.g., 1000).
- batch_size: Batch size for training (e.g., 1, 4, 8).
- weight_decay: Weight decay parameter (e.g., 1e-5).
- Euler_steps: Number of steps for Euler integration (e.g., 5, 10).
- Alpha: Alpha in shooting (e.g., 1.0, 2.0).
- Gamma: Gamma in shooting (e.g., 1.0).
- Lpow: The power of the Laplacian operator in shooting (e.g., 4.0, 6.0).
- Sigma: The noise variance on the image matching term in LDDMM (e.g., 0.02, 0.03).
in_channels
: Number of input channels (should be set to 1 for grayscale images).conv_channels
: List containing the number of channels for each convolutional layer (e.g., [8, 16, 16, 32, 32]).conv_kernel_sizes
: List containing the kernel sizes for each convolutional layer (e.g., [3, 3, 3, 3, 3]).activation
: Activation function used in the convolutional layers (e.g., 'ReLU', 'PReLU').num_classes
: Number of output classes for the classification task (e.g., 2).
To build an atlas using our deep learning framework, run Run_Atlas_trainer.py
. To run the entire training for classifiers, simply execute the script Run_trainer.py
.
We've included different shooting models:
- Stationary Velocity Field in Voxelmorph (Code: GitHub)
- Vector-Momenta Shooting Method in Lagomorph (Code: GitHub)
- LDDMM-based on Fourier representations (Code: Bitbucket).
If you have any questions, feel free to reach out to us by opening an issue.
This project is licensed under the terms of the MIT license. See the LICENSE file for details.