diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 2add7873c..d43f610ab 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -213,46 +213,6 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): recall, precision, mean_iou = val_metric.summary() return val_loss, recall, precision, mean_iou -@torch.no_grad() -def sec_evaluate(model, val_loader, batch_transforms, val_metric, amp=False): - # Model in eval mode - model.eval() - # Reset val metric - val_metric.reset() - last_progress = 0 - interval_progress = 5 - pbar = tqdm(val_loader) - send_on_slack(str(pbar)) - # Validation loop - val_loss, batch_cnt = 0, 0 - for images, targets in pbar: - if torch.cuda.is_available(): - images = images.cuda() - images = batch_transforms(images) - targets = [{CLASS_NAME: t["boxes"]} for t in targets] - if amp: - with torch.cuda.amp.autocast(): - out = model(images, targets, return_preds=True) - else: - out = model(images, targets, return_preds=True) - # Compute metric - loc_preds = out["preds"] - for target, loc_pred in zip(targets, loc_preds): - for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): - # Remove scores - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) - - current_progress = pbar.n / pbar.total * 100 - if current_progress - last_progress > interval_progress: - send_on_slack(str(pbar)) - last_progress = int(current_progress) - val_loss += out["loss"].item() - batch_cnt += 1 - - val_loss /= batch_cnt - recall, precision, mean_iou = val_metric.summary() - return val_loss, recall, precision, mean_iou - def main(args): print(args) @@ -307,24 +267,27 @@ def main(args): batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)) - funsd_ds = datasets.FUNSD( - train=True, - download=True, - use_polygons=args.rotation, - sample_transforms=T.Resize((args.input_size, args.input_size)), - ) - # Monkeypatch - subfolder = funsd_ds.root.split("/")[-2:] - funsd_ds.root = str(Path(funsd_ds.root).parent.parent) - funsd_ds.data = [(os.path.join(*subfolder, name), target) for name, target in funsd_ds.data] - _funsd_ds = datasets.FUNSD( - train=False, - download=True, - use_polygons=args.rotation, - sample_transforms=T.Resize((args.input_size, args.input_size)), + funsd_ds = DetectionDataset( + img_folder=os.path.join(args.funsd_path, "images"), + label_path=os.path.join(args.funsd_path, "labels.json"), + sample_transforms=T.SampleCompose( + ( + [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] + if not args.rotation or args.eval_straight + else [] + ) + + ( + [ + T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if args.rotation and not args.eval_straight + else [] + ) + ), + use_polygons=args.rotation and not args.eval_straight, ) - subfolder = _funsd_ds.root.split("/")[-2:] - funsd_ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _funsd_ds.data]) funsd_test_loader = DataLoader( funsd_ds, @@ -338,24 +301,27 @@ def main(args): print(f"FUNSD Test set loaded in {time.time() - st:.4}s ({len(funsd_ds)} samples in " f"{len(funsd_test_loader)} batches)") - cord_ds = datasets.CORD( - train=True, - download=True, - use_polygons=args.rotation, - sample_transforms=T.Resize((args.input_size, args.input_size)), - ) - # Monkeypatch - subfolder = cord_ds.root.split("/")[-2:] - cord_ds.root = str(Path(cord_ds.root).parent.parent) - cord_ds.data = [(os.path.join(*subfolder, name), target) for name, target in cord_ds.data] - _cord_ds = datasets.CORD( - train=False, - download=True, - use_polygons=args.rotation, - sample_transforms=T.Resize((args.input_size, args.input_size)), + cord_ds = DetectionDataset( + img_folder=os.path.join(args.cord_path, "images"), + label_path=os.path.join(args.cord_path, "labels.json"), + sample_transforms=T.SampleCompose( + ( + [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] + if not args.rotation or args.eval_straight + else [] + ) + + ( + [ + T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if args.rotation and not args.eval_straight + else [] + ) + ), + use_polygons=args.rotation and not args.eval_straight, ) - subfolder = _cord_ds.root.split("/")[-2:] - cord_ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _cord_ds.data]) cord_test_loader = DataLoader( cord_ds, @@ -545,13 +511,13 @@ def main(args): funsd_recall, funsd_precision, funsd_mean_iou = 0.0, 0.0, 0.0 cord_recall, cord_precision, cord_mean_iou = 0.0, 0.0, 0.0 try: - _, funsd_recall, funsd_precision, funsd_mean_iou = sec_evaluate( + _, funsd_recall, funsd_precision, funsd_mean_iou = evaluate( model, funsd_test_loader, batch_transforms, funsd_val_metric, amp=args.amp ) except Exception: pass try: - _, cord_recall, cord_precision, cord_mean_iou = sec_evaluate( + _, cord_recall, cord_precision, cord_mean_iou = evaluate( model, cord_test_loader, batch_transforms, cord_val_metric, amp=args.amp ) except Exception: @@ -603,6 +569,8 @@ def parse_args(): parser.add_argument("train_path", type=str, help="path to training data folder") parser.add_argument("val_path", type=str, help="path to validation data folder") + parser.add_argument("funsd_path", type=str, help="path to FUNSD data folder") + parser.add_argument("cord_path", type=str, help="path to Cord data folder") parser.add_argument("arch", type=str, help="text-detection model to train") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")