English | 中文
PSENet: Shape Robust Text Detection With Progressive Scale Expansion Network
PSENet是一种基于语义分割的文本检测算法。它可以精确定位具有任意形状的文本实例,而大多数基于anchor类的算法不能用来检测任意形状的文本实例。此外,两个彼此靠近的文本可能会导致模型做出错误的预测。因此,为了解决上述问题,PSENet还提出了一种渐进式尺度扩展算法(Progressive Scale Expansion Algorithm, PSE),利用该算法可以成功识别相邻的文本实例[1]。
图 1. PSENet整体架构图
PSENet的整体架构图如图1所示,包含以下阶段:
- 使用Resnet作为骨干网络,从2,3,4,5阶段进行不同层级的特征提取;
- 将提取到的特征放入FPN网络中,提取不同尺度的特征并拼接;
- 将第2阶段的特征采用PSE算法生成最后的分割结果,并生成文本边界框。
mindspore | ascend driver | firmware | cann toolkit/kernel |
---|---|---|---|
2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
请参考MindOCR套件的安装指南 。
请从该网址下载ICDAR2015数据集,然后参考数据转换对数据集标注进行格式转换。
完成数据准备工作后,数据的目录结构应该如下所示:
.
├── test
│ ├── images
│ │ ├── img_1.jpg
│ │ ├── img_2.jpg
│ │ └── ...
│ └── test_det_gt.txt
└── train
├── images
│ ├── img_1.jpg
│ ├── img_2.jpg
│ └── ....jpg
└── train_det_gt.txt
请从该网址下载SCUT-CTW1500数据集,然后参考数据转换对数据集标注进行格式转换。
完成数据准备工作后,数据的目录结构应该如下所示:
ctw1500
├── test_images
│ ├── 1001.jpg
│ ├── 1002.jpg
│ ├── ...
├── train_images
│ ├── 0001.jpg
│ ├── 0002.jpg
│ ├── ...
├── test_det_gt.txt
├── train_det_gt.txt
在配置文件configs/det/psenet/pse_r152_icdar15.yaml
中更新如下文件路径。其中dataset_root
会分别和data_dir
以及label_file
拼接构成完整的数据集目录和标签文件路径。
...
train:
ckpt_save_dir: './tmp_det'
dataset_sink_mode: False
dataset:
type: DetDataset
dataset_root: dir/to/dataset <--- 更新
data_dir: train/images <--- 更新
label_file: train/train_det_gt.txt <--- 更新
...
eval:
dataset_sink_mode: False
dataset:
type: DetDataset
dataset_root: dir/to/dataset <--- 更新
data_dir: test/images <--- 更新
label_file: test/test_det_gt.txt <--- 更新
...
【可选】可以根据CPU核的数量设置
num_workers
参数的值。
PSENet由3个部分组成:backbone
、neck
和head
。具体来说:
model:
type: det
transform: null
backbone:
name: det_resnet152
pretrained: True # 是否使用ImageNet数据集上的预训练权重
neck:
name: PSEFPN # PSENet的特征金字塔网络
out_channels: 128
head:
name: PSEHead
hidden_size: 256
out_channels: 7 # kernels数量
- 后处理
训练前,请确保在/mindocr/postprocess/pse目录下按照以下方式编译后处理代码:
python3 setup.py build_ext --inplace
- 单卡训练
请确保yaml文件中的distribute
参数为False。
# train psenet on ic15 dataset
python tools/train.py --config configs/det/psenet/pse_r152_icdar15.yaml
- 分布式训练
请确保yaml文件中的distribute
参数为True。
# n is the number of NPUs
mpirun --allow-run-as-root -n 8 python tools/train.py --config configs/det/psenet/pse_r152_icdar15.yaml
训练结果(包括checkpoint、每个epoch的性能和曲线图)将被保存在yaml配置文件的ckpt_save_dir
参数配置的路径下,默认为./tmp_det
。
评估环节,在yaml配置文件中将ckpt_load_path
参数配置为checkpoint文件的路径,设置distribute
为False,然后运行:
python tools/eval.py --config configs/det/psenet/pse_r152_icdar15.yaml
请参考MindOCR 推理教程,基于MindSpore Lite在Ascend 310上进行模型的推理,包括以下步骤:
- 模型导出
请先下载已导出的MindIR文件,或者参考模型导出教程,使用以下命令将训练完成的ckpt导出为MindIR文件:
python tools/export.py --model_name_or_config psenet_resnet152 --data_shape 1472 2624 --local_ckpt_path /path/to/local_ckpt.ckpt
# or
python tools/export.py --model_name_or_config configs/det/psenet/pse_r152_icdar15.yaml --data_shape 1472 2624 --local_ckpt_path /path/to/local_ckpt.ckpt
其中,data_shape
是导出MindIR时的模型输入Shape的height和width,下载链接中MindIR对应的shape值见注释。
- 环境搭建
请参考环境安装教程,配置MindSpore Lite推理运行环境。
- 模型转换
请参考模型转换教程,使用converter_lite
工具对MindIR模型进行离线转换。
- 执行推理
在进行推理前,请确保PSENet的后处理部分已编译,参考训练的后处理部分。
假设在模型转换后得到output.mindir文件,在deploy/py_infer
目录下使用以下命令进行推理:
python infer.py \
--input_images_dir=/your_path_to/test_images \
--det_model_path=your_path_to/output.mindir \
--det_model_name_or_config=../../configs/det/psenet/pse_r152_icdar15.yaml \
--res_save_dir=results_dir
PSENet在ICDAR2015,SCUT-CTW1500数据集上训练。另外,我们在ImageNet数据集上进行了预训练,并提供预训练权重下载链接。所有训练结果如下:
model name | backbone | pretrained | cards | batch size | jit level | graph compile | ms/step | img/s | recall | precision | f-score | recipe | weight |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
PSENet | ResNet-152 | ImageNet | 8 | 8 | O2 | 225.02 s | 355.19 | 180.19 | 78.91% | 84.70% | 81.70% | yaml | ckpt | mindir |
PSENet | ResNet-50 | ImageNet | 1 | 8 | O2 | 185.16 s | 280.21 | 228.40 | 76.55% | 86.51% | 81.23% | yaml | ckpt | mindir |
PSENet | MobileNetV3 | ImageNet | 8 | 8 | O2 | 181.54 s | 175.23 | 365.23 | 73.95% | 67.78% | 70.73% | yaml | ckpt | mindir |
model name | backbone | pretrained | cards | batch size | jit level | graph compile | ms/step | img/s | recall | precision | f-score | recipe | weight |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
PSENet | ResNet-152 | ImageNet | 8 | 8 | O2 | 193.59 s | 318.94 | 200.66 | 74.11% | 73.45% | 73.78% | yaml | ckpt | mindir |
- PSENet的训练时长受数据处理部分超参和不同运行环境的影响非常大。
- 在ICDAR15数据集上,以ResNet-152为backbone的MindIR导出时的输入Shape为
(1,3,1472,2624)
,以ResNet-50或MobileNetV3为backbone的MindIR导出时的输入Shape为(1,3,736,1312)
。 - 在SCUT-CTW1500数据集上,MindIR导出时的输入Shape为
(1,3,1024,1024)
。
[1] Wang, Wenhai, et al. "Shape robust text detection with progressive scale expansion network." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.