diff --git a/egs/librispeech/asr/simple_v1/mmi_bigram_train.py b/egs/librispeech/asr/simple_v1/mmi_bigram_train.py index f7d199aa..015cfa0c 100755 --- a/egs/librispeech/asr/simple_v1/mmi_bigram_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_bigram_train.py @@ -377,7 +377,7 @@ def main(): logging.info('Using BucketingSampler.') train_sampler = BucketingSampler( cuts_train, - max_frames=30000, + max_frames=40000, shuffle=True, num_buckets=30 ) @@ -450,7 +450,8 @@ def main(): 'only the batches seen in the master process (the actual loss ' 'includes batches from all GPUs, and the actual num_frames is ' f'approx. {args.world_size}x larger.') - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather... + # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) describe(model)