We use Graph Adversarial Knowledge Distillation (GAKD) to distill the knowledge from the teacher model to the student model. This framework is based on the paper Compressing Deep Graph Neural Networks via
Adversarial Knowledge Distillation. Basic idea is to have a discriminator that learns to distinguish between the teacher and student embeddings while the student tries to fool the discriminator.
- We re-implemented the GAKD framework from GraphGPS repository with better readability and configurability and used GINE model as the student. The official implementation uses GCN and GIN students.
- Our experiments revolves around full GAKD training, and exploration of Representation Identifier and Logits Identifier effects in isolation.
- All experiments are done on
OGBG-MolPCBA
dataset.
- Pre-requisite
GML
Environment setup as described in the root README.md.
- Create a new directory for the GakD experiments and
logs
directory inside it. - Copy
baseline.py
,gakd.py
andSBATCH
scripts fromscripts/*
directory to the new directory. - Modify the
SBATCH
script parameters according to your enviornment. Make sure to set the correctBASE_DIR
(for all experiments) andTEACHER_KNOWLEDGE_PATH
with the path to the teacher knowledge file (forgakd
experiments). - The output of the experiments will be available in the
$BASE_DIR/results
directory with the namegine_results_<dataset_name>_<with/without>_virtual_node.csv
forbaseline
experiments.gine_student_gakd_<dataset_name>_<with/without>_virtual_node_discriminator_logits_<true/false>_discriminator_embeddings_<true/false>_k<discriminator_update_freq>_wd<student_optimizer_weight_decay>_drop<student_dropout>.csv
forgakd
experiments.
- We use
GINE
model as the student baseline. The training script for the baseline models isbaseline.py
. - We used
Atom
andBond
Encoders for encoding node and bond features. - We trained the baseline with and without
Virtual Node Aggregation
. - The rest of the parameters are same for both experiments:
- Number of runs:
5
- Starting seed:
42
- Dataset:
ogbg-molpcba
- Number of layers:
5
- Hidden dimension:
400
- Dropout:
0.5
- Learning rate:
0.001
- Batch size:
32
- Epochs:
100
- Number of runs:
- Our training results for
With Virtual Node Aggregation (Parameters: 12585067)
experiment are summarized in following table:Seed Valid AP Test AP 42 0.214 0.214 43 0.202 0.198 44 0.219 0.213 45 0.195 0.187 46 0.202 0.192 Mean ± Std 0.206 ± 0.009 0.201 ± 0.010 Without Virtual Node Aggregation (Parameters: 3717733)
experiment results are summarized in following table:Seed Valid AP Test AP 42 0.207 0.205 43 0.210 0.203 44 0.209 0.206 45 0.201 0.202 46* 0.208 0.206 Mean ± Std 0.207 ± 0.003 0.205 ± 0.002 - We can see that the
Virtual Node Aggregation
does not improve the performance of the student model on average. - On the basis of average performance, we selected the models without the
Virtual Node Aggregation
as the baseline model. - The best performing model on
Test AP
in that class is with seed46
(Test AP:0.206
). We will consider this model as the baseline model for the rest of the experiments.# with Virtual Node Aggregation sbatch scripts/gine-baseline-vn.sh # without Virtual Node Aggregation sbatch scripts/gine-baseline.sh
We have some notable training differences compared to the original implementation:
- We use
GINE
model as the student model. - Our teacher model is based on
GraphGPS
model configuration forOGBG-MolPCBA
dataset. - Their implementation which uses
GCN
andGIN
students has a very largeEmbedding dimension (1024)
, whereas we use400
forGINE
. - Their batch size for
OGBG-MolPCBA
dataset is512
, whereas we use32
because of GPU memory constraints. - Their default discriminator update frequency (
K
in paper,Section C.1
) is1
(update on each iteration), whereas we use5
(update every 5 iterations) forOGBG-MolPCBA
dataset. They did not mention the reason for this choice andK=5
performed better as compared toK=1
for our initial experiments. - We execute
2 runs
for most experiments having50
epochs. For an experiment having100
epochs, we execute1 run
due to time constraints. So for single run experiments, we cannot get an estimate of thestandard deviation
for the results.
- We did two experiments on full GAKD, one with
50
epochs with2 runs
and one with100
epochs with a single run. - The parameters for students are the same as the baseline model.
- The parameters for discriminators are:
- Learning rate:
0.01
- Weight decay:
0.0005
- Discriminator update frequency (K):
5
- Learning rate:
- The results are summarized in following table:
Seed Runs Epochs Valid AP Test AP Training Time 42-43 2 50 0.213 ± 0.001 0.211 ± 0.003 ≈26 hours 42* 1 100 0.219 0.219 ≈26 hours - We can see that the performance of the student model on
Test AP
is best with100
epochs as compared to student baseline (+0.01
). - Under the
50
epochs, the performance gain is a little less as compared to100
epochs (+0.005
) but still better than student baseline. - This applies that the student model is able to learn more from the teacher model.
- With lesser epochs (i.e.
50
), we still can beat the baseline performance by small margin. - Ideally, GAN needs larger batch size and more epochs to converge. Here, we are using
32
batch size and50
epochs due to memory and time constraints, and still we achieve better performance than student baseline. - Further experiments with larger batch size and more epochs can provide a clearer picture of the performance of the GAKD framework.
- To reproduce the results, submit the following commands via
sbatch
:# 50 epochs sbatch scripts/gine-gakd-k5-wd0-drop0.5-epoch50.sh # 100 epochs sbatch scripts/gine-gakd-k5-wd0-drop0.5-epoch100.sh
- We did two experiments with
50
epoch where we only trained the Representation Identifier discriminators. - The parameters for discriminators are:
- Learning rate:
0.01
- Weight decay:
0.0005
- Discriminator update frequency (K):
5
- Learning rate:
- The results are summarized in following table:
Seed Runs Epochs Valid AP Test AP Training Time 42-43 2 50 0.209 ± 0.009 0.209 ± 0.007 ≈26 hours - The mean performance of the trained Representation Identifier discriminators is only slightly better (
+0.003
) than student baseline. - This result aligns with the results of the paper where they mention that both Identifiers are required to achieve better performance.
- However, due to high variance in the results, we cannot say with confidence that full GAKD training is always better than Representation Identifier discriminators.
- To reproduce the results, submit the following command via
sbatch
:sbatch scripts/gine-gakd-embeddings-k5-wd0-drop0.5-epoch50.sh
- We did two experiments with
50
epoch where we only trained the Logits Identifier discriminators. - The parameters for discriminators are:
- Learning rate:
0.01
- Weight decay:
0.0005
- Discriminator update frequency (K):
5
- Learning rate:
- The results are summarized in following table:
Seed Runs Epochs Valid AP Test AP Training Time 42-43 2 50 0.209 ± 0.011 0.208 ± 0.014 ≈26 hours - The mean performance of the trained Logits Identifier discriminators is only slightly better (
+0.002
) than student baseline. - Similar to Experiment #2, we can not achieve far better performance than student baseline just by training the Logits Identifier discriminators.
- This concludes that we need to train both Representation Identifier and Logits Identifier discriminators of GAKD framework to achieve better performance.
- To reproduce the results, submit the following command via
sbatch
:sbatch scripts/gine-gakd-logits-k5-wd0-drop0.5-epoch50.sh
- We did one experiment with
50
epochs where we trained full GAKD with increased discriminator update frequency (K=1). - We wanted to see if the performance of the student model is enhanced by training the discriminators more frequently.
- The parameters for discriminators are:
- Learning rate:
0.01
- Weight decay:
0.0005
- Discriminator update frequency (K):
1
- Learning rate:
- The results are summarized in following table:
Seed Runs Epochs Valid AP Test AP Training Time 42 1 50 0.1382 0.1412 ≈13 hours - Above results indicate that frequently updating the discriminators penalizes the performance and we cannot reach the performance of student baseline.
- This highlights that we need to train the discriminators less frequently to achieve better performance.
- Further experiments are needed to find the optimal discriminator update frequency.
- To reproduce the results, submit the following command via
sbatch
:sbatch scripts/gine-gakd-k1-wd0.00001-drop0.5-epoch50.sh
- We are able to achieve better performance than student baseline with full GAKD training.
- However, the performance gain is not as high as the paper suggests. This could be due to the differences in the implementation and training environment (batch size, embedding dimension, epochs, etc.).
- Representation Identifier and Logits Identifier discriminators only achieve a small performance gain over student baseline when trained in isolation.
- We need to train the discriminators less frequently to achieve better performance (
K > 1
).
- We did not get an estimate of the
standard deviation
for some experiments due to time constraints. In future, we should run experiments with more runs and epochs to get a better estimate of GAKD framework performance. - We need to explore more on the optimal discriminator update frequency (
K
). - We can try with larger batch size aligning with the original implementation.
- We can also run experiments with
GINE with Virtual Node Aggregation
to see if adding virtual nodes helps in achieving better performance under GAKD framework. - Current training time is exceptionally high and debugging is required to find the bottleneck.
- We can do tSNE plots of baseline, student and teacher embeddings and compare the Silhouette scores to see if the student embeddings are able to capture the teacher embeddings.