Skip to content

Commit eeee38e

Browse files
committed
Avoid unecessary compat break btw train script and nearby timm versions w/ dtype addition.
1 parent deb9895 commit eeee38e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ def main():
744744
distributed=args.distributed,
745745
collate_fn=collate_fn,
746746
pin_memory=args.pin_mem,
747-
img_dtype=model_dtype,
747+
img_dtype=model_dtype or torch.float32,
748748
device=device,
749749
use_prefetcher=args.prefetcher,
750750
use_multi_epochs_loader=args.use_multi_epochs_loader,
@@ -769,7 +769,7 @@ def main():
769769
distributed=args.distributed,
770770
crop_pct=data_config['crop_pct'],
771771
pin_memory=args.pin_mem,
772-
img_dtype=model_dtype,
772+
img_dtype=model_dtype or torch.float32,
773773
device=device,
774774
use_prefetcher=args.prefetcher,
775775
)

validate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def validate(args):
307307
crop_border_pixels=args.crop_border_pixels,
308308
pin_memory=args.pin_mem,
309309
device=device,
310-
img_dtype=model_dtype,
310+
img_dtype=model_dtype or torch.float32,
311311
tf_preprocessing=args.tf_preprocessing,
312312
)
313313

0 commit comments

Comments
 (0)