Skip to content

Commit e2c842a

Browse files
committed
train is ok
1 parent 5779c5a commit e2c842a

17 files changed

+350
-55
lines changed

base/base_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, config, model, criterion, weights_init):
7474
else:
7575
if weights_init is not None:
7676
model.apply(weights_init)
77-
# self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer)
77+
self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer)
7878

7979
# 单机多卡
8080
num_gpus = torch.cuda.device_count()
@@ -102,8 +102,8 @@ def train(self):
102102
"""
103103
for epoch in range(self.start_epoch, self.epochs + 1):
104104
try:
105+
self.scheduler.step()
105106
self.epoch_result = self._train_epoch(epoch)
106-
107107
self._on_epoch_finish()
108108
except torch.cuda.CudaError:
109109
self._log_memory_usage()

config.json

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{
2-
"name": "PAN",
2+
"name": "PAN_pred_mask",
33
"data_loader": {
44
"type": "ImageDataset",
55
"args": {
66
"alphabet": "alphabet.npy",
77
"dataset": {
88
"train_data_path": [
99
[
10-
"E:\\zj\\dataset\\icdar2015\\train\\train.txt"
10+
"/data1/zj/ocr/icdar2015/train/train.txt"
1111
]
1212
],
1313
"train_data_ratio": [
@@ -32,7 +32,7 @@
3232
"arch": {
3333
"type": "PANModel",
3434
"args": {
35-
"backbone": "resnet18",
35+
"backbone": "resnet50",
3636
"fpem_repeat": 2,
3737
"pretrained": true
3838
}
@@ -50,9 +50,7 @@
5050
"optimizer": {
5151
"type": "Adam",
5252
"args": {
53-
"lr": 0.001,
54-
"weight_decay": 0,
55-
"amsgrad": true
53+
"lr": 0.001
5654
}
5755
},
5856
"lr_scheduler": {
@@ -65,15 +63,16 @@
6563
"trainer": {
6664
"seed": 2,
6765
"gpus": [
68-
0
66+
3
6967
],
7068
"epochs": 600,
7169
"display_interval": 10,
70+
"show_images_interval": 50,
7271
"resume": {
7372
"restart_training": true,
7473
"checkpoint": ""
7574
},
7675
"output_dir": "output",
7776
"tensorboard": true
7877
}
79-
}
78+
}

config/default.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
'gpus': [0],
7575
'epochs': 100,
7676
'display_interval': 10,
77+
'show_images_interval': 50,
7778
'resume': resume,
7879
'output_dir': 'output',
7980
'tensorboard': True

data_loader/augment.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,24 @@ def random_crop(self, imgs, img_size):
121121
return imgs
122122

123123
# label中存在文本实例,并且按照概率进行裁剪
124-
if np.max(imgs[1][:, :, -1]) > 0 and random.random() > 3.0 / 8.0:
124+
if np.max(imgs[1][:, :, 0]) > 0 and random.random() > 3.0 / 8.0:
125125
# 文本实例的top left点
126-
tl = np.min(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
126+
tl = np.min(np.where(imgs[1][:, :, 0] > 0), axis=1) - img_size
127127
tl[tl < 0] = 0
128128
# 文本实例的 bottom right 点
129-
br = np.max(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
129+
br = np.max(np.where(imgs[1][:, :, 0] > 0), axis=1) - img_size
130130
br[br < 0] = 0
131131
# 保证选到右下角点是,有足够的距离进行crop
132132
br[0] = min(br[0], h - th)
133133
br[1] = min(br[1], w - tw)
134-
134+
for _ in range(50000):
135+
i = random.randint(tl[0], br[0])
136+
j = random.randint(tl[1], br[1])
137+
# 保证最小的图有文本
138+
if imgs[1][:, :, -1][i:i + th, j:j + tw].sum() <= 0:
139+
continue
140+
else:
141+
break
135142
i = random.randint(tl[0], br[0])
136143
j = random.randint(tl[1], br[1])
137144
else:

data_loader/data_utils.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,24 @@ def generate_rbox(im_size, text_polys, text_tags,training_mask, shrink_ratio):
5252
pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
5353
shrinked_poly = np.array(pco.Execute(-d_i))
5454
cv2.fillPoly(score_map, shrinked_poly, i + 1)
55-
if tag:
55+
if not tag:
5656
cv2.fillPoly(training_mask, shrinked_poly, 0)
5757
return score_map, training_mask
5858

5959

60-
def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int, input_size: int) -> tuple:
60+
def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int) -> tuple:
6161
# the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly
6262
im, text_polys = data_aug.random_scale(im, text_polys, scales)
6363
# the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly
6464
if random.random() < 0.5:
6565
im, text_polys = data_aug.horizontal_flip(im, text_polys)
6666
if random.random() < 0.5:
6767
im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees)
68-
6968
return im, text_polys
7069

7170

7271
def image_label(im: np.ndarray, text_polys: np.ndarray, text_tags: list, input_size: int = 640,
73-
shrink_ratio: float = 0.5, defrees: int = 10,
72+
shrink_ratio: float = 0.5, degrees: int = 10,
7473
scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple:
7574
"""
7675
读取图片并生成label
@@ -79,14 +78,14 @@ def image_label(im: np.ndarray, text_polys: np.ndarray, text_tags: list, input_s
7978
:param text_tags: 是否忽略文本的标致:true 忽略, false 不忽略
8079
:param input_size: 输出图像的尺寸
8180
:param shrink_ratio: gt收缩的比例
82-
:param defrees: 随机旋转的角度
81+
:param degrees: 随机旋转的角度
8382
:param scales: 随机缩放的尺度
8483
:return:
8584
"""
8685
h, w, _ = im.shape
8786
# 检查越界
8887
text_polys = check_and_validate_polys(text_polys, (h, w))
89-
# im, text_polys, = augmentation(im, text_polys, scales, defrees, input_size)
88+
im, text_polys = augmentation(im, text_polys, scales, degrees)
9089

9190
h, w, _ = im.shape
9291
short_edge = min(h, w)

data_loader/dataset.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def _get_annotation(self, label_path: str) -> tuple:
5353
try:
5454
label = params[8]
5555
if label == '*' or label == '###':
56-
text_tags.append(True)
57-
else:
5856
text_tags.append(False)
57+
else:
58+
text_tags.append(True)
5959
# if label == '*' or label == '###':
6060
x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, params[:8]))
6161
boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
@@ -135,9 +135,10 @@ def __next__(self):
135135
import matplotlib.pyplot as plt
136136
from torchvision import transforms
137137

138+
138139
train_data = ImageDataset(
139140
data_list=[
140-
(r'E:\zj\dataset\icdar2015\train\img\img_828.jpg', r'E:\zj\dataset\icdar2015\train\gt\gt_img_828.txt')],
141+
(r'/data1/zj/ocr/icdar2015/train/img/img_713.jpg','/data1/zj/ocr/icdar2015/train/gt/gt_img_713.txt')],
141142
input_size=640,
142143
img_channel=3,
143144
shrink_ratio=0.5,

models/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# @Time : 2019/8/23 21:55
33
# @Author : zhoujun
44
from .model import PAN
5+
from .model_pse1 import PSENet
56
from .loss import PANLoss
67

78

@@ -11,6 +12,10 @@ def get_model(config):
1112
pretrained = config['arch']['args']['pretrained']
1213
return PAN(backbone=backbone, fpem_repeat=fpem_repeat, pretrained=pretrained)
1314

15+
def get_model_pse1(config):
16+
backbone = config['arch']['args']['backbone']
17+
pretrained = config['arch']['args']['pretrained']
18+
return PSENet(backbone=backbone, pretrained=pretrained)
1419

1520
def get_loss(config):
1621
alpha = config['loss']['args']['alpha']

models/loss.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,30 @@ def __init__(self, alpha=0.5, beta=0.25, delta_agg=0.5, delta_dis=3, ohem_ratio=
2828
self.reduction = reduction
2929

3030
def forward(self, outputs, labels, training_masks):
31-
batch_size = outputs.size()[0]
3231
texts = outputs[:, 0, :, :]
3332
kernels = outputs[:, 1, :, :]
3433
gt_texts = labels[:, 0, :, :]
3534
gt_kernels = labels[:, 1, :, :]
3635

36+
37+
# 计算 agg loss 和 dis loss
38+
similarity_vectors = outputs[:, 2:, :, :]
39+
loss_aggs, loss_diss = self.agg_dis_loss(texts, kernels, gt_texts, gt_kernels, similarity_vectors)
40+
3741
# 计算 text loss
3842
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
3943
selected_masks = selected_masks.to(outputs.device)
4044

4145
loss_texts = self.dice_loss(texts, gt_texts, selected_masks)
4246

4347
# 计算 kernel loss
44-
selected_masks = ((gt_texts > 0.5) & (training_masks > 0.5)).float()
45-
48+
# selected_masks = ((gt_texts > 0.5) & (training_masks > 0.5)).float()
49+
mask0 = torch.sigmoid(texts).detach().cpu().numpy()
50+
mask1 = training_masks.data.cpu().numpy()
51+
selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
52+
selected_masks = torch.from_numpy(selected_masks).float().to(texts.device)
4653
loss_kernels = self.dice_loss(kernels, gt_kernels, selected_masks)
4754

48-
# 计算 agg loss 和 dis loss
49-
similarity_vectors = outputs[:, 2:, :, :]
50-
51-
texts = texts.contiguous().reshape(batch_size, -1)
52-
kernels = kernels.contiguous().reshape(batch_size, -1)
53-
gt_texts = gt_texts.contiguous().reshape(batch_size, -1)
54-
gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1)
55-
similarity_vectors = similarity_vectors.contiguous().view(batch_size, 4, -1)
56-
57-
loss_aggs, loss_diss = self.agg_dis_loss(texts, kernels, gt_texts, gt_kernels, similarity_vectors)
58-
5955
# mean or sum
6056
if self.reduction == 'mean':
6157
loss_text = loss_texts.mean()
@@ -83,7 +79,12 @@ def agg_dis_loss(self, texts, kernels, gt_texts, gt_kernels, similarity_vectors)
8379
:param similarity_vectors: 相似度向量的分割结果 batch_size * 4 *(w*h)
8480
:return:
8581
"""
86-
82+
batch_size = texts.size()[0]
83+
texts = texts.contiguous().reshape(batch_size, -1)
84+
kernels = kernels.contiguous().reshape(batch_size, -1)
85+
gt_texts = gt_texts.contiguous().reshape(batch_size, -1)
86+
gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1)
87+
similarity_vectors = similarity_vectors.contiguous().view(batch_size, 4, -1)
8788
loss_aggs = []
8889
loss_diss = []
8990
for text_i, kernel_i, gt_text_i, gt_kernel_i, similarity_vector in zip(texts, kernels, gt_texts, gt_kernels,
@@ -133,7 +134,8 @@ def agg_dis_loss(self, texts, kernels, gt_texts, gt_kernels, similarity_vectors)
133134

134135
def dice_loss(self, input, target, mask):
135136
input = torch.sigmoid(input)
136-
137+
target[target <= 0.5] = 0
138+
target[target > 0.5] = 1
137139
input = input.contiguous().view(input.size()[0], -1)
138140
target = target.contiguous().view(target.size()[0], -1)
139141
mask = mask.contiguous().view(mask.size()[0], -1)

models/model_pse1.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2019/8/23 21:57
3+
# @Author : zhoujun
4+
5+
import torch
6+
from torch import nn
7+
import torch.nn.functional as F
8+
from models.modules import *
9+
10+
backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]},
11+
'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]},
12+
'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]},
13+
'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]},
14+
'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]},
15+
'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]},
16+
'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]}
17+
}
18+
19+
20+
# 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]},
21+
# 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]},
22+
# 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}}
23+
24+
inplace = True
25+
26+
class PSENet(nn.Module):
27+
def __init__(self, backbone, result_num=6, scale: int = 1, pretrained=False):
28+
super(PSENet, self).__init__()
29+
assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict)
30+
self.name = backbone
31+
self.scale = scale
32+
conv_out = 256
33+
backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out']
34+
self.backbone = backbone_model(pretrained=pretrained)
35+
36+
# Top layer
37+
self.toplayer = nn.Conv2d(backbone_out[3], conv_out, kernel_size=1, stride=1, padding=0) # Reduce channels
38+
# Lateral layers
39+
self.latlayer1 = nn.Conv2d(backbone_out[2], conv_out, kernel_size=1, stride=1, padding=0)
40+
self.latlayer2 = nn.Conv2d(backbone_out[1], conv_out, kernel_size=1, stride=1, padding=0)
41+
self.latlayer3 = nn.Conv2d(backbone_out[0], conv_out, kernel_size=1, stride=1, padding=0)
42+
43+
# Smooth layers
44+
self.smooth1 = nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1)
45+
self.smooth2 = nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1)
46+
self.smooth3 = nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1)
47+
48+
self.conv = nn.Sequential(
49+
nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1),
50+
nn.BatchNorm2d(conv_out),
51+
nn.ReLU(inplace=inplace)
52+
)
53+
self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1)
54+
55+
def forward(self, input: torch.Tensor):
56+
_, _, H, W = input.size()
57+
c2, c3, c4, c5 = self.backbone(input)
58+
# Top-down
59+
p5 = self.toplayer(c5)
60+
p4 = self._upsample_add(p5, self.latlayer1(c4))
61+
p3 = self._upsample_add(p4, self.latlayer2(c3))
62+
p2 = self._upsample_add(p3, self.latlayer3(c2))
63+
# Smooth
64+
p4 = self.smooth1(p4)
65+
p3 = self.smooth2(p3)
66+
p2 = self.smooth3(p2)
67+
68+
x = self._upsample_cat(p2, p3, p4, p5)
69+
x = self.conv(x)
70+
x = self.out_conv(x)
71+
72+
x = F.interpolate(x, size=(H // self.scale, W // self.scale), mode='bilinear', align_corners=True)
73+
return x
74+
75+
def _upsample_add(self, x, y):
76+
return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=False) + y
77+
78+
def _upsample_cat(self, p2, p3, p4, p5):
79+
h, w = p2.size()[2:]
80+
p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=False)
81+
p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=False)
82+
p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=False)
83+
return torch.cat([p2, p3, p4, p5], dim=1)
84+
85+
if __name__ == '__main__':
86+
device = torch.device('cpu')
87+
x = torch.zeros(1, 3, 640, 640).to(device)
88+
89+
model = PAN(backbone='resnet18', fpem_repeat=2, pretrained=True).to(device)
90+
y = model(x)
91+
print(y.shape)
92+
# torch.save(model.state_dict(), 'PAN.pth')

0 commit comments

Comments
 (0)