Skip to content

Commit c173886

Browse files
committed
Merge branch 'main' into caojiaolong-main
2 parents 40c19f3 + 2d0ac6f commit c173886

File tree

3 files changed

+63
-38
lines changed

3 files changed

+63
-38
lines changed

timm/data/loader.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,18 @@ class PrefetchLoader:
7777

7878
def __init__(
7979
self,
80-
loader,
81-
mean=IMAGENET_DEFAULT_MEAN,
82-
std=IMAGENET_DEFAULT_STD,
83-
channels=3,
84-
device=torch.device('cuda'),
85-
img_dtype=torch.float32,
86-
fp16=False,
87-
re_prob=0.,
88-
re_mode='const',
89-
re_count=1,
90-
re_num_splits=0):
91-
80+
loader: torch.utils.data.DataLoader,
81+
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
82+
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
83+
channels: int = 3,
84+
device: torch.device = torch.device('cuda'),
85+
img_dtype: Optional[torch.dtype] = None,
86+
fp16: bool = False,
87+
re_prob: float = 0.,
88+
re_mode: str = 'const',
89+
re_count: int = 1,
90+
re_num_splits: int = 0,
91+
):
9292
mean = adapt_to_chs(mean, channels)
9393
std = adapt_to_chs(std, channels)
9494
normalization_shape = (1, channels, 1, 1)
@@ -98,7 +98,7 @@ def __init__(
9898
if fp16:
9999
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
100100
img_dtype = torch.float16
101-
self.img_dtype = img_dtype
101+
self.img_dtype = img_dtype or torch.float32
102102
self.mean = torch.tensor(
103103
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
104104
self.std = torch.tensor(

train.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@
178178
help='lower precision AMP dtype (default: float16)')
179179
group.add_argument('--amp-impl', default='native', type=str,
180180
help='AMP impl to use, "native" or "apex" (default: native)')
181+
group.add_argument('--model-dtype', default=None, type=str,
182+
help='Model dtype override (non-AMP) (default: float32)')
181183
group.add_argument('--no-ddp-bb', action='store_true', default=False,
182184
help='Force broadcast buffers for native DDP to off.')
183185
group.add_argument('--synchronize-step', action='store_true', default=False,
@@ -436,10 +438,18 @@ def main():
436438
_logger.info(f'Training with a single process on 1 device ({args.device}).')
437439
assert args.rank >= 0
438440

441+
model_dtype = None
442+
if args.model_dtype:
443+
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
444+
model_dtype = getattr(torch, args.model_dtype)
445+
if model_dtype == torch.float16:
446+
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
447+
439448
# resolve AMP arguments based on PyTorch / Apex availability
440449
use_amp = None
441450
amp_dtype = torch.float16
442451
if args.amp:
452+
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
443453
if args.amp_impl == 'apex':
444454
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
445455
use_amp = 'apex'
@@ -519,7 +529,7 @@ def main():
519529
model = convert_splitbn_model(model, max(num_aug_splits, 2))
520530

521531
# move model to GPU, enable channels last layout if set
522-
model.to(device=device)
532+
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
523533
if args.channels_last:
524534
model.to(memory_format=torch.channels_last)
525535

@@ -589,7 +599,7 @@ def main():
589599
_logger.info('Using native Torch AMP. Training in mixed precision.')
590600
else:
591601
if utils.is_primary(args):
592-
_logger.info('AMP not enabled. Training in float32.')
602+
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
593603

594604
# optionally resume from a checkpoint
595605
resume_epoch = None
@@ -734,6 +744,7 @@ def main():
734744
distributed=args.distributed,
735745
collate_fn=collate_fn,
736746
pin_memory=args.pin_mem,
747+
img_dtype=model_dtype,
737748
device=device,
738749
use_prefetcher=args.prefetcher,
739750
use_multi_epochs_loader=args.use_multi_epochs_loader,
@@ -758,6 +769,7 @@ def main():
758769
distributed=args.distributed,
759770
crop_pct=data_config['crop_pct'],
760771
pin_memory=args.pin_mem,
772+
img_dtype=model_dtype,
761773
device=device,
762774
use_prefetcher=args.prefetcher,
763775
)
@@ -822,21 +834,21 @@ def main():
822834
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
823835
f.write(args_text)
824836

825-
if utils.is_primary(args) and args.log_wandb:
826-
if has_wandb:
827-
assert not args.wandb_resume_id or args.resume
828-
wandb.init(
829-
project=args.wandb_project,
830-
name=args.experiment,
831-
config=args,
832-
tags=args.wandb_tags,
833-
resume="must" if args.wandb_resume_id else None,
834-
id=args.wandb_resume_id if args.wandb_resume_id else None,
835-
)
836-
else:
837-
_logger.warning(
838-
"You've requested to log metrics to wandb but package not found. "
839-
"Metrics not being logged to wandb, try `pip install wandb`")
837+
if args.log_wandb:
838+
if has_wandb:
839+
assert not args.wandb_resume_id or args.resume
840+
wandb.init(
841+
project=args.wandb_project,
842+
name=exp_name,
843+
config=args,
844+
tags=args.wandb_tags,
845+
resume="must" if args.wandb_resume_id else None,
846+
id=args.wandb_resume_id if args.wandb_resume_id else None,
847+
)
848+
else:
849+
_logger.warning(
850+
"You've requested to log metrics to wandb but package not found. "
851+
"Metrics not being logged to wandb, try `pip install wandb`")
840852

841853
# setup learning rate schedule and starting epoch
842854
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
@@ -886,6 +898,7 @@ def main():
886898
output_dir=output_dir,
887899
amp_autocast=amp_autocast,
888900
loss_scaler=loss_scaler,
901+
model_dtype=model_dtype,
889902
model_ema=model_ema,
890903
mixup_fn=mixup_fn,
891904
num_updates_total=num_epochs * updates_per_epoch,
@@ -904,6 +917,7 @@ def main():
904917
args,
905918
device=device,
906919
amp_autocast=amp_autocast,
920+
model_dtype=model_dtype,
907921
)
908922

909923
if model_ema is not None and not args.model_ema_force_cpu:
@@ -986,6 +1000,7 @@ def train_one_epoch(
9861000
output_dir=None,
9871001
amp_autocast=suppress,
9881002
loss_scaler=None,
1003+
model_dtype=None,
9891004
model_ema=None,
9901005
mixup_fn=None,
9911006
num_updates_total=None,
@@ -1022,7 +1037,7 @@ def train_one_epoch(
10221037
accum_steps = last_accum_steps
10231038

10241039
if not args.prefetcher:
1025-
input, target = input.to(device), target.to(device)
1040+
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
10261041
if mixup_fn is not None:
10271042
input, target = mixup_fn(input, target)
10281043
if args.channels_last:
@@ -1149,6 +1164,7 @@ def validate(
11491164
args,
11501165
device=torch.device('cuda'),
11511166
amp_autocast=suppress,
1167+
model_dtype=None,
11521168
log_suffix=''
11531169
):
11541170
batch_time_m = utils.AverageMeter()
@@ -1164,8 +1180,8 @@ def validate(
11641180
for batch_idx, (input, target) in enumerate(loader):
11651181
last_batch = batch_idx == last_idx
11661182
if not args.prefetcher:
1167-
input = input.to(device)
1168-
target = target.to(device)
1183+
input = input.to(device=device, dtype=model_dtype)
1184+
target = target.to(device=device)
11691185
if args.channels_last:
11701186
input = input.contiguous(memory_format=torch.channels_last)
11711187

validate.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
help='lower precision AMP dtype (default: float16)')
124124
parser.add_argument('--amp-impl', default='native', type=str,
125125
help='AMP impl to use, "native" or "apex" (default: native)')
126+
parser.add_argument('--model-dtype', default=None, type=str,
127+
help='Model dtype override (non-AMP) (default: float32)')
126128
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
127129
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
128130
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@@ -168,10 +170,16 @@ def validate(args):
168170

169171
device = torch.device(args.device)
170172

173+
model_dtype = None
174+
if args.model_dtype:
175+
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
176+
model_dtype = getattr(torch, args.model_dtype)
177+
171178
# resolve AMP arguments based on PyTorch / Apex availability
172179
use_amp = None
173180
amp_autocast = suppress
174181
if args.amp:
182+
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
175183
if args.amp_impl == 'apex':
176184
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
177185
assert args.amp_dtype == 'float16'
@@ -184,7 +192,7 @@ def validate(args):
184192
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
185193
_logger.info('Validating in mixed precision with native PyTorch AMP.')
186194
else:
187-
_logger.info('Validating in float32. AMP not enabled.')
195+
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
188196

189197
if args.fuser:
190198
set_jit_fuser(args.fuser)
@@ -231,7 +239,7 @@ def validate(args):
231239
if args.test_pool:
232240
model, test_time_pool = apply_test_time_pool(model, data_config)
233241

234-
model = model.to(device)
242+
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
235243
if args.channels_last:
236244
model = model.to(memory_format=torch.channels_last)
237245

@@ -299,6 +307,7 @@ def validate(args):
299307
crop_border_pixels=args.crop_border_pixels,
300308
pin_memory=args.pin_mem,
301309
device=device,
310+
img_dtype=model_dtype,
302311
tf_preprocessing=args.tf_preprocessing,
303312
)
304313

@@ -310,7 +319,7 @@ def validate(args):
310319
model.eval()
311320
with torch.no_grad():
312321
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
313-
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
322+
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
314323
if args.channels_last:
315324
input = input.contiguous(memory_format=torch.channels_last)
316325
with amp_autocast():
@@ -319,8 +328,8 @@ def validate(args):
319328
end = time.time()
320329
for batch_idx, (input, target) in enumerate(loader):
321330
if args.no_prefetcher:
322-
target = target.to(device)
323-
input = input.to(device)
331+
target = target.to(device=device)
332+
input = input.to(device=device, dtype=model_dtype)
324333
if args.channels_last:
325334
input = input.contiguous(memory_format=torch.channels_last)
326335

0 commit comments

Comments
 (0)