Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Remove batch norm sync (hangs in epoch 9...) and increase batch size …
Browse files Browse the repository at this point in the history
…when using bucketing
  • Loading branch information
pzelasko committed Feb 13, 2021
1 parent e1dfd41 commit c396e55
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_bigram_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def main():
logging.info('Using BucketingSampler.')
train_sampler = BucketingSampler(
cuts_train,
max_frames=30000,
max_frames=40000,
shuffle=True,
num_buckets=30
)
Expand Down Expand Up @@ -450,7 +450,8 @@ def main():
'only the batches seen in the master process (the actual loss '
'includes batches from all GPUs, and the actual num_frames is '
f'approx. {args.world_size}x larger.')
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather...
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
describe(model)

Expand Down

0 comments on commit c396e55

Please sign in to comment.