Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SIGSEGV while running train.py on a multi GPU setup #1308

Closed
chandraka opened this issue Oct 26, 2019 · 6 comments
Closed

SIGSEGV while running train.py on a multi GPU setup #1308

chandraka opened this issue Oct 26, 2019 · 6 comments

Comments

@chandraka
Copy link

chandraka commented Oct 26, 2019

I have setup a ubuntu 18.04 4 CPU and 4 GPU environment to execute the librispeech dataset training.

The prepare step went through fine.

But when I launch the training using:
python train.py ./librispeech-workdir/preprocessed-data/ --save-dir ./librispeech-workdir/train-output/ --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/

I get the following error right at the outset:
)

| model vggtransformer_2, criterion CrossEntropyWithAccCriterion
| num. model params: 315190057 (num. trained: 315190057)
| training on 4 GPUs
| max tokens per GPU = 5000 and max sentences per GPU = None
| no existing checkpoint found ./librispeech-workdir/train-output/checkpoint_last.pt
| loading train data for epoch 0
Traceback (most recent call last):
File "train.py", line 343, in
cli_main()
File "train.py", line 335, in cli_main
nprocs=args.distributed_world_size,
File "/home/chandraka/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
while not spawn_context.join():
File "/home/chandraka/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 107, in join
(error_index, name)
Exception: process 0 terminated with signal SIGSEGV

Unable to proceed ahead in teh absence of any clues a to what might be causing it etc

Please help

It starts out with


| distributed init (rank 3): tcp://localhost:15160
| distributed init (rank 0): tcp://localhost:15160
| distributed init (rank 2): tcp://localhost:15160
| distributed init (rank 1): tcp://localhost:15160
| initialized host espresso-2 as rank 2
| initialized host espresso-2 as rank 1
| initialized host espresso-2 as rank 3
| initialized host espresso-2 as rank 0
Namespace(adadelta_eps=1e-08, adadelta_rho=0.95, anneal_eps=False, arch='vggtransformer_2', best_checkpoint_metric='loss', bpe=None, bucket_cap_mb=25, clip_norm=
10.0, conv_dec_config='((256, 3, True),) * 4', cpu=False, criterion='cross_entropy_acc', curriculum=0, data='./librispeech-workdir/preprocessed-data/', dataset_i
mpl=None, ddp_backend='c10d', device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method='tcp://localhost:15160', distributed_no_
spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=4, empty_cache_freq=0, enc_output_dim=1024, fast_stat_sync=False, find_unused_parame
ters=False, fix_batches_to_gpus=False, fixed_validation_seed=None, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_windo
w=None, input_feat_per_channel=80, keep_interval_updates=-1, keep_last_epochs=-1, log_format='json', log_interval=1, lr=[1.0], lr_scheduler='fixed', lr_shrink=0.
1, max_epoch=80, max_sentences=None, max_sentences_valid=None, max_tokens=5000, max_tokens_valid=5000, max_update=0, maximize_best_checkpoint_metric=False, memor
y_efficient_fp16=False, min_loss_scale=0.0001, min_lr=-1, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=False, no_save=False, no_save_op
timizer_state=False, num_workers=1, optimizer='adadelta', optimizer_overrides='{}', required_batch_size_multiple=8, reset_dataloader=False, reset_lr_scheduler=Fa
lse, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='./librispeech-workdir/train-output/', save_interval=1, save_interval
_updates=0, seed=1, sentence_avg=False, silence_token='▁', skip_invalid_size_inputs_valid_test=False, task='speech_recognition_e', tbmf_wrapper=False, tensorboar
d_logdir='', tgt_embed_dim=512, threshold_loss_scale=None, tokenizer=None, train_subset='train', transformer_dec_config='((1024, 16, 4096, True, 0.15, 0.15, 0.15
),) * 6', transformer_enc_config='((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16', update_freq=[1], use_bmuf=False, user_dir='examples/speech_recognition/', va
lid_subset='valid', validate_interval=1, vggblock_enc_config='[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]', warmup_updates=0, weight_decay=0.0)
| dictionary: 5001 types


@chandraka
Copy link
Author

I have tried this with a checkout of fairseq -- I get the same SIGSEGV.

I have also tried reducing the max-tokens to 1000 from 5000

@Shujian2015
Copy link

Shujian2015 commented Nov 18, 2019

Hi, I had the same issue although I was using espresso's transformer model (based on Fairseq's transformer). I used gpleiss/efficient_densenet_pytorch#47 (comment) to check what's wrong and it turned out that SIGSEGV came from FusedLayerNorm in https://github.com/pytorch/fairseq/blob/master/fairseq/modules/layer_norm.py#L13.

I am not sure what exactly happened there but taking this piece out (comment out Line 10-15) solved my issue.

Maybe it is better to use LayerNorm like this: https://github.com/pytorch/fairseq/blob/master/fairseq/models/wav2vec.py#L233

@Shujian2015
Copy link

And this may be related: NVIDIA/apex#156

@myleott
Copy link
Contributor

myleott commented Dec 16, 2019

Seems this is related to apex? You can comment this out to disable Apex's LayerNorm: https://github.com/pytorch/fairseq/blob/master/fairseq/modules/layer_norm.py#L10-L15

Please re-open if this continues to happen after disabling apex.

@myleott myleott closed this as completed Dec 16, 2019
facebook-github-bot pushed a commit that referenced this issue Dec 16, 2019
…i GPU setup (#1454)

Summary:
#1308
tgt in AsrDataset is list of torch tensors and it cause SIGSEGV error because tgt has too many objects to create shared memory in multiprocessing of dataloaders.
Pull Request resolved: #1454

Differential Revision: D18929874

Pulled By: okhonko

fbshipit-source-id: 5582b126890a93177258f5e053f32d5c6d32e9ab
moussaKam pushed a commit to moussaKam/language-adaptive-pretraining that referenced this issue Sep 29, 2020
…i GPU setup (facebookresearch#1454)

Summary:
facebookresearch#1308
tgt in AsrDataset is list of torch tensors and it cause SIGSEGV error because tgt has too many objects to create shared memory in multiprocessing of dataloaders.
Pull Request resolved: facebookresearch#1454

Differential Revision: D18929874

Pulled By: okhonko

fbshipit-source-id: 5582b126890a93177258f5e053f32d5c6d32e9ab
yzpang pushed a commit to yzpang/gold-off-policy-text-gen-iclr21 that referenced this issue Feb 19, 2021
…i GPU setup (#1454)

Summary:
facebookresearch/fairseq#1308
tgt in AsrDataset is list of torch tensors and it cause SIGSEGV error because tgt has too many objects to create shared memory in multiprocessing of dataloaders.
Pull Request resolved: facebookresearch/fairseq#1454

Differential Revision: D18929874

Pulled By: okhonko

fbshipit-source-id: 5582b126890a93177258f5e053f32d5c6d32e9ab
@brando90
Copy link

brando90 commented Mar 2, 2021

I have setup a ubuntu 18.04 4 CPU and 4 GPU environment to execute the librispeech dataset training.

The prepare step went through fine.

But when I launch the training using:
python train.py ./librispeech-workdir/preprocessed-data/ --save-dir ./librispeech-workdir/train-output/ --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/

I get the following error right at the outset: )
| model vggtransformer_2, criterion CrossEntropyWithAccCriterion
| num. model params: 315190057 (num. trained: 315190057)
| training on 4 GPUs
| max tokens per GPU = 5000 and max sentences per GPU = None
| no existing checkpoint found ./librispeech-workdir/train-output/checkpoint_last.pt
| loading train data for epoch 0
Traceback (most recent call last):
File "train.py", line 343, in
cli_main()
File "train.py", line 335, in cli_main
nprocs=args.distributed_world_size,
File "/home/chandraka/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
while not spawn_context.join():
File "/home/chandraka/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 107, in join
(error_index, name)
Exception: process 0 terminated with signal SIGSEGV

Unable to proceed ahead in teh absence of any clues a to what might be causing it etc

Please help

It starts out with

| distributed init (rank 3): tcp://localhost:15160
| distributed init (rank 0): tcp://localhost:15160
| distributed init (rank 2): tcp://localhost:15160
| distributed init (rank 1): tcp://localhost:15160
| initialized host espresso-2 as rank 2
| initialized host espresso-2 as rank 1
| initialized host espresso-2 as rank 3
| initialized host espresso-2 as rank 0
Namespace(adadelta_eps=1e-08, adadelta_rho=0.95, anneal_eps=False, arch='vggtransformer_2', best_checkpoint_metric='loss', bpe=None, bucket_cap_mb=25, clip_norm=
10.0, conv_dec_config='((256, 3, True),) * 4', cpu=False, criterion='cross_entropy_acc', curriculum=0, data='./librispeech-workdir/preprocessed-data/', dataset_i
mpl=None, ddp_backend='c10d', device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method='tcp://localhost:15160', distributed_no_
spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=4, empty_cache_freq=0, enc_output_dim=1024, fast_stat_sync=False, find_unused_parame
ters=False, fix_batches_to_gpus=False, fixed_validation_seed=None, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_windo
w=None, input_feat_per_channel=80, keep_interval_updates=-1, keep_last_epochs=-1, log_format='json', log_interval=1, lr=[1.0], lr_scheduler='fixed', lr_shrink=0.
1, max_epoch=80, max_sentences=None, max_sentences_valid=None, max_tokens=5000, max_tokens_valid=5000, max_update=0, maximize_best_checkpoint_metric=False, memor
y_efficient_fp16=False, min_loss_scale=0.0001, min_lr=-1, no_epoch_checkpoints=False, no_last_checkpoints=False, no_progress_bar=False, no_save=False, no_save_op
timizer_state=False, num_workers=1, optimizer='adadelta', optimizer_overrides='{}', required_batch_size_multiple=8, reset_dataloader=False, reset_lr_scheduler=Fa
lse, reset_meters=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='./librispeech-workdir/train-output/', save_interval=1, save_interval
_updates=0, seed=1, sentence_avg=False, silence_token='▁', skip_invalid_size_inputs_valid_test=False, task='speech_recognition_e', tbmf_wrapper=False, tensorboar
d_logdir='', tgt_embed_dim=512, threshold_loss_scale=None, tokenizer=None, train_subset='train', transformer_dec_config='((1024, 16, 4096, True, 0.15, 0.15, 0.15
),) * 6', transformer_enc_config='((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16', update_freq=[1], use_bmuf=False, user_dir='examples/speech_recognition/', va
lid_subset='valid', validate_interval=1, vggblock_enc_config='[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]', warmup_updates=0, weight_decay=0.0)
| dictionary: 5001 types

@chandraka did you ever figure out what was wrong and how to solve your issue?

@chandraka
Copy link
Author

chandraka commented Mar 3, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants