From 212a9c340e0fe9930b39c501404de59dddaa1e9f Mon Sep 17 00:00:00 2001 From: linus Date: Mon, 16 Dec 2019 15:50:13 -0800 Subject: [PATCH] Fix examples/speech_recognition while using multi num_workers on multi GPU setup (#1454) Summary: https://github.com/pytorch/fairseq/issues/1308 tgt in AsrDataset is list of torch tensors and it cause SIGSEGV error because tgt has too many objects to create shared memory in multiprocessing of dataloaders. Pull Request resolved: https://github.com/pytorch/fairseq/pull/1454 Differential Revision: D18929874 Pulled By: okhonko fbshipit-source-id: 5582b126890a93177258f5e053f32d5c6d32e9ab --- examples/speech_recognition/data/asr_dataset.py | 9 +++++---- examples/speech_recognition/data/collaters.py | 2 ++ examples/speech_recognition/tasks/speech_recognition.py | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/speech_recognition/data/asr_dataset.py b/examples/speech_recognition/data/asr_dataset.py index b95b71d6af..47969a2853 100644 --- a/examples/speech_recognition/data/asr_dataset.py +++ b/examples/speech_recognition/data/asr_dataset.py @@ -56,6 +56,11 @@ def __init__( self.frame_length = frame_length self.frame_shift = frame_shift + self.s2s_collater = Seq2SeqCollater( + 0, 1, pad_index=self.tgt_dict.pad(), + eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True + ) + def __getitem__(self, index): import torchaudio import torchaudio.compliance.kaldi as kaldi @@ -72,10 +77,6 @@ def __getitem__(self, index): frame_shift=self.frame_shift ) output_cmvn = data_utils.apply_mv_norm(output) - self.s2s_collater = Seq2SeqCollater( - 0, 1, pad_index=self.tgt_dict.pad(), - eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True - ) return {"id": index, "data": [output_cmvn.detach(), tgt_item]} diff --git a/examples/speech_recognition/data/collaters.py b/examples/speech_recognition/data/collaters.py index 16166e55b2..14740d48b7 100644 --- a/examples/speech_recognition/data/collaters.py +++ b/examples/speech_recognition/data/collaters.py @@ -76,6 +76,8 @@ def collate(self, samples): target = s["data"][self.label_index] if isinstance(target, (np.ndarray, np.generic)): target = torch.from_numpy(target).long() + elif isinstance(target, list): + target = torch.LongTensor(target) parsed_sample = {"id": s["id"], "source": source, "target": target} parsed_samples.append(parsed_sample) diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index 699fa4a290..bd671e46dd 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -56,11 +56,11 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict): speakers.append(m.group(1) + "_" + m.group(2)) frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] tgt = [ - torch.LongTensor([int(i) for i in s[1]["output"]["tokenid"].split(", ")]) + [int(i) for i in s[1]["output"]["tokenid"].split(", ")] for s in sorted_samples ] # append eos - tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt] + tgt = [[*t, tgt_dict.eos()] for t in tgt] return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)