Recycle-and-Distill: Universal Compression Strategy for Transformer-based Speech SSL Models with Attention Map Reusing and Masking Distillation, INTERSPEECH 2023.
Kangwook Jang*,
Sungnyun Kim*,
Se-Young Yun, Hoirin Kim
* equal contribution
- Attention Map Reusing: Reuse previous layer's attention map to remove key & query parameters in Transformer
- Masking Distillation: Masking distillation treating masked frames and unmasked frames separately
- Parameters and MACs of ARMHuBERT have decreased to 28% and 30% of the teacher, HuBERT Base, respectively.
- ARMHuBERT achieves PER of 7.72%, WER of 9.96% on the SUPERB benchmark in an E2E distillation manner.
📌 Check out our model's performance in SUPERB Leaderboard!
For our model's checkpoints, go check this link!
Model name | Parameters | Teacher | Training dataset | Link |
---|---|---|---|---|
ARMHuBERT-960h | 26.45M | HuBERT | LibriSpeech-960h | HF Model |
ARMHuBERT-S-100h | 22.39M | HuBERT | LibriSpeech-100h | HF Model |
ARMHuBERT-S-960h | 22.39M | HuBERT | LibriSpeech-960h | HF Model |
ARMwavLM-S-100h | 22.39M | wavLM | LibriSpeech-100h | HF Model |
ARMwavLM-S-960h | 22.39M | wavLM | LibriSpeech-960h | HF Model |
MaskHuBERT-960h | 26.64M | HuBERT | LibriSpeech-960h | HF Model |
Install the necessary packages with:
$ pip install -r requirements.txt
-
Download the teacher model checkpoint to perform knowledge distillation, and place it under the root path,
./
. -
Download the LibriSpeech dataset.
- For 100h distillation, download
train-clean-100
- For 960h distillation, download whole dataset,
train-clean-100
,train-clean-360
,train-other-500
- For validation, download
dev-clean
- You can validate your model with test clean other either. In this case, please download
test-clean
, and modifyself.eval_data
intrain.py
file.
- You can validate your model with test clean other either. In this case, please download
- For 100h distillation, download
-
Modify the configuration file in
./conf/[model_name]/[config].yaml
.- For example, the configuration file
./conf/armhubert/armhubert-960.yaml
contains all the settings for reproducing ARMHuBERT on LibriSpeech 960h dataset. - Set the path to the teacher model checkpoint at
teacher_model
, and the root path to the LibriSpeech dataset atlibri_root
.
- For example, the configuration file
-
Then, run the following command:
python train.py -c ./conf/[model_name]/[config].yaml
For ARMHuBERT,
python train.py -c ./conf/armhubert/armhubert-960.yaml
After training, the model checkpoints and the corresponding configuration file will be created at ./results/pretrain/
.
-
If you don't feel like training your model, feel free to use our checkpoints.
-
Clone and install the S3PRL toolkit with
pip install -e ".[all]"
(dev mode). -
Copy the entire
./models/[model_name]
folder into<s3prl root>/s3prl/upstream/
. -
Please add upstream importing line in
<s3prl root>/s3prl/hub.py
.from s3prl.upstream.[model_name].hubconf import *
For ARMHuBERT,
from s3prl.upstream.armhubert.hubconf import *
-
Please change each config file of s3prl downstream tasks as follows.
- Uncomment learning rate scheduler
- Learning rate scaled to 10x in spekaer identification (SID) task
-
Run the following command to fine-tune the ARMHuBERT model.
For automatic speech recognition (ASR) as an example:
python run_downstream.py \ -m train \ -n ARMHuBERT-ASR \ # You can set your exp name whatever you want -u armhubert \ -d asr \ -k <path to .ckpt file in <git root>/results/pretrain/> \ -g <path to .yaml file in <git root>/results/pretrain/>
Note: Refer to the SUPERB docs for more information on usage details and data preparation.
We evaluate our student models on the SUPERB benchmark.
MaskHuBERT highly improves the performances in content- and semantics-related tasks. See PR, ASR, SF, and IC.
ARMHuBERT shows promising improvements when compared to MaskHuBERT in SF and SID tasks, exhibiting a similar level of performance in other tasks.
ARMHuBERT achieves a better overall score of 78.1 with less parameters than MaskHuBERT. This is an state-of-the-art performance for an end-to-end distillation approach such as Deep-versus-wide 12-L or FitHuBERT.
You can also check that our model works on other Transformer backbone model, wavLM, too.
We have only performed evaluation on HuBERT-based models, but this strategy can be performed identically on any speech model with a Transformer backbone. E.g. AST (Audio Spectrogram Transformer).
If you find this repo useful for your research, please consider citing our paper:
@article{jang2023recycleanddistill,
title={Recycle-and-Distill: Universal Compression Strategy for Transformer-based Speech SSL Models with Attention Map Reusing and Masking Distillation},
author={Kangwook Jang and Sungnyun Kim and Se-Young Yun and Hoirin Kim},
booktitle={Proc. INTERSPEECH 2023},
pages={316--320},
year={2023}
}
🎉 Update (Apr 12, 2024): Our new paper, STaR, has been selected as Best Student Paper in ICASSP 2024!
🎉 Check out our model's performance in SUPERB Leaderboard!
STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models, ICASSP 2024.
Kangwook Jang,
Sungnyun Kim,
Hoirin Kim
- Speech Temporal Relation (STaR): Distill the knowledge by focusing on the pairwise temporal relation between two speech frames.
- Temporal Gram Matrix (TGM): Propose Temporal Gram Matrix which aggregates channel information at two time steps.
- Layer-wise TGM: Distill the TGM for every Transformer layer
- Intra-layer TGM: Modify the TGM as computing the temporal relation between the input and output of a single Transformer layer.
- Incorporating two TGMs as the distillation objectives together, our student model STaRHuBERT (22M & 26M) shows the SOTA performance on the SUPERB benchmark with the metric of overall and generalizability scores.
- For further compression (9.39M & 14.1M), our approach shows the robust performance against degradation compares to other works.
For our model's checkpoints, please check the following links. All models are distilled from HuBERT base.
- STaRHuBERT-L (26.6M): ckpt, yaml
- STaRHuBERT (22.3M): ckpt, yaml
- STaRHuBERT-S (14.1M): ckpt, yaml
- STaRHuBERT-XS (9.39M): ckpt, yaml
We also add the model distilled from WavLM base models!
We do not offer an official implementation code for distillation. Nevertheless, since STaRHuBERT is developed upon the backbone of ARMHuBERT, you can easily re-implement our apporach with this ARMHuBERT repository.
You can reproduce our model with given checkpoints. Please follow the steps. (This is almost the same as ARMHuBERT case.)
-
Clone and install the S3PRL toolkit with
pip install -e ".[all]"
(dev mode). -
Copy the entire
./models/starhubert
folder into<s3prl root>/s3prl/upstream/
. -
Please add upstream importing line in
<s3prl root>/s3prl/hub.py
.from s3prl.upstream.starhubert.hubconf import *
-
Please change each config file of s3prl downstream tasks as follows.
- Uncomment learning rate scheduler
- Learning rate scaled to 10x in spekaer identification (SID) task
-
Run the following command to fine-tune the ARMHuBERT model.
For automatic speech recognition (ASR) as an example:
python run_downstream.py \ -m train \ -n STaRHuBERT-ASR \ # You can set your exp name whatever you want -u starhubert \ -d asr \ -k <path to .ckpt file in <git root>/results/pretrain/> \ -g <path to .yaml file in <git root>/results/pretrain/>
Note: Refer to the SUPERB docs for more information on usage details and data preparation.
If you find this repo useful for your research, please consider citing our paper:
@inproceedings{jang2024star,
title={STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models},
author={Jang, Kangwook and Kim, Sungnyun and Kim, Hoirin},
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={10721--10725},
year={2024},
organization={IEEE}
}
For any details or clarification, please reach out to
- Kangwook Jang: dnrrkdwkd12@kaist.ac.kr
- Sungnyun Kim: ksn4397@kaist.ac.kr