Skip to content

Commit

Permalink
apply patch from felixdittrich92@27bc838
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Mar 26, 2024
1 parent cc795f6 commit 34e32eb
Showing 1 changed file with 44 additions and 76 deletions.
120 changes: 44 additions & 76 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,46 +213,6 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
recall, precision, mean_iou = val_metric.summary()
return val_loss, recall, precision, mean_iou

@torch.no_grad()
def sec_evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
val_metric.reset()
last_progress = 0
interval_progress = 5
pbar = tqdm(val_loader)
send_on_slack(str(pbar))
# Validation loop
val_loss, batch_cnt = 0, 0
for images, targets in pbar:
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
targets = [{CLASS_NAME: t["boxes"]} for t in targets]
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
# Compute metric
loc_preds = out["preds"]
for target, loc_pred in zip(targets, loc_preds):
for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])

current_progress = pbar.n / pbar.total * 100
if current_progress - last_progress > interval_progress:
send_on_slack(str(pbar))
last_progress = int(current_progress)
val_loss += out["loss"].item()
batch_cnt += 1

val_loss /= batch_cnt
recall, precision, mean_iou = val_metric.summary()
return val_loss, recall, precision, mean_iou


def main(args):
print(args)
Expand Down Expand Up @@ -307,24 +267,27 @@ def main(args):

batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287))

funsd_ds = datasets.FUNSD(
train=True,
download=True,
use_polygons=args.rotation,
sample_transforms=T.Resize((args.input_size, args.input_size)),
)
# Monkeypatch
subfolder = funsd_ds.root.split("/")[-2:]
funsd_ds.root = str(Path(funsd_ds.root).parent.parent)
funsd_ds.data = [(os.path.join(*subfolder, name), target) for name, target in funsd_ds.data]
_funsd_ds = datasets.FUNSD(
train=False,
download=True,
use_polygons=args.rotation,
sample_transforms=T.Resize((args.input_size, args.input_size)),
funsd_ds = DetectionDataset(
img_folder=os.path.join(args.funsd_path, "images"),
label_path=os.path.join(args.funsd_path, "labels.json"),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
if not args.rotation or args.eval_straight
else []
)
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation and not args.eval_straight
else []
)
),
use_polygons=args.rotation and not args.eval_straight,
)
subfolder = _funsd_ds.root.split("/")[-2:]
funsd_ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _funsd_ds.data])

funsd_test_loader = DataLoader(
funsd_ds,
Expand All @@ -338,24 +301,27 @@ def main(args):
print(f"FUNSD Test set loaded in {time.time() - st:.4}s ({len(funsd_ds)} samples in " f"{len(funsd_test_loader)} batches)")


cord_ds = datasets.CORD(
train=True,
download=True,
use_polygons=args.rotation,
sample_transforms=T.Resize((args.input_size, args.input_size)),
)
# Monkeypatch
subfolder = cord_ds.root.split("/")[-2:]
cord_ds.root = str(Path(cord_ds.root).parent.parent)
cord_ds.data = [(os.path.join(*subfolder, name), target) for name, target in cord_ds.data]
_cord_ds = datasets.CORD(
train=False,
download=True,
use_polygons=args.rotation,
sample_transforms=T.Resize((args.input_size, args.input_size)),
cord_ds = DetectionDataset(
img_folder=os.path.join(args.cord_path, "images"),
label_path=os.path.join(args.cord_path, "labels.json"),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
if not args.rotation or args.eval_straight
else []
)
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation and not args.eval_straight
else []
)
),
use_polygons=args.rotation and not args.eval_straight,
)
subfolder = _cord_ds.root.split("/")[-2:]
cord_ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _cord_ds.data])

cord_test_loader = DataLoader(
cord_ds,
Expand Down Expand Up @@ -545,13 +511,13 @@ def main(args):
funsd_recall, funsd_precision, funsd_mean_iou = 0.0, 0.0, 0.0
cord_recall, cord_precision, cord_mean_iou = 0.0, 0.0, 0.0
try:
_, funsd_recall, funsd_precision, funsd_mean_iou = sec_evaluate(
_, funsd_recall, funsd_precision, funsd_mean_iou = evaluate(
model, funsd_test_loader, batch_transforms, funsd_val_metric, amp=args.amp
)
except Exception:
pass
try:
_, cord_recall, cord_precision, cord_mean_iou = sec_evaluate(
_, cord_recall, cord_precision, cord_mean_iou = evaluate(
model, cord_test_loader, batch_transforms, cord_val_metric, amp=args.amp
)
except Exception:
Expand Down Expand Up @@ -603,6 +569,8 @@ def parse_args():

parser.add_argument("train_path", type=str, help="path to training data folder")
parser.add_argument("val_path", type=str, help="path to validation data folder")
parser.add_argument("funsd_path", type=str, help="path to FUNSD data folder")
parser.add_argument("cord_path", type=str, help="path to Cord data folder")
parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
Expand Down

0 comments on commit 34e32eb

Please sign in to comment.