Skip to content

Commit e489017

Browse files
committed
fix bug on trainer
1 parent 561c1fe commit e489017

File tree

4 files changed

+35
-5
lines changed

4 files changed

+35
-5
lines changed

README.MD

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ only train on ICDAR2015 dataset
6060
| Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS |
6161
|--------------------------|-------|--------|--------|------------|---------------|-----|
6262
| paper(resnet18) | 736 |x | x | x | 80.4 | 26.1 |
63-
| my (resnet18+FPEM_FFM+pse扩张) |736 |1e-3| 84.24 | 74.14 | 78.87 | 21.31 (P100)|
63+
| my (resnet18+FPEM_FFM+pse扩张) |736 |1e-3| 85.03 | 73.03 | 78.58 | 21.31 (P100)|
6464
| my (resnet50+FPEM_FFM+pse扩张) |736 |1e-3| 69.04 | 66.66 | 67.83 | 14.22 (P100)|
6565
| my (resnet18+FPEM_FFM+pse扩张) |736 |1e-4| 62.93 | 62.41 | 62.61 | 21.31 (P100)|
6666
| my (resnet50+FPEM_FFM+pse扩张) |736 |1e-4| 61.19 | 69.18 | 64.94 | 14.22 (P100)|

config.json

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"loader": {
2929
"validation_split": 0.1,
3030
"train_batch_size": 16,
31-
"val_batch_size": 4,
3231
"shuffle": true,
3332
"pin_memory": false,
3433
"num_workers": 6

data_loader/data_utils.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
# @Time : 2019/8/23 21:53
33
# @Author : zhoujun
4-
4+
import math
55
import random
66
import pyclipper
77
import numpy as np
@@ -33,6 +33,25 @@ def check_and_validate_polys(polys, xxx_todo_changeme):
3333
validated_polys.append(poly)
3434
return np.array(validated_polys)
3535

36+
def unshrink_offset(poly,ratio):
37+
area = cv2.contourArea(poly)
38+
peri = cv2.arcLength(poly, True)
39+
a = 8
40+
b = peri - 4
41+
c = 1-0.5 * peri - area/ratio
42+
return quadratic(a,b,c)
43+
44+
def quadratic(a, b, c):
45+
if (b * b - 4 * a * c) < 0:
46+
return 'None'
47+
Delte = math.sqrt(b * b - 4 * a * c)
48+
if Delte > 0:
49+
x = (- b + Delte) / (2 * a)
50+
y = (- b - Delte) / (2 * a)
51+
return x, y
52+
else:
53+
x = (- b) / (2 * a)
54+
return x
3655

3756
def generate_rbox(im_size, text_polys, text_tags,training_mask, shrink_ratio):
3857
"""
@@ -48,7 +67,8 @@ def generate_rbox(im_size, text_polys, text_tags,training_mask, shrink_ratio):
4867
for i, (poly, tag) in enumerate(zip(text_polys, text_tags)):
4968
try:
5069
poly = poly.astype(np.int)
51-
d_i = cv2.contourArea(poly) * (1 - shrink_ratio * shrink_ratio) / cv2.arcLength(poly, True)
70+
# d_i = cv2.contourArea(poly) * (1 - shrink_ratio * shrink_ratio) / cv2.arcLength(poly, True)
71+
d_i = cv2.contourArea(poly) * (1 - shrink_ratio) / cv2.arcLength(poly, True) + 0.5
5272
pco = pyclipper.PyclipperOffset()
5373
pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
5474
shrinked_poly = np.array(pco.Execute(-d_i))
@@ -107,3 +127,14 @@ def image_label(im: np.ndarray, text_polys: np.ndarray, text_tags: list, input_s
107127
score_maps = np.array(score_maps, dtype=np.float32)
108128
imgs = data_aug.random_crop([im, score_maps.transpose((1, 2, 0)), training_mask], (input_size, input_size))
109129
return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2] # im,score_maps,training_mask#
130+
131+
if __name__ == '__main__':
132+
poly = np.array([377,117,463,117,465,130,378,130]).reshape(-1,2)
133+
shrink_ratio = 0.5
134+
d_i = cv2.contourArea(poly) * (1 - shrink_ratio) / cv2.arcLength(poly, True) + 0.5
135+
pco = pyclipper.PyclipperOffset()
136+
pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
137+
shrinked_poly = np.array(pco.Execute(-d_i))
138+
print(d_i)
139+
print(cv2.contourArea(shrinked_poly.astype(int)) / cv2.contourArea(poly))
140+
print(unshrink_offset(shrinked_poly,shrink_ratio))

trainer/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torchvision.utils as vutils
1313
from torchvision import transforms
14-
from post_processing import decode_np as decode
14+
from post_processing import decode
1515
from utils import PolynomialLR, runningScore, cal_text_score, cal_kernel_score, cal_recall_precison_f1
1616

1717
from base import BaseTrainer

0 commit comments

Comments
 (0)