Skip to content

Commit 2d0ac6f

Browse files
authored
Merge pull request #2397 from huggingface/half_prec_trainval
Add half-precision (bfloat16, float16) support to train & validate scripts
2 parents 6f80214 + 1969528 commit 2d0ac6f

File tree

3 files changed

+55
-26
lines changed

3 files changed

+55
-26
lines changed

timm/data/loader.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,18 @@ class PrefetchLoader:
7777

7878
def __init__(
7979
self,
80-
loader,
81-
mean=IMAGENET_DEFAULT_MEAN,
82-
std=IMAGENET_DEFAULT_STD,
83-
channels=3,
84-
device=torch.device('cuda'),
85-
img_dtype=torch.float32,
86-
fp16=False,
87-
re_prob=0.,
88-
re_mode='const',
89-
re_count=1,
90-
re_num_splits=0):
91-
80+
loader: torch.utils.data.DataLoader,
81+
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
82+
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
83+
channels: int = 3,
84+
device: torch.device = torch.device('cuda'),
85+
img_dtype: Optional[torch.dtype] = None,
86+
fp16: bool = False,
87+
re_prob: float = 0.,
88+
re_mode: str = 'const',
89+
re_count: int = 1,
90+
re_num_splits: int = 0,
91+
):
9292
mean = adapt_to_chs(mean, channels)
9393
std = adapt_to_chs(std, channels)
9494
normalization_shape = (1, channels, 1, 1)
@@ -98,7 +98,7 @@ def __init__(
9898
if fp16:
9999
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
100100
img_dtype = torch.float16
101-
self.img_dtype = img_dtype
101+
self.img_dtype = img_dtype or torch.float32
102102
self.mean = torch.tensor(
103103
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
104104
self.std = torch.tensor(

train.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@
178178
help='lower precision AMP dtype (default: float16)')
179179
group.add_argument('--amp-impl', default='native', type=str,
180180
help='AMP impl to use, "native" or "apex" (default: native)')
181+
group.add_argument('--model-dtype', default=None, type=str,
182+
help='Model dtype override (non-AMP) (default: float32)')
181183
group.add_argument('--no-ddp-bb', action='store_true', default=False,
182184
help='Force broadcast buffers for native DDP to off.')
183185
group.add_argument('--synchronize-step', action='store_true', default=False,
@@ -434,10 +436,18 @@ def main():
434436
_logger.info(f'Training with a single process on 1 device ({args.device}).')
435437
assert args.rank >= 0
436438

439+
model_dtype = None
440+
if args.model_dtype:
441+
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
442+
model_dtype = getattr(torch, args.model_dtype)
443+
if model_dtype == torch.float16:
444+
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
445+
437446
# resolve AMP arguments based on PyTorch / Apex availability
438447
use_amp = None
439448
amp_dtype = torch.float16
440449
if args.amp:
450+
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
441451
if args.amp_impl == 'apex':
442452
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
443453
use_amp = 'apex'
@@ -517,7 +527,7 @@ def main():
517527
model = convert_splitbn_model(model, max(num_aug_splits, 2))
518528

519529
# move model to GPU, enable channels last layout if set
520-
model.to(device=device)
530+
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
521531
if args.channels_last:
522532
model.to(memory_format=torch.channels_last)
523533

@@ -587,7 +597,7 @@ def main():
587597
_logger.info('Using native Torch AMP. Training in mixed precision.')
588598
else:
589599
if utils.is_primary(args):
590-
_logger.info('AMP not enabled. Training in float32.')
600+
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
591601

592602
# optionally resume from a checkpoint
593603
resume_epoch = None
@@ -732,6 +742,7 @@ def main():
732742
distributed=args.distributed,
733743
collate_fn=collate_fn,
734744
pin_memory=args.pin_mem,
745+
img_dtype=model_dtype,
735746
device=device,
736747
use_prefetcher=args.prefetcher,
737748
use_multi_epochs_loader=args.use_multi_epochs_loader,
@@ -756,6 +767,7 @@ def main():
756767
distributed=args.distributed,
757768
crop_pct=data_config['crop_pct'],
758769
pin_memory=args.pin_mem,
770+
img_dtype=model_dtype,
759771
device=device,
760772
use_prefetcher=args.prefetcher,
761773
)
@@ -823,9 +835,13 @@ def main():
823835
if utils.is_primary(args) and args.log_wandb:
824836
if has_wandb:
825837
assert not args.wandb_resume_id or args.resume
826-
wandb.init(project=args.experiment, config=args, tags=args.wandb_tags,
827-
resume='must' if args.wandb_resume_id else None,
828-
id=args.wandb_resume_id if args.wandb_resume_id else None)
838+
wandb.init(
839+
project=args.experiment,
840+
config=args,
841+
tags=args.wandb_tags,
842+
resume='must' if args.wandb_resume_id else None,
843+
id=args.wandb_resume_id if args.wandb_resume_id else None,
844+
)
829845
else:
830846
_logger.warning(
831847
"You've requested to log metrics to wandb but package not found. "
@@ -879,6 +895,7 @@ def main():
879895
output_dir=output_dir,
880896
amp_autocast=amp_autocast,
881897
loss_scaler=loss_scaler,
898+
model_dtype=model_dtype,
882899
model_ema=model_ema,
883900
mixup_fn=mixup_fn,
884901
num_updates_total=num_epochs * updates_per_epoch,
@@ -897,6 +914,7 @@ def main():
897914
args,
898915
device=device,
899916
amp_autocast=amp_autocast,
917+
model_dtype=model_dtype,
900918
)
901919

902920
if model_ema is not None and not args.model_ema_force_cpu:
@@ -979,6 +997,7 @@ def train_one_epoch(
979997
output_dir=None,
980998
amp_autocast=suppress,
981999
loss_scaler=None,
1000+
model_dtype=None,
9821001
model_ema=None,
9831002
mixup_fn=None,
9841003
num_updates_total=None,
@@ -1015,7 +1034,7 @@ def train_one_epoch(
10151034
accum_steps = last_accum_steps
10161035

10171036
if not args.prefetcher:
1018-
input, target = input.to(device), target.to(device)
1037+
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
10191038
if mixup_fn is not None:
10201039
input, target = mixup_fn(input, target)
10211040
if args.channels_last:
@@ -1142,6 +1161,7 @@ def validate(
11421161
args,
11431162
device=torch.device('cuda'),
11441163
amp_autocast=suppress,
1164+
model_dtype=None,
11451165
log_suffix=''
11461166
):
11471167
batch_time_m = utils.AverageMeter()
@@ -1157,8 +1177,8 @@ def validate(
11571177
for batch_idx, (input, target) in enumerate(loader):
11581178
last_batch = batch_idx == last_idx
11591179
if not args.prefetcher:
1160-
input = input.to(device)
1161-
target = target.to(device)
1180+
input = input.to(device=device, dtype=model_dtype)
1181+
target = target.to(device=device)
11621182
if args.channels_last:
11631183
input = input.contiguous(memory_format=torch.channels_last)
11641184

validate.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
help='lower precision AMP dtype (default: float16)')
124124
parser.add_argument('--amp-impl', default='native', type=str,
125125
help='AMP impl to use, "native" or "apex" (default: native)')
126+
parser.add_argument('--model-dtype', default=None, type=str,
127+
help='Model dtype override (non-AMP) (default: float32)')
126128
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
127129
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
128130
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@@ -168,10 +170,16 @@ def validate(args):
168170

169171
device = torch.device(args.device)
170172

173+
model_dtype = None
174+
if args.model_dtype:
175+
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
176+
model_dtype = getattr(torch, args.model_dtype)
177+
171178
# resolve AMP arguments based on PyTorch / Apex availability
172179
use_amp = None
173180
amp_autocast = suppress
174181
if args.amp:
182+
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
175183
if args.amp_impl == 'apex':
176184
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
177185
assert args.amp_dtype == 'float16'
@@ -184,7 +192,7 @@ def validate(args):
184192
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
185193
_logger.info('Validating in mixed precision with native PyTorch AMP.')
186194
else:
187-
_logger.info('Validating in float32. AMP not enabled.')
195+
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
188196

189197
if args.fuser:
190198
set_jit_fuser(args.fuser)
@@ -231,7 +239,7 @@ def validate(args):
231239
if args.test_pool:
232240
model, test_time_pool = apply_test_time_pool(model, data_config)
233241

234-
model = model.to(device)
242+
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
235243
if args.channels_last:
236244
model = model.to(memory_format=torch.channels_last)
237245

@@ -299,6 +307,7 @@ def validate(args):
299307
crop_border_pixels=args.crop_border_pixels,
300308
pin_memory=args.pin_mem,
301309
device=device,
310+
img_dtype=model_dtype,
302311
tf_preprocessing=args.tf_preprocessing,
303312
)
304313

@@ -310,7 +319,7 @@ def validate(args):
310319
model.eval()
311320
with torch.no_grad():
312321
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
313-
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
322+
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
314323
if args.channels_last:
315324
input = input.contiguous(memory_format=torch.channels_last)
316325
with amp_autocast():
@@ -319,8 +328,8 @@ def validate(args):
319328
end = time.time()
320329
for batch_idx, (input, target) in enumerate(loader):
321330
if args.no_prefetcher:
322-
target = target.to(device)
323-
input = input.to(device)
331+
target = target.to(device=device)
332+
input = input.to(device=device, dtype=model_dtype)
324333
if args.channels_last:
325334
input = input.contiguous(memory_format=torch.channels_last)
326335

0 commit comments

Comments
 (0)