178
178
help = 'lower precision AMP dtype (default: float16)' )
179
179
group .add_argument ('--amp-impl' , default = 'native' , type = str ,
180
180
help = 'AMP impl to use, "native" or "apex" (default: native)' )
181
+ group .add_argument ('--model-dtype' , default = None , type = str ,
182
+ help = 'Model dtype override (non-AMP) (default: float32)' )
181
183
group .add_argument ('--no-ddp-bb' , action = 'store_true' , default = False ,
182
184
help = 'Force broadcast buffers for native DDP to off.' )
183
185
group .add_argument ('--synchronize-step' , action = 'store_true' , default = False ,
@@ -434,10 +436,18 @@ def main():
434
436
_logger .info (f'Training with a single process on 1 device ({ args .device } ).' )
435
437
assert args .rank >= 0
436
438
439
+ model_dtype = None
440
+ if args .model_dtype :
441
+ assert args .model_dtype in ('float32' , 'float16' , 'bfloat16' )
442
+ model_dtype = getattr (torch , args .model_dtype )
443
+ if model_dtype == torch .float16 :
444
+ _logger .warning ('float16 is not recommended for training, for half precision bfloat16 is recommended.' )
445
+
437
446
# resolve AMP arguments based on PyTorch / Apex availability
438
447
use_amp = None
439
448
amp_dtype = torch .float16
440
449
if args .amp :
450
+ assert model_dtype is None or model_dtype == torch .float32 , 'float32 model dtype must be used with AMP'
441
451
if args .amp_impl == 'apex' :
442
452
assert has_apex , 'AMP impl specified as APEX but APEX is not installed.'
443
453
use_amp = 'apex'
@@ -517,7 +527,7 @@ def main():
517
527
model = convert_splitbn_model (model , max (num_aug_splits , 2 ))
518
528
519
529
# move model to GPU, enable channels last layout if set
520
- model .to (device = device )
530
+ model .to (device = device , dtype = model_dtype ) # FIXME move model device & dtype into create_model
521
531
if args .channels_last :
522
532
model .to (memory_format = torch .channels_last )
523
533
@@ -587,7 +597,7 @@ def main():
587
597
_logger .info ('Using native Torch AMP. Training in mixed precision.' )
588
598
else :
589
599
if utils .is_primary (args ):
590
- _logger .info ('AMP not enabled. Training in float32.' )
600
+ _logger .info (f 'AMP not enabled. Training in { model_dtype or torch . float32 } .' )
591
601
592
602
# optionally resume from a checkpoint
593
603
resume_epoch = None
@@ -732,6 +742,7 @@ def main():
732
742
distributed = args .distributed ,
733
743
collate_fn = collate_fn ,
734
744
pin_memory = args .pin_mem ,
745
+ img_dtype = model_dtype ,
735
746
device = device ,
736
747
use_prefetcher = args .prefetcher ,
737
748
use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -756,6 +767,7 @@ def main():
756
767
distributed = args .distributed ,
757
768
crop_pct = data_config ['crop_pct' ],
758
769
pin_memory = args .pin_mem ,
770
+ img_dtype = model_dtype ,
759
771
device = device ,
760
772
use_prefetcher = args .prefetcher ,
761
773
)
@@ -823,9 +835,13 @@ def main():
823
835
if utils .is_primary (args ) and args .log_wandb :
824
836
if has_wandb :
825
837
assert not args .wandb_resume_id or args .resume
826
- wandb .init (project = args .experiment , config = args , tags = args .wandb_tags ,
827
- resume = 'must' if args .wandb_resume_id else None ,
828
- id = args .wandb_resume_id if args .wandb_resume_id else None )
838
+ wandb .init (
839
+ project = args .experiment ,
840
+ config = args ,
841
+ tags = args .wandb_tags ,
842
+ resume = 'must' if args .wandb_resume_id else None ,
843
+ id = args .wandb_resume_id if args .wandb_resume_id else None ,
844
+ )
829
845
else :
830
846
_logger .warning (
831
847
"You've requested to log metrics to wandb but package not found. "
@@ -879,6 +895,7 @@ def main():
879
895
output_dir = output_dir ,
880
896
amp_autocast = amp_autocast ,
881
897
loss_scaler = loss_scaler ,
898
+ model_dtype = model_dtype ,
882
899
model_ema = model_ema ,
883
900
mixup_fn = mixup_fn ,
884
901
num_updates_total = num_epochs * updates_per_epoch ,
@@ -897,6 +914,7 @@ def main():
897
914
args ,
898
915
device = device ,
899
916
amp_autocast = amp_autocast ,
917
+ model_dtype = model_dtype ,
900
918
)
901
919
902
920
if model_ema is not None and not args .model_ema_force_cpu :
@@ -979,6 +997,7 @@ def train_one_epoch(
979
997
output_dir = None ,
980
998
amp_autocast = suppress ,
981
999
loss_scaler = None ,
1000
+ model_dtype = None ,
982
1001
model_ema = None ,
983
1002
mixup_fn = None ,
984
1003
num_updates_total = None ,
@@ -1015,7 +1034,7 @@ def train_one_epoch(
1015
1034
accum_steps = last_accum_steps
1016
1035
1017
1036
if not args .prefetcher :
1018
- input , target = input .to (device ), target .to (device )
1037
+ input , target = input .to (device = device , dtype = model_dtype ), target .to (device = device )
1019
1038
if mixup_fn is not None :
1020
1039
input , target = mixup_fn (input , target )
1021
1040
if args .channels_last :
@@ -1142,6 +1161,7 @@ def validate(
1142
1161
args ,
1143
1162
device = torch .device ('cuda' ),
1144
1163
amp_autocast = suppress ,
1164
+ model_dtype = None ,
1145
1165
log_suffix = ''
1146
1166
):
1147
1167
batch_time_m = utils .AverageMeter ()
@@ -1157,8 +1177,8 @@ def validate(
1157
1177
for batch_idx , (input , target ) in enumerate (loader ):
1158
1178
last_batch = batch_idx == last_idx
1159
1179
if not args .prefetcher :
1160
- input = input .to (device )
1161
- target = target .to (device )
1180
+ input = input .to (device = device , dtype = model_dtype )
1181
+ target = target .to (device = device )
1162
1182
if args .channels_last :
1163
1183
input = input .contiguous (memory_format = torch .channels_last )
1164
1184
0 commit comments