Skip to content

Commit

Permalink
Grad accumulation - testing
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Jan 16, 2025
1 parent 8f5ccf9 commit 934730e
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def record_lr(
return lr_recorder[: len(loss_recorder)], loss_recorder


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, clearml_log=False):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, grad_accumulation_steps, amp=False, clearml_log=False):
if amp:
scaler = torch.cuda.amp.GradScaler()

Expand All @@ -131,34 +131,38 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a

# Iterate over the batches of the dataset
pbar = tqdm(train_loader, position=1)
for images, targets in pbar:
for step, (images, targets) in enumerate(pbar):
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)

optimizer.zero_grad()
if amp:
with torch.cuda.amp.autocast():
train_loss = model(images, targets)["loss"]
train_loss = model(images, targets)["loss"] / grad_accumulation_steps
scaler.scale(train_loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# Update the params
scaler.step(optimizer)
scaler.update()
else:
train_loss = model(images, targets)["loss"]
train_loss = model(images, targets)["loss"] / grad_accumulation_steps
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()

scheduler.step()
pbar.set_description(f"Training loss: {train_loss.item():.6}")
if (step + 1) % grad_accumulation_steps == 0 or step + 1 == len(train_loader):
if amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()

optimizer.zero_grad()
scheduler.step()

pbar.set_description(f"Training loss: {train_loss.item() * grad_accumulation_steps:.6f}")
if clearml_log:
global iteration
logger.report_scalar(
title="Training Loss", series="train_loss", value=train_loss.item(), iteration=iteration
title="Training Loss", series="train_loss", value=train_loss.item() * grad_accumulation_steps, iteration=iteration
)
iteration += 1
send_on_slack(f"Final training loss: {train_loss.item():.6}")
Expand Down Expand Up @@ -471,12 +475,16 @@ def main(args):
return

# Scheduler
# Effective steps per epoch (due to grad accumulation)
grad_steps = args.grad_accumulation
effective_steps_per_epoch = len(train_loader) // grad_steps
total_steps = args.epochs * effective_steps_per_epoch
if args.sched == "cosine":
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
scheduler = CosineAnnealingLR(optimizer, total_steps, eta_min=args.lr / 25e4)
elif args.sched == "onecycle":
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))
scheduler = OneCycleLR(optimizer, args.lr, total_steps)
elif args.sched == "poly":
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader))
scheduler = PolynomialLR(optimizer, total_steps)

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -525,7 +533,7 @@ def main(args):
# Training loop
for epoch in range(args.epochs):
fit_one_epoch(
model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, clearml_log=args.clearml
model, train_loader, batch_transforms, optimizer, scheduler, grad_steps, amp=args.amp, clearml_log=args.clearml
)
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
Expand Down Expand Up @@ -606,6 +614,7 @@ def parse_args():
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
parser.add_argument("--grad_accumulation", type=int, default=1, help="gradient accumulation steps")
parser.add_argument("--device", default=None, type=int, help="device")
parser.add_argument(
"--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch"
Expand Down

0 comments on commit 934730e

Please sign in to comment.