File tree 2 files changed +3
-3
lines changed
2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -744,7 +744,7 @@ def main():
744
744
distributed = args .distributed ,
745
745
collate_fn = collate_fn ,
746
746
pin_memory = args .pin_mem ,
747
- img_dtype = model_dtype ,
747
+ img_dtype = model_dtype or torch . float32 ,
748
748
device = device ,
749
749
use_prefetcher = args .prefetcher ,
750
750
use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -769,7 +769,7 @@ def main():
769
769
distributed = args .distributed ,
770
770
crop_pct = data_config ['crop_pct' ],
771
771
pin_memory = args .pin_mem ,
772
- img_dtype = model_dtype ,
772
+ img_dtype = model_dtype or torch . float32 ,
773
773
device = device ,
774
774
use_prefetcher = args .prefetcher ,
775
775
)
Original file line number Diff line number Diff line change @@ -307,7 +307,7 @@ def validate(args):
307
307
crop_border_pixels = args .crop_border_pixels ,
308
308
pin_memory = args .pin_mem ,
309
309
device = device ,
310
- img_dtype = model_dtype ,
310
+ img_dtype = model_dtype or torch . float32 ,
311
311
tf_preprocessing = args .tf_preprocessing ,
312
312
)
313
313
You can’t perform that action at this time.
0 commit comments