Skip to content

Commit dd1a4cb

Browse files
committed
init code
1 parent 7b118e1 commit dd1a4cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+15858
-4
lines changed

.gitignore

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
*.tar
2+
*.pth
3+
venv/
4+
.idea/
5+
6+
*.err
7+
*.log
8+
*.out
9+
*.eps
10+
*.png
11+
*.pyc
12+
13+
__pycache__/

LICENSE

+399
Large diffs are not rendered by default.

README.md

+92-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,92 @@
1-
# CLsurvey
2-
Codebase of the continual learning survey:
3-
"Continual learning: A comparative study on how to defy forgetting in classification tasks." arXiv preprint arXiv:1909.08383 (2019).
4-
Code coming soon!
1+
# A continual learning survey: Defying forgetting in classification tasks
2+
Source code for the [Continual Learning survey paper](https://arxiv.org/abs/1909.08383):
3+
4+
```
5+
@article{de2019continual,
6+
title={A continual learning survey: Defying forgetting in classification tasks},
7+
author={De Lange, Matthias and Aljundi, Rahaf and Masana, Marc and Parisot, Sarah and Jia, Xu and Leonardis, Ale{\v{s}} and Slabaugh, Gregory and Tuytelaars, Tinne},
8+
journal={arXiv preprint arXiv:1909.08383},
9+
year={2019}
10+
}
11+
```
12+
13+
The code contains a generalizing framework for 11 SOTA methods and 4 baselines:
14+
- Methods: SI, EWC, MAS, mean/mode-IMM, LWF, EBLL, PackNet, HAT, GEM, iCaRL
15+
- Baselines
16+
- Joint: Learn from all task data at once with a single head (multi-task learning baseline).
17+
- Finetuning: standard SGD
18+
- Finetuning with Full Memory replay: Allocate memory dynamically to incoming tasks.
19+
- Finetuning with Partial Memory replay: Divide memory a priori over all tasks.
20+
21+
22+
This source code is released under a Attribution-NonCommercial 4.0 International
23+
license, find out more about it in the [LICENSE file](LICENSE).
24+
25+
26+
27+
28+
## Pipeline
29+
**Reproducability**: Results from the paper can be obtained from [src/main_'dataset'.sh](src/main_tinyimagenet.sh).
30+
Full pipeline example in [src/main_tinyimagenet.sh](src/main_tinyimagenet.sh) .
31+
32+
**Pipeline**: Constructing a custom pipeline typically requires the following steps.
33+
1. Project Setup
34+
1. For all requirements see [requirements.txt](requirements.txt).
35+
Main packages can be installed as in
36+
```
37+
conda create --name <ENV-NAME> python=3.7
38+
conda activate <ENV-NAME>
39+
40+
# Main packages
41+
conda install -c conda-forge matplotlib tqdm
42+
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
43+
44+
# For GEM QP
45+
conda install -c omnia quadprog
46+
47+
# For PackNet: torchnet
48+
pip install git+https://github.com/pytorch/tnt.git@master
49+
```
50+
1. Set paths in '[config.init](src/config.init)' (or leave default)
51+
1. '{tr,test}_results_root_path': where to save training/testing results.
52+
1. 'models_root_path': where to store initial models (to ensure same initial model)
53+
1. 'ds_root_path': root path of your datasets
54+
1. Prepare dataset: see [src/data](src/data)/"dataset"_dataprep.py (e.g. [src/data/tinyimgnet_dataprep.py](src/data/tinyimgnet_dataprep.py))
55+
1. **Train** any out of the 11 SOTA methods or 4 baselines
56+
1. **Regularization-based/replay methods:** We run a *first task model dump*, for Synaptic Intelligence (SI) as it acquires importance weights during training.
57+
Other methods start from this same initial model.
58+
1. **Baselines/parameter isolation methods**: Start training sequence from scratch
59+
1. **Evaluate** performance, sequence for testing on a task is saved in dictionary format under *test_results_root_path* defined in [config.init](src/config.init).
60+
1. **Plot** the evaluation results, using one of the configuration files in [utilities/plot_configs](src/utilities/plot_configs)
61+
62+
## Implement Your Method
63+
1. Find class "YourMethod" in [methods/method.py](src/methods/method.py). Implement the framework phases (documented in code).
64+
1. Implement your task-based training script in [methods](src/methods): methods/"YourMethodDir".
65+
The class "YourMethod" will call this code for training/eval/processing of a single task.
66+
67+
68+
## Project structure
69+
- [src/data](src/data): datasets and automated preparation scripts for Tiny Imagenet and iNaturalist.
70+
- [src/framework](src/framework): the novel task incremental continual learning framework.
71+
**main.py** starts training pipeline, specify *--test* argument to perform evaluation with **eval.py**.
72+
- [src/methods](src/methods): all methods source code and **method.py** wrapper.
73+
- [src/models](src/models): **net.py** all model preprocessing.
74+
- [src/utilities](src/utilities): utils used across all modules and plotting.
75+
- Config:
76+
- [src/data](src/data)/{datasets/models}: default datasets and models directory (see [config.init](src/config.init))
77+
- [src/results](src/results)/{train/test}: default training and testing results directory (see [config.init](src/config.init))
78+
79+
80+
## Credits
81+
- Consider citing our work upon using this repo.
82+
- Thanks to Huawei for funding this project.
83+
- Thanks to the following repositories:
84+
- https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses
85+
- https://github.com/facebookresearch/GradientEpisodicMemory
86+
- https://github.com/arunmallya/packnet
87+
- https://github.com/joansj/hat
88+
* If you want to join the Continual Learning community, checkout https://www.continualai.org
89+
90+
## Support
91+
* If you have troubles, please open a Git issue.
92+
* Have you defined your method in the framework and want to share it with the community? Send a pull request!

requirements.txt

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# This file may be used to create an environment using:
2+
# $ conda create --name <env> --file <this file>
3+
# platform: linux-64
4+
# Python 3.7
5+
# Original code in Pytorch 1.0.0
6+
# Compatibility patch for Pytorch 1.6 (Create git issue for remaining bugs)
7+
_libgcc_mutex=0.1=main
8+
blas=1.0=mkl
9+
ca-certificates=2020.6.20=hecda079_0
10+
certifi=2020.6.20=py37hc8dfbb8_0
11+
chardet=3.0.4=pypi_0
12+
cudatoolkit=10.2.89=hfd86e86_1
13+
cycler=0.10.0=py_2
14+
freetype=2.10.2=h5ab3b9f_0
15+
future=0.18.2=pypi_0
16+
idna=2.10=pypi_0
17+
intel-openmp=2020.2=254
18+
jpeg=9b=h024ee3a_2
19+
jsonpatch=1.26=pypi_0
20+
jsonpointer=2.0=pypi_0
21+
kiwisolver=1.2.0=py37h99015e2_0
22+
lcms2=2.11=h396b838_0
23+
ld_impl_linux-64=2.33.1=h53a641e_7
24+
libedit=3.1.20191231=h14c3975_1
25+
libffi=3.3=he6710b0_2
26+
libgcc-ng=9.1.0=hdf63c60_0
27+
libpng=1.6.37=hbc83047_0
28+
libstdcxx-ng=9.1.0=hdf63c60_0
29+
libtiff=4.1.0=h2733197_1
30+
lz4-c=1.9.2=he6710b0_1
31+
matplotlib=3.3.1=1
32+
matplotlib-base=3.3.1=py37hd478181_1
33+
mkl=2020.2=256
34+
mkl-service=2.3.0=py37he904b0f_0
35+
mkl_fft=1.1.0=py37h23d657b_0
36+
mkl_random=1.1.1=py37h0573a6f_0
37+
ncurses=6.2=he6710b0_1
38+
ninja=1.10.1=py37hfd86e86_0
39+
numpy=1.19.1=py37hbc911f0_0
40+
numpy-base=1.19.1=py37hfa32c7d_0
41+
olefile=0.46=py37_0
42+
openssl=1.1.1g=h516909a_1
43+
pillow=7.2.0=py37hb39fc2d_0
44+
pip=20.2.2=py37_0
45+
pyparsing=2.4.7=pyh9f0ad1d_0
46+
python=3.7.9=h7579374_0
47+
python-dateutil=2.8.1=py_0
48+
python_abi=3.7=1_cp37m
49+
pytorch=1.6.0=py3.7_cuda10.2.89_cudnn7.6.5_0
50+
pyzmq=19.0.2=pypi_0
51+
quadprog=0.1.6=py37_0
52+
readline=8.0=h7b6447c_0
53+
requests=2.24.0=pypi_0
54+
scipy=1.5.2=pypi_0
55+
setuptools=49.6.0=py37_0
56+
six=1.15.0=py_0
57+
sqlite=3.33.0=h62c20be_0
58+
tk=8.6.10=hbc83047_0
59+
torchfile=0.1.0=pypi_0
60+
torchnet=0.0.5.1=pypi_0
61+
torchvision=0.7.0=py37_cu102
62+
tornado=6.0.4=py37h8f50634_1
63+
tqdm=4.48.2=pyh9f0ad1d_0
64+
urllib3=1.25.10=pypi_0
65+
visdom=0.1.8.9=pypi_0
66+
websocket-client=0.57.0=pypi_0
67+
wheel=0.35.1=py_0
68+
xz=5.2.5=h7b6447c_0
69+
zlib=1.2.11=h7b6447c_3
70+
zstd=1.4.5=h9ceee32_0

src/config.init

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Project config file
2+
# May not contain spaces, surround paths with apostrophes
3+
[DEFAULT]
4+
test_results_root_path='./results/test'
5+
tr_results_root_path='./results/train'
6+
models_root_path='./data/models'
7+
ds_root_path='./data/datasets'

0 commit comments

Comments
 (0)