Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

[WIP] Initially working multi-gpu training #71

Merged
merged 18 commits into from
Feb 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions egs/aishell/asr/simple_v1/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Note: The following initial result is obtained by Pingfeng Luo.

```
2021-02-10 18:12:22,691 INFO [mmi_bigram_decode.py:263] %WER 10.06% [10542 / 104765, 436 ins, 495 del, 9611 sub ] exp-lstm-adam-mmi-bigram-musan
```
Expand All @@ -10,10 +11,11 @@
(Fangjun): Results of <https://github.com/k2-fsa/snowfall/pull/99>

TensorBoard log is available at <https://tensorboard.dev/experiment/5bMFoRjVT7OMRWVFd3qVAA/#scalars>
and the training log can be downloaded using <https://github.com/k2-fsa/snowfall/files/5971503/log-train-2021-02-12-14-19-11.txt>.
and the training log can be downloaded
using <https://github.com/k2-fsa/snowfall/files/5971503/log-train-2021-02-12-14-19-11.txt>.

Decoding results of each epoch (the first line is WER and the second CER) are listed below.
They are obtained using the latest k2 and lhotse as of today (2021-02-12).
Decoding results of each epoch (the first line is WER and the second CER) are listed below. They are obtained using the
latest k2 and lhotse as of today (2021-02-12).

```
# epoch 0
Expand Down
6 changes: 2 additions & 4 deletions egs/aishell/asr/simple_v1/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,16 @@
from k2 import Fsa, SymbolTable
from kaldialign import edit_distance
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset import SingleCutSampler
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_phone_symbols
from snowfall.common import get_texts
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.common import get_texts
from snowfall.common import find_first_disambig_symbol
from snowfall.decoding.graph import compile_LG
from snowfall.models import AcousticModel
from snowfall.models.tdnn_lstm import TdnnLstm1b
Expand Down
25 changes: 12 additions & 13 deletions egs/aishell/asr/simple_v1/ctc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@
import torch.optim as optim
from datetime import datetime
from pathlib import Path
from torch import nn
from torch.nn.utils import clip_grad_value_
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Optional, Tuple

from lhotse import CutSet
from lhotse.dataset import CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler
from lhotse.utils import fix_random_seed
from snowfall.common import describe
from snowfall.common import get_phone_symbols
from snowfall.common import load_checkpoint, save_checkpoint
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.common import describe
from snowfall.models import AcousticModel
from snowfall.models.tdnn_lstm import TdnnLstm1b
from snowfall.training.ctc_graph import CtcTrainingGraphCompiler
Expand Down Expand Up @@ -214,8 +213,8 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader,
100.0 * total_valid_frames / total_valid_all_frames))

tb_writer.add_scalar('train/global_valid_average_objf',
valid_average_objf,
global_batch_idx_train)
valid_average_objf,
global_batch_idx_train)
prev_timestamp = datetime.now()
return total_objf / total_frames, valid_average_objf, global_batch_idx_train

Expand Down Expand Up @@ -346,15 +345,15 @@ def main():
logging.info('epoch {}, learning rate {}'.format(
epoch, curr_learning_rate))
objf, valid_objf, global_batch_idx_train = train_one_epoch(dataloader=train_dl,
valid_dataloader=valid_dl,
model=model,
device=device,
graph_compiler=graph_compiler,
optimizer=optimizer,
current_epoch=epoch,
tb_writer=tb_writer,
num_epochs=num_epochs,
global_batch_idx_train=global_batch_idx_train)
valid_dataloader=valid_dl,
model=model,
device=device,
graph_compiler=graph_compiler,
optimizer=optimizer,
current_epoch=epoch,
tb_writer=tb_writer,
num_epochs=num_epochs,
global_batch_idx_train=global_batch_idx_train)
# the lower, the better
if valid_objf < best_valid_objf:
best_valid_objf = valid_objf
Expand Down
46 changes: 23 additions & 23 deletions egs/aishell/asr/simple_v1/local2/aishell_train_lms.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#!/usr/bin/env bash


# To be run from one directory above this script.
. ./path.sh

text=data/local/train/text
lexicon=data/local/dict_nosp/lexicon.txt

for f in "$text" "$lexicon"; do
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
[ ! -f $x ] && echo "$0: No such file $f" && exit 1
done

# This script takes no arguments. It assumes you have already run
Expand All @@ -19,7 +18,7 @@ done
dir=data/local/lm
mkdir -p $dir

kaldi_lm=`which train_lm.sh`
kaldi_lm=$(which train_lm.sh)
if [ -z $kaldi_lm ]; then
echo "$0: train_lm.sh is not found. That might mean it's not installed"
echo "$0: or it is not added to PATH"
Expand All @@ -35,28 +34,28 @@ cleantext=$dir/text.no_oov

cat $text | awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<UNK> ");} } printf("\n");}' \
> $cleantext || exit 1;
>$cleantext || exit 1

cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \
sort -nr > $dir/word.counts || exit 1;
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c |
sort -nr >$dir/word.counts || exit 1

# Get counts from acoustic training transcripts, and add one-count
# for each word in the lexicon (but not silence, we don't want it
# in the LM-- we'll add it optionally later).
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1;
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' |
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') |
sort | uniq -c | sort -nr >$dir/unigram.counts || exit 1

# note: we probably won't really make use of <UNK> as there aren't any OOVs
cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "<s>" "</s>" "<UNK>" > $dir/word_map \
|| exit 1;
cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "<s>" "</s>" "<UNK>" >$dir/word_map ||
exit 1

# note: ignore 1st field of train.txt, it's the utterance-id.
cat $cleantext | awk -v wmap=$dir/word_map 'BEGIN{while((getline<wmap)>0)map[$1]=$2;}
{ for(n=2;n<=NF;n++) { printf map[$n]; if(n<NF){ printf " "; } else { print ""; }}}' | gzip -c >$dir/train.gz \
|| exit 1;
{ for(n=2;n<=NF;n++) { printf map[$n]; if(n<NF){ printf " "; } else { print ""; }}}' | gzip -c >$dir/train.gz ||
exit 1

train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1;
train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1

# LM is small enough that we don't need to prune it (only about 0.7M N-grams).
# Perplexity over 128254.000000 words is 90.446690
Expand All @@ -66,20 +65,21 @@ train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1;

exit 0


# From here is some commands to do a baseline with SRILM (assuming
# you have it installed).
heldout_sent=10000 # Don't change this if you want result to be comparable with
# kaldi_lm results
# kaldi_lm results
sdir=$dir/srilm # in case we want to use SRILM to double-check perplexities.
mkdir -p $sdir
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
head -$heldout_sent > $sdir/heldout
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
tail -n +$heldout_sent > $sdir/train

cat $dir/word_map | awk '{print $1}' | cat - <(echo "<s>"; echo "</s>" ) > $sdir/wordlist
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' |
head -$heldout_sent >$sdir/heldout
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' |
tail -n +$heldout_sent >$sdir/train

cat $dir/word_map | awk '{print $1}' | cat - <(
echo "<s>"
echo "</s>"
) >$sdir/wordlist

ngram-count -text $sdir/train -order 3 -limit-vocab -vocab $sdir/wordlist -unk \
-map-unk "<UNK>" -kndiscount -interpolate -lm $sdir/srilm.o3g.kn.gz
Expand All @@ -88,5 +88,5 @@ ngram -lm $sdir/srilm.o3g.kn.gz -ppl $sdir/heldout

# Note: perplexity SRILM gives to Kaldi-LM model is same as kaldi-lm reports above.
# Difference in WSJ must have been due to different treatment of <UNK>.
ngram -lm $dir/3gram-mincount/lm_unpruned.gz -ppl $sdir/heldout
ngram -lm $dir/3gram-mincount/lm_unpruned.gz -ppl $sdir/heldout
# 0 zeroprobs, logprob= -250913 ppl= 90.4439 ppl1= 132.379
7 changes: 3 additions & 4 deletions egs/aishell/asr/simple_v1/mmi_bigram_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from kaldialign import edit_distance
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_texts
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.common import get_texts
from snowfall.common import find_first_disambig_symbol
from snowfall.decoding.graph import compile_LG
from snowfall.models import AcousticModel
from snowfall.models.tdnn_lstm import TdnnLstm1b
Expand Down Expand Up @@ -108,7 +107,7 @@ def print_transition_probabilities(P: k2.Fsa, phone_symbol_table: SymbolTable,
num_phones = len(phone_ids)
table = np.zeros((num_phones + 1, num_phones + 2))
table[:, 0] = 0
table[0, -1] = 0 # the start state has no arcs to the final state
table[0, -1] = 0 # the start state has no arcs to the final state
assert P.arcs.dim0() == num_phones + 2
arcs = P.arcs.values()[:, :3]
probability = P.scores.exp().tolist()
Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/asr/simple_v1/mmi_bigram_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from lhotse import CutSet
from lhotse.dataset import CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler
from lhotse.utils import fix_random_seed
from snowfall.common import describe
from snowfall.common import load_checkpoint, save_checkpoint
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.common import describe
from snowfall.models import AcousticModel
from snowfall.models.tdnn_lstm import TdnnLstm1b
from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change
Expand Down
8 changes: 5 additions & 3 deletions egs/aishell/asr/simple_v1/prepare.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#!/usr/bin/env python3

from concurrent.futures import ProcessPoolExecutor

# Copyright (c) 2020 Xiaomi Corporation (authors: Junbo Zhang, Haowen Qiu)
# 2021 Pingfeng Luo
# Apache 2.0
import multiprocessing
import os
import sys
import subprocess
from concurrent.futures import ProcessPoolExecutor
import sys
import torch
from contextlib import contextmanager
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, LilcomHdf5Writer, combine
from lhotse.recipes import prepare_aishell, prepare_musan

Expand Down Expand Up @@ -63,6 +64,7 @@ def locate_corpus(corpus_dirs, msg):
print(msg)
sys.exit(1)


def main():
corpus_dir = locate_corpus(
(Path('/mnt/cfs2/asr/database/AM/aishell'),
Expand Down
6 changes: 3 additions & 3 deletions egs/aishell/asr/simple_v1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
set -eou pipefail

dataset_path=(
/mnt/cfs2/asr/database/AM/aishell
/root/fangjun/data/aishell
/home/storage04/zhuangweiji/data/open-source-data/SLR33-aishell/data
/mnt/cfs2/asr/database/AM/aishell
/root/fangjun/data/aishell
/home/storage04/zhuangweiji/data/open-source-data/SLR33-aishell/data
)

data=${dataset_path[0]}
Expand Down
Loading