178
178
help = 'lower precision AMP dtype (default: float16)' )
179
179
group .add_argument ('--amp-impl' , default = 'native' , type = str ,
180
180
help = 'AMP impl to use, "native" or "apex" (default: native)' )
181
+ group .add_argument ('--model-dtype' , default = None , type = str ,
182
+ help = 'Model dtype override (non-AMP) (default: float32)' )
181
183
group .add_argument ('--no-ddp-bb' , action = 'store_true' , default = False ,
182
184
help = 'Force broadcast buffers for native DDP to off.' )
183
185
group .add_argument ('--synchronize-step' , action = 'store_true' , default = False ,
@@ -436,10 +438,18 @@ def main():
436
438
_logger .info (f'Training with a single process on 1 device ({ args .device } ).' )
437
439
assert args .rank >= 0
438
440
441
+ model_dtype = None
442
+ if args .model_dtype :
443
+ assert args .model_dtype in ('float32' , 'float16' , 'bfloat16' )
444
+ model_dtype = getattr (torch , args .model_dtype )
445
+ if model_dtype == torch .float16 :
446
+ _logger .warning ('float16 is not recommended for training, for half precision bfloat16 is recommended.' )
447
+
439
448
# resolve AMP arguments based on PyTorch / Apex availability
440
449
use_amp = None
441
450
amp_dtype = torch .float16
442
451
if args .amp :
452
+ assert model_dtype is None or model_dtype == torch .float32 , 'float32 model dtype must be used with AMP'
443
453
if args .amp_impl == 'apex' :
444
454
assert has_apex , 'AMP impl specified as APEX but APEX is not installed.'
445
455
use_amp = 'apex'
@@ -519,7 +529,7 @@ def main():
519
529
model = convert_splitbn_model (model , max (num_aug_splits , 2 ))
520
530
521
531
# move model to GPU, enable channels last layout if set
522
- model .to (device = device )
532
+ model .to (device = device , dtype = model_dtype ) # FIXME move model device & dtype into create_model
523
533
if args .channels_last :
524
534
model .to (memory_format = torch .channels_last )
525
535
@@ -589,7 +599,7 @@ def main():
589
599
_logger .info ('Using native Torch AMP. Training in mixed precision.' )
590
600
else :
591
601
if utils .is_primary (args ):
592
- _logger .info ('AMP not enabled. Training in float32.' )
602
+ _logger .info (f 'AMP not enabled. Training in { model_dtype or torch . float32 } .' )
593
603
594
604
# optionally resume from a checkpoint
595
605
resume_epoch = None
@@ -734,6 +744,7 @@ def main():
734
744
distributed = args .distributed ,
735
745
collate_fn = collate_fn ,
736
746
pin_memory = args .pin_mem ,
747
+ img_dtype = model_dtype ,
737
748
device = device ,
738
749
use_prefetcher = args .prefetcher ,
739
750
use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -758,6 +769,7 @@ def main():
758
769
distributed = args .distributed ,
759
770
crop_pct = data_config ['crop_pct' ],
760
771
pin_memory = args .pin_mem ,
772
+ img_dtype = model_dtype ,
761
773
device = device ,
762
774
use_prefetcher = args .prefetcher ,
763
775
)
@@ -822,21 +834,21 @@ def main():
822
834
with open (os .path .join (output_dir , 'args.yaml' ), 'w' ) as f :
823
835
f .write (args_text )
824
836
825
- if utils . is_primary ( args ) and args .log_wandb :
826
- if has_wandb :
827
- assert not args .wandb_resume_id or args .resume
828
- wandb .init (
829
- project = args .wandb_project ,
830
- name = args . experiment ,
831
- config = args ,
832
- tags = args .wandb_tags ,
833
- resume = "must" if args .wandb_resume_id else None ,
834
- id = args .wandb_resume_id if args .wandb_resume_id else None ,
835
- )
836
- else :
837
- _logger .warning (
838
- "You've requested to log metrics to wandb but package not found. "
839
- "Metrics not being logged to wandb, try `pip install wandb`" )
837
+ if args .log_wandb :
838
+ if has_wandb :
839
+ assert not args .wandb_resume_id or args .resume
840
+ wandb .init (
841
+ project = args .wandb_project ,
842
+ name = exp_name ,
843
+ config = args ,
844
+ tags = args .wandb_tags ,
845
+ resume = "must" if args .wandb_resume_id else None ,
846
+ id = args .wandb_resume_id if args .wandb_resume_id else None ,
847
+ )
848
+ else :
849
+ _logger .warning (
850
+ "You've requested to log metrics to wandb but package not found. "
851
+ "Metrics not being logged to wandb, try `pip install wandb`" )
840
852
841
853
# setup learning rate schedule and starting epoch
842
854
updates_per_epoch = (len (loader_train ) + args .grad_accum_steps - 1 ) // args .grad_accum_steps
@@ -886,6 +898,7 @@ def main():
886
898
output_dir = output_dir ,
887
899
amp_autocast = amp_autocast ,
888
900
loss_scaler = loss_scaler ,
901
+ model_dtype = model_dtype ,
889
902
model_ema = model_ema ,
890
903
mixup_fn = mixup_fn ,
891
904
num_updates_total = num_epochs * updates_per_epoch ,
@@ -904,6 +917,7 @@ def main():
904
917
args ,
905
918
device = device ,
906
919
amp_autocast = amp_autocast ,
920
+ model_dtype = model_dtype ,
907
921
)
908
922
909
923
if model_ema is not None and not args .model_ema_force_cpu :
@@ -986,6 +1000,7 @@ def train_one_epoch(
986
1000
output_dir = None ,
987
1001
amp_autocast = suppress ,
988
1002
loss_scaler = None ,
1003
+ model_dtype = None ,
989
1004
model_ema = None ,
990
1005
mixup_fn = None ,
991
1006
num_updates_total = None ,
@@ -1022,7 +1037,7 @@ def train_one_epoch(
1022
1037
accum_steps = last_accum_steps
1023
1038
1024
1039
if not args .prefetcher :
1025
- input , target = input .to (device ), target .to (device )
1040
+ input , target = input .to (device = device , dtype = model_dtype ), target .to (device = device )
1026
1041
if mixup_fn is not None :
1027
1042
input , target = mixup_fn (input , target )
1028
1043
if args .channels_last :
@@ -1149,6 +1164,7 @@ def validate(
1149
1164
args ,
1150
1165
device = torch .device ('cuda' ),
1151
1166
amp_autocast = suppress ,
1167
+ model_dtype = None ,
1152
1168
log_suffix = ''
1153
1169
):
1154
1170
batch_time_m = utils .AverageMeter ()
@@ -1164,8 +1180,8 @@ def validate(
1164
1180
for batch_idx , (input , target ) in enumerate (loader ):
1165
1181
last_batch = batch_idx == last_idx
1166
1182
if not args .prefetcher :
1167
- input = input .to (device )
1168
- target = target .to (device )
1183
+ input = input .to (device = device , dtype = model_dtype )
1184
+ target = target .to (device = device )
1169
1185
if args .channels_last :
1170
1186
input = input .contiguous (memory_format = torch .channels_last )
1171
1187
0 commit comments