Skip to content

Commit 561c1fe

Browse files
committed
update readme
1 parent 14b1644 commit 561c1fe

File tree

10 files changed

+97
-104
lines changed

10 files changed

+97
-104
lines changed

README.MD

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ only train on ICDAR2015 dataset
6060
| Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS |
6161
|--------------------------|-------|--------|--------|------------|---------------|-----|
6262
| paper(resnet18) | 736 |x | x | x | 80.4 | 26.1 |
63-
| my (resnet18+FPEM_FFM+pse扩张) |736 |1e-3| 73.17 | 70.39 | 71.75 | 21.31 (P100)|
63+
| my (resnet18+FPEM_FFM+pse扩张) |736 |1e-3| 84.24 | 74.14 | 78.87 | 21.31 (P100)|
6464
| my (resnet50+FPEM_FFM+pse扩张) |736 |1e-3| 69.04 | 66.66 | 67.83 | 14.22 (P100)|
6565
| my (resnet18+FPEM_FFM+pse扩张) |736 |1e-4| 62.93 | 62.41 | 62.61 | 21.31 (P100)|
6666
| my (resnet50+FPEM_FFM+pse扩张) |736 |1e-4| 61.19 | 69.18 | 64.94 | 14.22 (P100)|

base/base_trainer.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, config, model, criterion, weights_init):
2020
self.save_dir = os.path.join(config['trainer']['output_dir'], config['name'])
2121
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
2222

23-
if config['trainer']['resume']['restart_training']:
23+
if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '':
2424
shutil.rmtree(self.save_dir, ignore_errors=True)
2525
if not os.path.exists(self.checkpoint_dir):
2626
os.makedirs(self.checkpoint_dir)
@@ -42,7 +42,7 @@ def __init__(self, config, model, criterion, weights_init):
4242
self.logger = setup_logger(os.path.join(self.save_dir, 'train_log'))
4343
self.logger.info(pformat(self.config))
4444

45-
# device set
45+
# device
4646
torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子
4747
if len(self.config['trainer']['gpus']) > 0 and torch.cuda.is_available():
4848
self.with_cuda = True
@@ -62,8 +62,10 @@ def __init__(self, config, model, criterion, weights_init):
6262

6363
self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())
6464

65-
if self.config['trainer']['resume']['checkpoint'] != '':
66-
self._resume_checkpoint(self.config['trainer']['resume']['checkpoint'])
65+
if self.config['trainer']['resume_checkpoint'] != '':
66+
self._laod_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True)
67+
elif self.config['trainer']['finetune_checkpoint'] != '':
68+
self._laod_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False)
6769
else:
6870
if weights_init is not None:
6971
model.apply(weights_init)
@@ -171,15 +173,15 @@ def _save_checkpoint(self, epoch, file_name, save_best=False):
171173
else:
172174
self.logger.info("Saving checkpoint: {}".format(filename))
173175

174-
def _resume_checkpoint(self, resume_path):
176+
def _laod_checkpoint(self, checkpoint_path, resume):
175177
"""
176178
Resume from saved checkpoints
177-
:param resume_path: Checkpoint path to be resumed
179+
:param checkpoint_path: Checkpoint path to be resumed
178180
"""
179-
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
180-
checkpoint = torch.load(resume_path)
181+
self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))
182+
checkpoint = torch.load(checkpoint_path)
181183
self.model.load_state_dict(checkpoint['state_dict'])
182-
if not self.config['trainer']['resume']['restart_training']:
184+
if resume:
183185
self.global_step = checkpoint['global_step']
184186
self.start_epoch = checkpoint['epoch'] + 1
185187
self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
@@ -192,8 +194,9 @@ def _resume_checkpoint(self, resume_path):
192194
for k, v in state.items():
193195
if isinstance(v, torch.Tensor):
194196
state[k] = v.to(self.device)
195-
# self.config = checkpoint['config']
196-
self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch))
197+
self.logger.info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch))
198+
else:
199+
self.logger.info("finetune from checkpoint {}".format(checkpoint_path))
197200

198201
def _initialize(self, name, module, *args, **kwargs):
199202
module_name = self.config[name]['type']

config.json

+2-4
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@
7777
"epochs": 600,
7878
"display_interval": 10,
7979
"show_images_interval": 50,
80-
"resume": {
81-
"restart_training": true,
82-
"checkpoint": ""
83-
},
80+
"resume_checkpoint": "",
81+
"finetune_checkpoint": "",
8482
"output_dir": "output",
8583
"tensorboard": true
8684
}

data_loader/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from PIL import Image
77
from torch.utils.data import Dataset, DataLoader
88
from data_loader.data_utils import image_label
9-
from utils import order_points_colckwise
9+
from utils import order_points_clockwise
1010

1111

1212
class ImageDataset(Dataset):
@@ -52,7 +52,7 @@ def _get_annotation(self, label_path: str) -> tuple:
5252
for line in f.readlines():
5353
params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
5454
try:
55-
box = order_points_colckwise(np.array(list(map(float, params[:8]))).reshape(-2, 1))
55+
box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2))
5656
if cv2.arcLength(box, True) > 0:
5757
boxes.append(box)
5858
label = params[8]

eval.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def main(model_path, img_folder, save_path, gpu_id):
4141

4242

4343
if __name__ == '__main__':
44-
os.environ['CUDA_VISIBLE_DEVICES'] = str('2')
44+
os.environ['CUDA_VISIBLE_DEVICES'] = str('0')
4545
scale = 4
46-
model_path = 'output/PAN_pred_mask_resnet50/checkpoint/model_best.pth'
46+
model_path = 'output/PAN_resnet18_FPEM_FFM/checkpoint/model_best.pth'
4747
img_path = '/data2/dataset/ICD15/test/img'
4848
gt_path = '/data2/dataset/ICD15/test/gt'
4949
save_path = model_path.replace('checkpoint/model_best.pt', 'result/')

models/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def forward(self, x):
6161
'fpem_repeat': 4, # fpem模块重复的次数
6262
'pretrained': False, # backbone 是否使用imagesnet的预训练模型
6363
'result_num':7,
64-
'segmentation_head': 'FPEM_FFM' # 分割头,FPN or FPEM
64+
'segmentation_head': 'FPEM_FFM' # 分割头,FPN or FPEM_FFM
6565
}
6666
model = Model(model_config=model_config).to(device)
6767
y = model(x)
6868
print(y.shape)
69-
print(model)
69+
# print(model)
7070
# torch.save(model.state_dict(), 'PAN.pth')

models/modules/segmentation_head.py

+64-58
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,40 @@ def __init__(self, backbone_out_channels, **kwargs):
1616
result_num = kwargs.get('result_num', 6)
1717
inplace = True
1818
conv_out = 256
19-
20-
# Top layer
21-
self.toplayer = nn.Sequential(
22-
nn.Conv2d(backbone_out_channels[3], conv_out, kernel_size=1, stride=1, padding=0),
19+
# reduce layers
20+
self.reduce_conv_c2 = nn.Sequential(
21+
nn.Conv2d(backbone_out_channels[0], conv_out, kernel_size=1, stride=1, padding=0),
2322
nn.BatchNorm2d(conv_out),
2423
nn.ReLU(inplace=inplace)
2524
)
26-
# Lateral layers
27-
self.latlayer1 = nn.Sequential(
28-
nn.Conv2d(backbone_out_channels[2], conv_out, kernel_size=1, stride=1, padding=0),
25+
self.reduce_conv_c3 = nn.Sequential(
26+
nn.Conv2d(backbone_out_channels[1], conv_out, kernel_size=1, stride=1, padding=0),
2927
nn.BatchNorm2d(conv_out),
3028
nn.ReLU(inplace=inplace)
3129
)
32-
self.latlayer2 = nn.Sequential(
33-
nn.Conv2d(backbone_out_channels[1], conv_out, kernel_size=1, stride=1, padding=0),
30+
self.reduce_conv_c4 = nn.Sequential(
31+
nn.Conv2d(backbone_out_channels[2], conv_out, kernel_size=1, stride=1, padding=0),
3432
nn.BatchNorm2d(conv_out),
3533
nn.ReLU(inplace=inplace)
3634
)
37-
self.latlayer3 = nn.Sequential(
38-
nn.Conv2d(backbone_out_channels[0], conv_out, kernel_size=1, stride=1, padding=0),
35+
36+
self.reduce_conv_c5 = nn.Sequential(
37+
nn.Conv2d(backbone_out_channels[3], conv_out, kernel_size=1, stride=1, padding=0),
3938
nn.BatchNorm2d(conv_out),
4039
nn.ReLU(inplace=inplace)
4140
)
42-
4341
# Smooth layers
44-
self.smooth1 = nn.Sequential(
42+
self.smooth_p4 = nn.Sequential(
4543
nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
4644
nn.BatchNorm2d(conv_out),
4745
nn.ReLU(inplace=inplace)
4846
)
49-
self.smooth2 = nn.Sequential(
47+
self.smooth_p3 = nn.Sequential(
5048
nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
5149
nn.BatchNorm2d(conv_out),
5250
nn.ReLU(inplace=inplace)
5351
)
54-
self.smooth3 = nn.Sequential(
52+
self.smooth_p2 = nn.Sequential(
5553
nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
5654
nn.BatchNorm2d(conv_out),
5755
nn.ReLU(inplace=inplace)
@@ -67,27 +65,27 @@ def __init__(self, backbone_out_channels, **kwargs):
6765
def forward(self, x):
6866
c2, c3, c4, c5 = x
6967
# Top-down
70-
p5 = self.toplayer(c5)
71-
p4 = self._upsample_add(p5, self.latlayer1(c4))
72-
p4 = self.smooth1(p4)
73-
p3 = self._upsample_add(p4, self.latlayer2(c3))
74-
p3 = self.smooth2(p3)
75-
p2 = self._upsample_add(p3, self.latlayer3(c2))
76-
p2 = self.smooth3(p2)
68+
p5 = self.reduce_conv_c5(c5)
69+
p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
70+
p4 = self.smooth_p4(p4)
71+
p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
72+
p3 = self.smooth_p3(p3)
73+
p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
74+
p2 = self.smooth_p2(p2)
7775

7876
x = self._upsample_cat(p2, p3, p4, p5)
7977
x = self.conv(x)
8078
x = self.out_conv(x)
8179
return x
8280

8381
def _upsample_add(self, x, y):
84-
return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=False) + y
82+
return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y
8583

8684
def _upsample_cat(self, p2, p3, p4, p5):
8785
h, w = p2.size()[2:]
88-
p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=False)
89-
p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=False)
90-
p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=False)
86+
p3 = F.interpolate(p3, size=(h, w), mode='bilinear')
87+
p4 = F.interpolate(p4, size=(h, w), mode='bilinear')
88+
p5 = F.interpolate(p5, size=(h, w), mode='bilinear')
9189
return torch.cat([p2, p3, p4, p5], dim=1)
9290

9391

@@ -99,22 +97,40 @@ def __init__(self, backbone_out_channels, **kwargs):
9997
"""
10098
super().__init__()
10199
fpem_repeat = kwargs.get('fpem_repeat', 2)
102-
self.conv_c2 = nn.Conv2d(in_channels=backbone_out_channels[0], out_channels=128, kernel_size=1)
103-
self.conv_c3 = nn.Conv2d(in_channels=backbone_out_channels[1], out_channels=128, kernel_size=1)
104-
self.conv_c4 = nn.Conv2d(in_channels=backbone_out_channels[2], out_channels=128, kernel_size=1)
105-
self.conv_c5 = nn.Conv2d(in_channels=backbone_out_channels[3], out_channels=128, kernel_size=1)
100+
conv_out = 128
101+
# reduce layers
102+
self.reduce_conv_c2 = nn.Sequential(
103+
nn.Conv2d(in_channels=backbone_out_channels[0], out_channels=conv_out, kernel_size=1),
104+
nn.BatchNorm2d(conv_out),
105+
nn.ReLU()
106+
)
107+
self.reduce_conv_c3 = nn.Sequential(
108+
nn.Conv2d(in_channels=backbone_out_channels[1], out_channels=conv_out, kernel_size=1),
109+
nn.BatchNorm2d(conv_out),
110+
nn.ReLU()
111+
)
112+
self.reduce_conv_c4 = nn.Sequential(
113+
nn.Conv2d(in_channels=backbone_out_channels[2], out_channels=conv_out, kernel_size=1),
114+
nn.BatchNorm2d(conv_out),
115+
nn.ReLU()
116+
)
117+
self.reduce_conv_c5 = nn.Sequential(
118+
nn.Conv2d(in_channels=backbone_out_channels[3], out_channels=conv_out, kernel_size=1),
119+
nn.BatchNorm2d(conv_out),
120+
nn.ReLU()
121+
)
106122
self.fpems = nn.ModuleList()
107123
for i in range(fpem_repeat):
108-
self.fpems.append(FPEM(128))
109-
self.out_conv = nn.Conv2d(in_channels=512, out_channels=6, kernel_size=1)
124+
self.fpems.append(FPEM(conv_out))
125+
self.out_conv = nn.Conv2d(in_channels=conv_out * 4, out_channels=6, kernel_size=1)
110126

111127
def forward(self, x):
112128
c2, c3, c4, c5 = x
113129
# reduce channel
114-
c2 = self.conv_c2(c2)
115-
c3 = self.conv_c3(c3)
116-
c4 = self.conv_c4(c4)
117-
c5 = self.conv_c5(c5)
130+
c2 = self.reduce_conv_c2(c2)
131+
c3 = self.reduce_conv_c3(c3)
132+
c4 = self.reduce_conv_c4(c4)
133+
c5 = self.reduce_conv_c5(c5)
118134

119135
# FPEM
120136
for i, fpem in enumerate(self.fpems):
@@ -142,38 +158,28 @@ def forward(self, x):
142158
class FPEM(nn.Module):
143159
def __init__(self, in_channels=128):
144160
super().__init__()
145-
# self.add_up = nn.Sequential(
146-
# nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, padding=1, groups=in_channel),
147-
# nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=1),
148-
# nn.BatchNorm2d(in_channel),
149-
# nn.ReLU()
150-
# )
151-
# self.add_down = nn.Sequential(
152-
# nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, padding=1, groups=in_channel,
153-
# stride=2),
154-
# nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=1),
155-
# nn.BatchNorm2d(in_channel),
156-
# nn.ReLU()
157-
# )
158161
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
159162
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
160-
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
163+
self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
161164
self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
162165
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
163166
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
164167

165168
def forward(self, c2, c3, c4, c5):
166169
# up阶段
167-
c4 = self.up_add1(c4 + F.interpolate(c5, c4.size()[-2:], mode='bilinear', align_corners=True))
168-
c3 = self.up_add2(c3 + F.interpolate(c4, c3.size()[-2:], mode='bilinear', align_corners=True))
169-
c2 = self.up_add2(c2 + F.interpolate(c3, c2.size()[-2:], mode='bilinear', align_corners=True))
170+
c4 = self.up_add1(self._upsample_add(c5, c4))
171+
c3 = self.up_add2(self._upsample_add(c4, c3))
172+
c2 = self.up_add3(self._upsample_add(c3, c2))
170173

171174
# down 阶段
172-
c3 = self.down_add1(c2 + F.interpolate(c3, c2.size()[-2:], mode='bilinear', align_corners=True))
173-
c4 = self.down_add2(c3 + F.interpolate(c4, c3.size()[-2:], mode='bilinear', align_corners=True))
174-
c5 = self.down_add3(c4 + F.interpolate(c5, c4.size()[-2:], mode='bilinear', align_corners=True))
175+
c3 = self.down_add1(self._upsample_add(c3, c2))
176+
c4 = self.down_add2(self._upsample_add(c4, c3))
177+
c5 = self.down_add3(self._upsample_add(c5, c4))
175178
return c2, c3, c4, c5
176179

180+
def _upsample_add(self, x, y):
181+
return F.interpolate(x, size=y.size()[2:], mode='bilinear') + y
182+
177183

178184
class SeparableConv2d(nn.Module):
179185
def __init__(self, in_channels, out_channels, stride=1):
@@ -186,8 +192,8 @@ def __init__(self, in_channels, out_channels, stride=1):
186192
self.relu = nn.ReLU()
187193

188194
def forward(self, x):
189-
x = self.out_channels(x)
195+
x = self.depthwise_conv(x)
190196
x = self.pointwise_conv(x)
191197
x = self.bn(x)
192-
x = self.relu
198+
x = self.relu(x)
193199
return x

post_processing/__init__.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def decode(preds, scale=1, threshold=0.7311, min_area=5):
2525
:param threshold: sigmoid的阈值
2626
:return: 最后的输出图和文本框
2727
"""
28-
from .pse import pse_cpp, get_points, get_sum
28+
from .pse import pse_cpp, get_points, get_num
2929
preds[:2, :, :] = torch.sigmoid(preds[:2, :, :])
3030
preds = preds.detach().cpu().numpy()
3131
score = preds[0].astype(np.float32)
@@ -35,42 +35,28 @@ def decode(preds, scale=1, threshold=0.7311, min_area=5):
3535

3636
label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4)
3737
label_values = []
38-
label_sum = get_sum(label, label_num)
38+
label_sum = get_num(label, label_num)
3939
for label_idx in range(1, label_num):
4040
if label_sum[label_idx] < min_area:
4141
continue
4242
label_values.append(label_idx)
4343

4444
pred = pse_cpp(text.astype(np.uint8), similarity_vectors, label, label_num, 0.8)
4545
pred = pred.reshape(text.shape)
46-
bbox_list = []
47-
for label_value in label_values:
48-
points = np.array(np.where(pred == label_value)).transpose((1, 0))[:, ::-1]
49-
50-
if points.shape[0] < 100 / (scale * scale):
51-
continue
52-
53-
score_i = np.mean(score[pred == label_value])
54-
if score_i < 0.1:
55-
continue
56-
57-
rect = cv2.minAreaRect(points)
58-
bbox = cv2.boxPoints(rect)
59-
bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]])
6046

6147
bbox_list = []
6248
label_points = get_points(pred, score, label_num)
63-
for label_value, label_point in label_points.item():
49+
for label_value, label_point in label_points.items():
6450
if label_value not in label_values:
6551
continue
6652
score_i = label_point[0]
6753
label_point = label_point[2:]
68-
points = np.array(label_point).reshape(-1, 2)
54+
points = np.array(label_point, dtype=int).reshape(-1, 2)
6955

7056
if points.shape[0] < 100 / (scale * scale):
7157
continue
7258

73-
if score_i < 0.1:
59+
if score_i < 0.93:
7460
continue
7561

7662
rect = cv2.minAreaRect(points)

0 commit comments

Comments
 (0)