Skip to content

Commit

Permalink
Fix examples/speech_recognition while using multi num_workers on mult…
Browse files Browse the repository at this point in the history
…i GPU setup (#1454)

Summary:
#1308
tgt in AsrDataset is list of torch tensors and it cause SIGSEGV error because tgt has too many objects to create shared memory in multiprocessing of dataloaders.
Pull Request resolved: #1454

Differential Revision: D18929874

Pulled By: okhonko

fbshipit-source-id: 5582b126890a93177258f5e053f32d5c6d32e9ab
  • Loading branch information
linus authored and facebook-github-bot committed Dec 16, 2019
1 parent a9c9304 commit 212a9c3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
9 changes: 5 additions & 4 deletions examples/speech_recognition/data/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(
self.frame_length = frame_length
self.frame_shift = frame_shift

self.s2s_collater = Seq2SeqCollater(
0, 1, pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
)

def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
Expand All @@ -72,10 +77,6 @@ def __getitem__(self, index):
frame_shift=self.frame_shift
)
output_cmvn = data_utils.apply_mv_norm(output)
self.s2s_collater = Seq2SeqCollater(
0, 1, pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
)

return {"id": index, "data": [output_cmvn.detach(), tgt_item]}

Expand Down
2 changes: 2 additions & 0 deletions examples/speech_recognition/data/collaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def collate(self, samples):
target = s["data"][self.label_index]
if isinstance(target, (np.ndarray, np.generic)):
target = torch.from_numpy(target).long()
elif isinstance(target, list):
target = torch.LongTensor(target)

parsed_sample = {"id": s["id"], "source": source, "target": target}
parsed_samples.append(parsed_sample)
Expand Down
4 changes: 2 additions & 2 deletions examples/speech_recognition/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict):
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
torch.LongTensor([int(i) for i in s[1]["output"]["tokenid"].split(", ")])
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
for s in sorted_samples
]
# append eos
tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt]
tgt = [[*t, tgt_dict.eos()] for t in tgt]
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)


Expand Down

0 comments on commit 212a9c3

Please sign in to comment.