diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index d43f610ab..4cde7df4b 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -267,72 +267,72 @@ def main(args): batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)) - funsd_ds = DetectionDataset( - img_folder=os.path.join(args.funsd_path, "images"), - label_path=os.path.join(args.funsd_path, "labels.json"), - sample_transforms=T.SampleCompose( - ( - [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] - if not args.rotation or args.eval_straight - else [] - ) - + ( - [ - T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomApply(T.RandomRotate(90, expand=True), 0.5), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ] - if args.rotation and not args.eval_straight - else [] - ) - ), - use_polygons=args.rotation and not args.eval_straight, - ) - - funsd_test_loader = DataLoader( - funsd_ds, - batch_size=args.batch_size, - drop_last=False, - num_workers=args.workers, - sampler=SequentialSampler(funsd_ds), - pin_memory=torch.cuda.is_available(), - collate_fn=funsd_ds.collate_fn, - ) - print(f"FUNSD Test set loaded in {time.time() - st:.4}s ({len(funsd_ds)} samples in " f"{len(funsd_test_loader)} batches)") - - - cord_ds = DetectionDataset( - img_folder=os.path.join(args.cord_path, "images"), - label_path=os.path.join(args.cord_path, "labels.json"), - sample_transforms=T.SampleCompose( - ( - [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] - if not args.rotation or args.eval_straight - else [] - ) - + ( - [ - T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomApply(T.RandomRotate(90, expand=True), 0.5), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ] - if args.rotation and not args.eval_straight - else [] - ) - ), - use_polygons=args.rotation and not args.eval_straight, - ) - - cord_test_loader = DataLoader( - cord_ds, - batch_size=args.batch_size, - drop_last=False, - num_workers=args.workers, - sampler=SequentialSampler(cord_ds), - pin_memory=torch.cuda.is_available(), - collate_fn=cord_ds.collate_fn, - ) - print(f"CORD Test set loaded in {time.time() - st:.4}s ({len(cord_ds)} samples in " f"{len(funsd_test_loader)} batches)") + #funsd_ds = DetectionDataset( + # img_folder=os.path.join(args.funsd_path, "images"), + # label_path=os.path.join(args.funsd_path, "labels.json"), + # sample_transforms=T.SampleCompose( + # ( + # [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] + # if not args.rotation or args.eval_straight + # else [] + # ) + # + ( + # [ + # T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + # T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + # T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + # ] + # if args.rotation and not args.eval_straight + # else [] + # ) + # ), + # use_polygons=args.rotation and not args.eval_straight, + #) + + #funsd_test_loader = DataLoader( + # funsd_ds, + # batch_size=args.batch_size, + # drop_last=False, + # num_workers=args.workers, + # sampler=SequentialSampler(funsd_ds), + # pin_memory=torch.cuda.is_available(), + # collate_fn=funsd_ds.collate_fn, + #) + #print(f"FUNSD Test set loaded in {time.time() - st:.4}s ({len(funsd_ds)} samples in " f"{len(funsd_test_loader)} batches)") + + + #cord_ds = DetectionDataset( + # img_folder=os.path.join(args.cord_path, "images"), + # label_path=os.path.join(args.cord_path, "labels.json"), + # sample_transforms=T.SampleCompose( + # ( + # [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] + # if not args.rotation or args.eval_straight + # else [] + # ) + # + ( + # [ + # T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + # T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + # T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + # ] + # if args.rotation and not args.eval_straight + # else [] + # ) + # ), + # use_polygons=args.rotation and not args.eval_straight, + #) + + #cord_test_loader = DataLoader( + # cord_ds, + # batch_size=args.batch_size, + # drop_last=False, + # num_workers=args.workers, + # sampler=SequentialSampler(cord_ds), + # pin_memory=torch.cuda.is_available(), + # collate_fn=cord_ds.collate_fn, + #) + #print(f"CORD Test set loaded in {time.time() - st:.4}s ({len(cord_ds)} samples in " f"{len(funsd_test_loader)} batches)") # Load doctr model model = detection.__dict__[args.arch]( @@ -369,16 +369,16 @@ def main(args): mask_shape=(args.input_size, args.input_size), use_broadcasting=True if system_available_memory > 62 else False, ) - funsd_val_metric = LocalizationConfusion( - use_polygons=args.rotation and not args.eval_straight, - mask_shape=(args.input_size, args.input_size), - use_broadcasting=True if system_available_memory > 62 else False, - ) - cord_val_metric = LocalizationConfusion( - use_polygons=args.rotation and not args.eval_straight, - mask_shape=(args.input_size, args.input_size), - use_broadcasting=True if system_available_memory > 62 else False, - ) + #funsd_val_metric = LocalizationConfusion( + # use_polygons=args.rotation and not args.eval_straight, + # mask_shape=(args.input_size, args.input_size), + # use_broadcasting=True if system_available_memory > 62 else False, + #) + #cord_val_metric = LocalizationConfusion( + # use_polygons=args.rotation and not args.eval_straight, + # mask_shape=(args.input_size, args.input_size), + # use_broadcasting=True if system_available_memory > 62 else False, + #) if args.test_only: print("Running evaluation") @@ -510,18 +510,18 @@ def main(args): val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) funsd_recall, funsd_precision, funsd_mean_iou = 0.0, 0.0, 0.0 cord_recall, cord_precision, cord_mean_iou = 0.0, 0.0, 0.0 - try: - _, funsd_recall, funsd_precision, funsd_mean_iou = evaluate( - model, funsd_test_loader, batch_transforms, funsd_val_metric, amp=args.amp - ) - except Exception: - pass - try: - _, cord_recall, cord_precision, cord_mean_iou = evaluate( - model, cord_test_loader, batch_transforms, cord_val_metric, amp=args.amp - ) - except Exception: - pass + #try: + # _, funsd_recall, funsd_precision, funsd_mean_iou = evaluate( + # model, funsd_test_loader, batch_transforms, funsd_val_metric, amp=args.amp + # ) + #except Exception: + # pass + #try: + # _, cord_recall, cord_precision, cord_mean_iou = evaluate( + # model, cord_test_loader, batch_transforms, cord_val_metric, amp=args.amp + # ) + #except Exception: + # pass if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") send_on_slack(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") @@ -569,8 +569,6 @@ def parse_args(): parser.add_argument("train_path", type=str, help="path to training data folder") parser.add_argument("val_path", type=str, help="path to validation data folder") - parser.add_argument("funsd_path", type=str, help="path to FUNSD data folder") - parser.add_argument("cord_path", type=str, help="path to Cord data folder") parser.add_argument("arch", type=str, help="text-detection model to train") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")