This repository provides the Pytorch implementation of the following paper
Equivariance-bridged SO(2) Invariant Representation Learning using Graph Convolutional Networks
Authored by Sungwon Hwang, Hyungtae Lim, and Hyun Myung
Abstract: Training a Convolutional Neural Network (CNN) to be robust against rotation has mostly been done with data augmentation. In this paper, another progressive vision of research direction is highlighted to encourage less dependence on data augmentation by achieving structural rotational invariance of a network. The deep equivariance-bridged SO(2) invariant network is proposed, which consists of two main parts, to echo such vision. First, Self-Weighted Nearest Neighbors Graph Convolutional Network (SWN-GCN) is proposed to implement Graph Convolutional Network (GCN) on the graph representation of an image to acquire rotationally equivariant representation, as GCN is more suitable for constructing deeper network than spectral graph convolution-based approaches. Then, invariant representation is eventually obtained with Global Average Pooling (GAP), a permutation-invariant operation suitable for aggregating high-dimensional representations, over the equivariant set of vertices retrieved from SWN-GCN. Our method achieves the state-of-the-art image classification performance on rotated MNIST and CIFAR-10 image, where the models are trained with a non-augmented dataset only. Then, quantitative and qualitative validations over invariance and equivariance of the representations are reported, respectively.
$ conda env create -f environment.yml
$ conda activate wn_gcn
- R-MNIST
python train.py --test_only True --test_dataset 'RotNIST' --test_model_name './pretrained_models/pretrained_wngcn_mnist.pth.tar'
- R-CIFAR-10
python train.py --test_only True --test_dataset 'RotCIFAR10' --test_model_name './pretrained_models/pretrained_wngcn_cifar10.pth.tar'
- MNIST
python train.py --train_dataset 'MNIST' --test_dataset 'RotNIST' --m 0.05 --save_bestmodel_name './data/saved_models/wngcn_mnist.pth.tar'
- CIFAR-10
python train.py --train_dataset 'CIFAR10' --test_dataset 'RotCIFAR10' --m 0.05 --save_bestmodel_name './data/saved_models/wngcn_cifar10.pth.tar'