From 934730ebf789757f0dcbefa374134d3d99e7ab30 Mon Sep 17 00:00:00 2001 From: Olivier Dulcy Date: Thu, 16 Jan 2025 18:34:12 +0100 Subject: [PATCH] Grad accumulation - testing --- references/detection/train_pytorch.py | 41 ++++++++++++++++----------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index d6af3f69a..80b05d427 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -119,7 +119,7 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder -def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, clearml_log=False): +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, grad_accumulation_steps, amp=False, clearml_log=False): if amp: scaler = torch.cuda.amp.GradScaler() @@ -131,34 +131,38 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a # Iterate over the batches of the dataset pbar = tqdm(train_loader, position=1) - for images, targets in pbar: + for step, (images, targets) in enumerate(pbar): if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) - optimizer.zero_grad() if amp: with torch.cuda.amp.autocast(): - train_loss = model(images, targets)["loss"] + train_loss = model(images, targets)["loss"] / grad_accumulation_steps scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - # Update the params - scaler.step(optimizer) - scaler.update() else: - train_loss = model(images, targets)["loss"] + train_loss = model(images, targets)["loss"] / grad_accumulation_steps train_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - optimizer.step() - scheduler.step() - pbar.set_description(f"Training loss: {train_loss.item():.6}") + if (step + 1) % grad_accumulation_steps == 0 or step + 1 == len(train_loader): + if amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + optimizer.zero_grad() + scheduler.step() + + pbar.set_description(f"Training loss: {train_loss.item() * grad_accumulation_steps:.6f}") if clearml_log: global iteration logger.report_scalar( - title="Training Loss", series="train_loss", value=train_loss.item(), iteration=iteration + title="Training Loss", series="train_loss", value=train_loss.item() * grad_accumulation_steps, iteration=iteration ) iteration += 1 send_on_slack(f"Final training loss: {train_loss.item():.6}") @@ -471,12 +475,16 @@ def main(args): return # Scheduler + # Effective steps per epoch (due to grad accumulation) + grad_steps = args.grad_accumulation + effective_steps_per_epoch = len(train_loader) // grad_steps + total_steps = args.epochs * effective_steps_per_epoch if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + scheduler = CosineAnnealingLR(optimizer, total_steps, eta_min=args.lr / 25e4) elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + scheduler = OneCycleLR(optimizer, args.lr, total_steps) elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + scheduler = PolynomialLR(optimizer, total_steps) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -525,7 +533,7 @@ def main(args): # Training loop for epoch in range(args.epochs): fit_one_epoch( - model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, clearml_log=args.clearml + model, train_loader, batch_transforms, optimizer, scheduler, grad_steps, amp=args.amp, clearml_log=args.clearml ) # Validation loop at the end of each epoch val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) @@ -606,6 +614,7 @@ def parse_args(): parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") + parser.add_argument("--grad_accumulation", type=int, default=1, help="gradient accumulation steps") parser.add_argument("--device", default=None, type=int, help="device") parser.add_argument( "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch"