Skip to content

Commit

Permalink
Version 3.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Sep 30, 2022
1 parent 8de4dd3 commit 6190492
Show file tree
Hide file tree
Showing 386 changed files with 21,043 additions and 11,562 deletions.
22 changes: 22 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[flake8]
max-line-length = 88
extend-ignore = W191,E203,E501,E402
show_source = True
exclude =
# No need to traverse our git directory
.git,
# There's no value in checking cache directories
__pycache__,
# The conf file is mostly autogenerated, ignore it
docs/source/conf.py,
# The old directory contains Flake8 2.0
old,
# This contains our built documentation
build,
# This contains builds of flake8 that we don't want to check
dist,
# Ignore notebook checkpoints
.ipynb_checkpoints
per-file-ignores =
# imported but unused
__init__.py: F401
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ img/
*.err
*.out
*.torch
*ign*
*.ign*
*_ign*
.tmp*

# Specific files
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

include README.md
graft notebooks/
graft conf/
132 changes: 89 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
<!-- # -*- coding: utf-8 -*- -->

<div align="center">

# Deep Semi-Supervised Learning with Holistic methods (SSLH)

<a href="https://www.python.org/"><img alt="Python" src="https://img.shields.io/badge/-Python 3.9+-blue?style=for-the-badge&logo=python&logoColor=white"></a>
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/-PyTorch 1.7.1-ee4c2c?style=for-the-badge&logo=pytorch&logoColor=white"></a>
<a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-black.svg?style=for-the-badge&labelColor=gray"></a>

Unofficial PyTorch and PyTorch-Lightning implementations of Deep Semi-Supervised Learning methods for audio tagging.

</div>

There is 4 SSL methods :
- [FixMatch (FM)](https://arxiv.org/pdf/2001.07685.pdf)
- [MixMatch (MM)](https://arxiv.org/pdf/1905.02249.pdf)
- [ReMixMatch (RMM)](https://arxiv.org/pdf/1911.09785.pdf)
- [Unsupervised Data Augmentation (UDA)](https://arxiv.org/pdf/1904.12848.pdf)
- [FixMatch (FM)](https://arxiv.org/pdf/2001.07685.pdf) [1]
- [MixMatch (MM)](https://arxiv.org/pdf/1905.02249.pdf) [2]
- [ReMixMatch (RMM)](https://arxiv.org/pdf/1911.09785.pdf) [3]
- [Unsupervised Data Augmentation (UDA)](https://arxiv.org/pdf/1904.12848.pdf) [4]

For the following datasets :
- [CIFAR-10 (CIFAR10)](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
Expand All @@ -15,33 +25,35 @@ For the following datasets :
- [Primate Vocalization Corpus (PVC)](https://arxiv.org/pdf/2101.10390.pdf)
- [UrbanSound8k (UBS8K)](http://www.justinsalamon.com/uploads/4/3/9/4/4394963/salamon_urbansound_acmmm14.pdf)

[comment]: <> (- [AudioSet &#40;ADS&#41;]&#40;https://static.googleusercontent.com/media/research.google.com/fr//pubs/archive/45857.pdf&#41;)
[comment]: <> (- [FSD50K]&#40;&#41;)

With 3 models :
- [WideResNet28 (WRN28)](https://arxiv.org/pdf/1605.07146.pdf)
- [MobileNetV1 (MNV1)](https://arxiv.org/pdf/1704.04861.pdf)
- [MobileNetV2 (MNV2)](https://arxiv.org/pdf/1801.04381.pdf)

The implementation of Mean Teacher (MT), Deep Co-Training (DCT) and Pseudo-Labeling (PL) are present in this repository but not fully tested.
**IMPORTANT NOTE: The implementation of Mean Teacher (MT), Deep Co-Training (DCT) and Pseudo-Labeling (PL) are present in this repository but not fully tested.**

You can find a more stable version of MT and DCT at https://github.com/lcances/semi-supervised.
The datasets AudioSet and FSD50K are in beta testing.
If you meet problems, you can contact me at


## Installation
#### Download & setup
- Clone the repository :
```bash
git clone https://github.com/Labbeti/SSLH
conda env create -n env_sslh -f environment.yaml
conda activate env_sslh
pip install -e SSLH --no-dependencies
```
- Set up the package in your environment :

#### Alternatives
- As python package :
```bash
cd SSLH
pip install -e .
pip install https://github.com/Labbeti/SSLH
```
The dependencies will be automatically installed with pip instead of conda, which means the the build versions can be slightly different.

The installation is now finished.

#### Alternatives
The project contains also a ```environment.yaml``` and ```requirements.txt``` for installing the packages respectively with conda or pip :
The project contains also a ```environment.yaml``` and ```requirements.txt``` for installing the packages respectively with conda or pip.
- With **conda** environment file :
```bash
conda env create -n env_sslh -f environment.yaml
Expand All @@ -56,57 +68,44 @@ pip install -e . --no-dependencies
```

## Datasets
CIFAR10, ESC10 and GoogleSpeechCommands are automatically downloaded and installed.
For UrbanSound8k, please read the [README of leocances](https://github.com/leocances/UrbanSound8K/blob/master/README.md#prepare-the-dataset), in section "Prepare the dataset".
CIFAR10, ESC10, GoogleSpeechCommands and FSD50K can be downloaded and installed.
For UrbanSound8k, please read the [README of leocances](https://github.com/leocances/UrbanSound8K/blob/master/README.md#prepare-the-dataset), in section "Prepare the dataset".
AudioSet (ADS) and Primate Vocalize Corpus (PVC) cannot be installed automatically by now.

To download a dataset, you can use the `data.download=true` option.

[comment]: <> (TODO : For Audioset install !)
[comment]: <> (TODO : For PVC install !)

## Usage
The main scripts available are in ```standalone``` directory :
```
standalone
├── deep_co_training.py
├── fixmatch.py
├── mean_teacher.py
├── mixmatch.py
├── mixup.py
├── pseudo_labeling.py
├── remixmatch.py
├── supervised.py
└── uda.py
```

The code use Hydra for parsing args. The syntax of setting an argument is "name=value" instead of "--name value".
This code use Hydra for parsing args. The syntax of setting an argument is "name=value" instead of "--name value".

Example 1 : MixMatch on ESC10
```bash
python mixmatch.py data=esc10
python -m sslh.mixmatch data=ssl_esc10 data.dm.download=true
```

Example 2 : Supervised+Weak on GSC
```bash
python supervised.py data=gsc expt.augm_train=weak bsize=256 epochs=300
python -m sslh.supervised data=sup_gsc aug@train_aug=weak data.dm.bsize=256 epochs=300 data.dm.download=true
```

Example 3 : FixMatch+MixUp on UBS8K
```bash
python fixmatch.py data=ubs8K expt=fixmatch_mixup bsize_s=128 bsize_u=128 epochs=300
python -m sslh.fixmatch data=ssl_ubs8K pl=fixmatch_mixup data.dm.bsize_s=128 data.dm.bsize_u=128 epochs=300 data.dm.download=true
```
(note: default folds used for UBS8K are in "config/data/ubs8k.yaml")

Example 4 : ReMixMatch on CIFAR-10
```bash
python remixmatch.py data=cifar10 model.n_input_channels=3
python -m sslh.remixmatch data=ssl_cifar10 model.n_input_channels=3 aug@weak_aug=img_weak aug@strong_aug=img_strong data.dm.download=true
```

## List of main arguments

| Name | Description | Values | Default |
| --- | --- | --- | --- |
| data | Dataset used | ads, cifar10, esc10, fsd50k, gsc, pvc, ubs8k | esc10 |
| expt | Training method (experiment) used | *(depends of the python script, see the filenames in config/expt/ folder)* | *(depends of the python script)* |
| data | Dataset used | (sup|ssl)_(ads|cifar10|esc10|fsd50k|gsc|pvc|ubs8k) | (sup|ssl)_esc10 |
| pl | Pytorch Lightning training method (experiment) used | *(depends of the python script, see the filenames in config/pl/ folder)* | *(depends of the python script)* |
| model | Pytorch model to use | mobilenetv1, mobilenetv2, vgg, wideresnet28 | wideresnet28 |
| optim | Optimizer used | adam, sgd | adam |
| sched | Learning rate scheduler | cosine, softcosine, none | softcosine |
Expand All @@ -127,7 +126,7 @@ sslh
│ ├── supervised
│ └── semi_supervised
├── datasets
├── expt
├── pl
│ ├── deep_co_training
│ ├── fixmatch
│ ├── mean_teacher
Expand All @@ -140,9 +139,13 @@ sslh
├── metrics
├── models
├── transforms
│ ├── augments
│ ├── get
│ ├── image
│ ├── other
│ ├── pools
│ └── self_transforms
│ ├── self_transforms
│ ├── spectrogram
│ └── waveform
└── utils
```

Expand Down Expand Up @@ -186,3 +189,46 @@ It contains also some code from the following authors :
| SUP | Supervised Learning |
| _u | Unsupervised |
| UBS8K | UrbanSound8K dataset |

## References

[1] K. Sohn, D. Berthelot, C.-L. Li, Z. Zhang, N. Carlini, E. D. Cubuk, A. Ku-
rakin, H. Zhang, and C. Raffel, “FixMatch: Simplifying Semi-Supervised
Learning with Consistency and Confidence,” p. 21.

[2] D. Berthelot, N. Carlini, I. Goodfellow, N. Papernot, A. Oliver, and
C. Raffel, “MixMatch: A Holistic Approach to Semi-Supervised Learning,”
Oct. 2019, number: arXiv:1905.02249 arXiv:1905.02249 [cs, stat]. [Online].
Available: http://arxiv.org/abs/1905.02249

[3] D. Berthelot, N. Carlini, E. D. Cubuk, A. Kurakin, K. Sohn,
H. Zhang, and C. Raffel, “ReMixMatch: Semi-Supervised Learning
with Distribution Alignment and Augmentation Anchoring,” Feb. 2020,
number: arXiv:1911.09785 arXiv:1911.09785 [cs, stat]. [Online]. Available:
http://arxiv.org/abs/1911.09785

[4] Q. Xie, Z. Dai, E. Hovy, M.-T. Luong, and Q. V. Le, “Unsu-
pervised Data Augmentation for Consistency Training,” Nov. 2020,
number: arXiv:1904.12848 arXiv:1904.12848 [cs, stat]. [Online]. Available:
http://arxiv.org/abs/1904.12848

<!-- Cances, L., Labbé, E. & Pellegrini, T. Comparison of semi-supervised deep learning algorithms for audio classification. J AUDIO SPEECH MUSIC PROC. 2022, 23 (2022). https://doi.org/10.1186/s13636-022-00255-6 -->

## Cite this repository
If you use this code, you can cite the following paper associated :
```
@article{cances_comparison_2022,
title = {Comparison of semi-supervised deep learning algorithms for audio classification},
volume = {2022},
issn = {1687-4722},
url = {https://doi.org/10.1186/s13636-022-00255-6},
doi = {10.1186/s13636-022-00255-6},
abstract = {In this article, we adapted five recent SSL methods to the task of audio classification. The first two methods, namely Deep Co-Training (DCT) and Mean Teacher (MT), involve two collaborative neural networks. The three other algorithms, called MixMatch (MM), ReMixMatch (RMM), and FixMatch (FM), are single-model methods that rely primarily on data augmentation strategies. Using the Wide-ResNet-28-2 architecture in all our experiments, 10\% of labeled data and the remaining 90\% as unlabeled data for training, we first compare the error rates of the five methods on three standard benchmark audio datasets: Environmental Sound Classification (ESC-10), UrbanSound8K (UBS8K), and Google Speech Commands (GSC). In all but one cases, MM, RMM, and FM outperformed MT and DCT significantly, MM and RMM being the best methods in most experiments. On UBS8K and GSC, MM achieved 18.02\% and 3.25\% error rate (ER), respectively, outperforming models trained with 100\% of the available labeled data, which reached 23.29\% and 4.94\%, respectively. RMM achieved the best results on ESC-10 (12.00\% ER), followed by FM which reached 13.33\%. Second, we explored adding the mixup augmentation, used in MM and RMM, to DCT, MT, and FM. In almost all cases, mixup brought consistent gains. For instance, on GSC, FM reached 4.44\% and 3.31\% ER without and with mixup. Our PyTorch code will be made available upon paper acceptance at https://github.com/Labbeti/SSLH.},
number = {1},
journal = {EURASIP Journal on Audio, Speech, and Music Processing},
author = {Cances, Léo and Labbé, Etienne and Pellegrini, Thomas},
month = sep,
year = {2022},
pages = {23},
}
```
4 changes: 4 additions & 0 deletions conf/activation/log_softmax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package activation

_target_: "torch.nn.modules.activation.LogSoftmax"
dim: -1
3 changes: 3 additions & 0 deletions conf/activation/sigmoid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# @package activation

_target_: "torch.nn.modules.activation.Sigmoid"
4 changes: 4 additions & 0 deletions conf/activation/softmax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package activation

_target_: "torch.nn.modules.activation.Softmax"
dim: -1
5 changes: 5 additions & 0 deletions conf/aug/ident.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package aug

- type: "spectrogram"
aug:
_target_: "torch.nn.Identity"
14 changes: 14 additions & 0 deletions conf/aug/img_strong.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# @package aug

- type: "image"
aug:
_target_: "sslh.transforms.image.rand_augment.RandAugment"
n_augm_apply: 1
magnitude_policy: "random"
p: 1.0
- type: "image"
aug:
_target_: "sslh.transforms.image.pil.CutOutImgPIL"
scales: [0.2, 0.5]
fill_value: 0
p: 1.0
15 changes: 15 additions & 0 deletions conf/aug/img_weak.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# @package aug

- type: "image"
aug:
_target_: "torchvision.transforms.transforms.RandomVerticalFlip"
p: 0.5
- type: "image"
aug:
_target_: "torchvision.transforms.transforms.RandomVerticalFlip"
p: 0.25
- type: "image"
aug:
_target_: "torchvision.transforms.transforms.RandomCrop"
size: [32, 32]
padding: 8
7 changes: 7 additions & 0 deletions conf/aug/occlusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package aug

- type: "waveform"
aug:
_target_: "sslh.transforms.waveform.occlusion.Occlusion"
scales: [0.0, 0.25]
p: 0.5
14 changes: 14 additions & 0 deletions conf/aug/spec_aug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# @package aug

- type: "spectrogram"
aug:
_target_: "sslh.transforms.spectrogram.spec_aug.SpecAugmentation"
# default hparams source : https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py#L163
time_drop_width: 64
time_stripes_num: 2
freq_drop_width: 8
freq_stripes_num: 2
time_dim: 3
freq_dim: 2
inplace: true
p: 1.0
20 changes: 20 additions & 0 deletions conf/aug/strong.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package aug

- type: "waveform"
aug:
_target_: "sslh.transforms.waveform.occlusion.Occlusion"
scales: [0.0, 0.75]
p: 1.0
- type: "waveform"
aug:
_target_: "sslh.transforms.waveform.resample_pad_crop.ResamplePadCrop"
rates: [0.25, 1.75]
align: "random"
p: 1.0
- type: "spectrogram"
aug:
_target_: "sslh.transforms.spectrogram.cutoutspec.CutOutSpec"
freq_scales: [0.5, 1.0]
time_scales: [0.5, 1.0]
fill_value: -80.0
p: 1.0
20 changes: 20 additions & 0 deletions conf/aug/strong4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package aug

- type: "spectrogram"
aug:
_target_: "sslh.transforms.spectrogram.spec_aug.SpecAugmentation"
freq_drop_width: 8
freq_stripes_num: 2
time_drop_width: 16
time_stripes_num: 2
time_dim: 3
freq_dim: 2
inplace: true
p: 1.0

- type: "waveform"
aug:
_target_: "sslh.transforms.waveform.resample_pad_crop.ResamplePadCrop"
rates: [0.25, 1.75]
align: "random"
p: 1.0
20 changes: 20 additions & 0 deletions conf/aug/weak.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package aug

- type: "waveform"
aug:
_target_: "sslh.transforms.waveform.occlusion.Occlusion"
scales: [0.0, 0.25]
p: 0.5
- type: "waveform"
aug:
_target_: "sslh.transforms.waveform.resample_pad_crop.ResamplePadCrop"
rates: [0.5, 1.5]
align: "random"
p: 0.5
- type: "spectrogram"
aug:
_target_: "sslh.transforms.spectrogram.cutoutspec.CutOutSpec"
freq_scales: [0.1, 0.5]
time_scales: [0.1, 0.5]
fill_value: -80.0
p: 0.5
Loading

0 comments on commit 6190492

Please sign in to comment.