add_pdf2docx_api
commit
1a9926a7fa
12
README.md
12
README.md
|
@ -27,11 +27,11 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
|
|||
|
||||
## Recent updates
|
||||
- **🔥2022.8.24 Release PaddleOCR [release/2.6](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6)**
|
||||
- Release [PP-Structurev2](./ppstructure/),with functions and performance fully upgraded, adapted to Chinese scenes, and new support for [Layout Recovery](./ppstructure/recovery) and **one line command to convert PDF to Word**;
|
||||
- Release [PP-StructureV2](./ppstructure/),with functions and performance fully upgraded, adapted to Chinese scenes, and new support for [Layout Recovery](./ppstructure/recovery) and **one line command to convert PDF to Word**;
|
||||
- [Layout Analysis](./ppstructure/layout) optimization: model storage reduced by 95%, while speed increased by 11 times, and the average CPU time-cost is only 41ms;
|
||||
- [Table Recognition](./ppstructure/table) optimization: 3 optimization strategies are designed, and the model accuracy is improved by 6% under comparable time consumption;
|
||||
- [Key Information Extraction](./ppstructure/kie) optimization:a visual-independent model structure is designed, the accuracy of semantic entity recognition is increased by 2.8%, and the accuracy of relation extraction is increased by 9.1%.
|
||||
|
||||
|
||||
- **🔥2022.7 Release [OCR scene application collection](./applications/README_en.md)**
|
||||
- Release **9 vertical models** such as digital tube, LCD screen, license plate, handwriting recognition model, high-precision SVTR model, etc, covering the main OCR vertical applications in general, manufacturing, finance, and transportation industries.
|
||||
|
||||
|
@ -129,7 +129,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
|
|||
- [Text recognition](./doc/doc_en/algorithm_overview_en.md)
|
||||
- [End-to-end OCR](./doc/doc_en/algorithm_overview_en.md)
|
||||
- [Table Recognition](./doc/doc_en/algorithm_overview_en.md)
|
||||
- [Key Information Extraction](./doc/doc_en/algorithm_overview_en.md)
|
||||
- [Key Information Extraction](./doc/doc_en/algorithm_overview_en.md)
|
||||
- [Add New Algorithms to PaddleOCR](./doc/doc_en/add_new_algorithm_en.md)
|
||||
- Data Annotation and Synthesis
|
||||
- [Semi-automatic Annotation Tool: PPOCRLabel](./PPOCRLabel/README.md)
|
||||
|
@ -181,7 +181,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
|
|||
</details>
|
||||
|
||||
<details open>
|
||||
<summary>PP-Structurev2</summary>
|
||||
<summary>PP-StructureV2</summary>
|
||||
|
||||
- layout analysis + table recognition
|
||||
<div align="center">
|
||||
|
@ -192,7 +192,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
|
|||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186094456-01a1dd11-1433-4437-9ab2-6480ac94ec0a.png" width="600">
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185310636-6ce02f7c-790d-479f-b163-ea97a5a04808.jpg" width="600">
|
||||
</div>
|
||||
|
@ -204,7 +204,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
|
|||
- RE (Relation Extraction)
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186094813-3a8e16cc-42e5-4982-b9f4-0134dfb5688d.png" width="600">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185393805-c67ff571-cf7e-4217-a4b0-8b396c4f22bb.jpg" width="600">
|
||||
|
|
12
README_ch.md
12
README_ch.md
|
@ -28,14 +28,14 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
|
|||
## 近期更新
|
||||
|
||||
- **🔥2022.8.24 发布 PaddleOCR [release/2.6](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6)**
|
||||
- 发布[PP-Structurev2](./ppstructure/),系统功能性能全面升级,适配中文场景,新增支持[版面复原](./ppstructure/recovery),支持**一行命令完成PDF转Word**;
|
||||
- 发布[PP-StructureV2](./ppstructure/),系统功能性能全面升级,适配中文场景,新增支持[版面复原](./ppstructure/recovery),支持**一行命令完成PDF转Word**;
|
||||
- [版面分析](./ppstructure/layout)模型优化:模型存储减少95%,速度提升11倍,平均CPU耗时仅需41ms;
|
||||
- [表格识别](./ppstructure/table)模型优化:设计3大优化策略,预测耗时不变情况下,模型精度提升6%;
|
||||
- [关键信息抽取](./ppstructure/kie)模型优化:设计视觉无关模型结构,语义实体识别精度提升2.8%,关系抽取精度提升9.1%。
|
||||
|
||||
|
||||
- **🔥2022.8 发布 [OCR场景应用集合](./applications)**
|
||||
- 包含数码管、液晶屏、车牌、高精度SVTR模型、手写体识别等**9个垂类模型**,覆盖通用,制造、金融、交通行业的主要OCR垂类应用。
|
||||
|
||||
|
||||
- **2022.5.9 发布 PaddleOCR [release/2.5](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.5)**
|
||||
- 发布[PP-OCRv3](./doc/doc_ch/ppocr_introduction.md#pp-ocrv3),速度可比情况下,中文场景效果相比于PP-OCRv2再提升5%,英文场景提升11%,80语种多语言模型平均识别准确率提升5%以上;
|
||||
- 发布半自动标注工具[PPOCRLabelv2](./PPOCRLabel):新增表格文字图像、图像关键信息抽取任务和不规则文字图像的标注功能;
|
||||
|
@ -220,11 +220,11 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
|
|||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185539517-ccf2372a-f026-4a7c-ad28-c741c770f60a.png" width="600">
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186094456-01a1dd11-1433-4437-9ab2-6480ac94ec0a.png" width="600">
|
||||
</div>
|
||||
|
||||
|
||||
- RE(关系提取)
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185393805-c67ff571-cf7e-4217-a4b0-8b396c4f22bb.jpg" width="600">
|
||||
|
@ -237,7 +237,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
|
|||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186094813-3a8e16cc-42e5-4982-b9f4-0134dfb5688d.png" width="600">
|
||||
</div>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<a name="许可证书"></a>
|
||||
|
|
|
@ -54,6 +54,7 @@ PostProcess:
|
|||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
det_box_type: 'quad' # 'quad' or 'poly'
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
|
|
@ -54,6 +54,7 @@ PostProcess:
|
|||
box_thresh: 0.5
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
det_box_type: 'quad' # 'quad' or 'poly'
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/det_r50_drrg_ctw/
|
||||
save_epoch_step: 100
|
||||
# evaluation is run every 1260 iterations
|
||||
eval_batch_step: [37800, 1260]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_drrg/predicts_drrg.txt
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DRRG
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet_vd
|
||||
layers: 50
|
||||
Neck:
|
||||
name: FPN_UNet
|
||||
in_channels: [256, 512, 1024, 2048]
|
||||
out_channels: 32
|
||||
Head:
|
||||
name: DRRGHead
|
||||
in_channels: 32
|
||||
text_region_thr: 0.3
|
||||
center_region_thr: 0.4
|
||||
Loss:
|
||||
name: DRRGLoss
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: DecayLearningRate
|
||||
learning_rate: 0.028
|
||||
epochs: 1200
|
||||
factor: 0.9
|
||||
end_lr: 0.0000001
|
||||
weight_decay: 0.0001
|
||||
|
||||
PostProcess:
|
||||
name: DRRGPostprocess
|
||||
link_thr: 0.8
|
||||
|
||||
Metric:
|
||||
name: DetFCEMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ctw1500/imgs/
|
||||
label_file_list:
|
||||
- ./train_data/ctw1500/imgs/training.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
ignore_orientation: True
|
||||
- DetLabelEncode: # Class handling label
|
||||
- ColorJitter:
|
||||
brightness: 0.12549019607843137
|
||||
saturation: 0.5
|
||||
- RandomScaling:
|
||||
- RandomCropFlip:
|
||||
crop_ratio: 0.5
|
||||
- RandomCropPolyInstances:
|
||||
crop_ratio: 0.8
|
||||
min_side_ratio: 0.3
|
||||
- RandomRotatePolyInstances:
|
||||
rotate_ratio: 0.5
|
||||
max_angle: 60
|
||||
pad_with_fixed_color: False
|
||||
- SquareResizePad:
|
||||
target_size: 800
|
||||
pad_ratio: 0.6
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- DRRGTargets:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||
'gt_top_height_map', 'gt_bot_height_map', 'gt_sin_map',
|
||||
'gt_cos_map', 'gt_comp_attribs'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 4
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ctw1500/imgs/
|
||||
label_file_list:
|
||||
- ./train_data/ctw1500/imgs/test.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
ignore_orientation: True
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
limit_type: 'min'
|
||||
limit_side_len: 640
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- Pad:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -70,16 +70,14 @@ Loss:
|
|||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: hidden_states
|
||||
index: 5
|
||||
key: hidden_states_5
|
||||
name: "loss_5"
|
||||
- DistillationVQADistanceLoss:
|
||||
weight: 0.5
|
||||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: hidden_states
|
||||
index: 8
|
||||
key: hidden_states_8
|
||||
name: "loss_8"
|
||||
|
||||
|
||||
|
@ -182,4 +180,3 @@ Eval:
|
|||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# DRRG
|
||||
|
||||
- [1. 算法简介](#1-算法简介)
|
||||
- [2. 环境配置](#2-环境配置)
|
||||
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
|
||||
- [4. 推理部署](#4-推理部署)
|
||||
- [4.1 Python推理](#41-python推理)
|
||||
- [4.2 C++推理](#42-c推理)
|
||||
- [4.3 Serving服务化部署](#43-serving服务化部署)
|
||||
- [4.4 更多推理部署](#44-更多推理部署)
|
||||
- [5. FAQ](#5-faq)
|
||||
- [引用](#引用)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection](https://arxiv.org/abs/2003.07493)
|
||||
> Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng
|
||||
> CVPR, 2020
|
||||
|
||||
在CTW1500文本检测公开数据集上,算法复现效果如下:
|
||||
|
||||
| 模型 |骨干网络|配置文件|precision|recall|Hmean|下载链接|
|
||||
|-----| --- | --- | --- | --- | --- | --- |
|
||||
| DRRG | ResNet50_vd | [configs/det/det_r50_drrg_ctw.yml](../../configs/det/det_r50_drrg_ctw.yml)| 89.92%|80.91%|85.18%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
上述DRRG模型使用CTW1500文本检测公开数据集训练得到,数据集下载可参考 [ocr_datasets](./dataset/ocr_datasets.md)。
|
||||
|
||||
数据下载完成后,请参考[文本检测训练教程](./detection.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
|
||||
由于模型前向运行时需要多次转换为Numpy数据进行运算,因此DRRG的动态图转静态图暂未支持。
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@inproceedings{zhang2020deep,
|
||||
title={Deep relational reasoning graph network for arbitrary shape text detection},
|
||||
author={Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={9699--9708},
|
||||
year={2020}
|
||||
}
|
||||
```
|
|
@ -29,6 +29,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
- [x] [SAST](./algorithm_det_sast.md)
|
||||
- [x] [PSENet](./algorithm_det_psenet.md)
|
||||
- [x] [FCENet](./algorithm_det_fcenet.md)
|
||||
- [x] [DRRG](./algorithm_det_drrg.md)
|
||||
|
||||
在ICDAR2015文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|
@ -54,6 +55,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|FCE|ResNet50_dcn|88.39%|82.18%|85.27%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar)|
|
||||
|DRRG|ResNet50_vd|89.92%|80.91%|85.18%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
|
||||
|
||||
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:
|
||||
* [百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
|
||||
|
@ -103,7 +105,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|
||||
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@
|
|||
|
||||
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_visual_train.tar)|
|
||||
|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# DRRG
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection](https://arxiv.org/abs/2003.07493)
|
||||
> Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng
|
||||
> CVPR, 2020
|
||||
|
||||
On the CTW1500 dataset, the text detection result is as follows:
|
||||
|
||||
|Model|Backbone|Configuration|Precision|Recall|Hmean|Download|
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| DRRG | ResNet50_vd | [configs/det/det_r50_drrg_ctw.yml](../../configs/det/det_r50_drrg_ctw.yml)| 89.92%|80.91%|85.18%|[trained model](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md).
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
The above DRRG model is trained using the CTW1500 text detection public dataset. For the download of the dataset, please refer to [ocr_datasets](./dataset/ocr_datasets_en.md).
|
||||
|
||||
After the data download is complete, please refer to [Text Detection Training Tutorial](./detection_en.md) for training. PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models.
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
|
||||
Since the model needs to be converted to Numpy data for many times in the forward, DRRG dynamic graph to static graph is not supported.
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{zhang2020deep,
|
||||
title={Deep relational reasoning graph network for arbitrary shape text detection},
|
||||
author={Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={9699--9708},
|
||||
year={2020}
|
||||
}
|
||||
```
|
|
@ -27,6 +27,7 @@ Supported text detection algorithms (Click the link to get the tutorial):
|
|||
- [x] [SAST](./algorithm_det_sast_en.md)
|
||||
- [x] [PSENet](./algorithm_det_psenet_en.md)
|
||||
- [x] [FCENet](./algorithm_det_fcenet_en.md)
|
||||
- [x] [DRRG](./algorithm_det_drrg_en.md)
|
||||
|
||||
On the ICDAR2015 dataset, the text detection result is as follows:
|
||||
|
||||
|
@ -52,6 +53,7 @@ On CTW1500 dataset, the text detection result is as follows:
|
|||
|Model|Backbone|Precision|Recall|Hmean| Download link|
|
||||
| --- | --- | --- | --- | --- |---|
|
||||
|FCE|ResNet50_dcn|88.39%|82.18%|85.27%| [trained model](https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar) |
|
||||
|DRRG|ResNet50_vd|89.92%|80.91%|85.18%|[trained model](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
|
||||
|
||||
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from:
|
||||
* [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
|
||||
|
@ -100,7 +102,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|
||||
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
|
|
@ -25,8 +25,8 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|
|||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_visual_train.tar)|
|
||||
|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
|
|
|
@ -85,9 +85,9 @@ For English recognition model inference, you can execute the following commands,
|
|||
|
||||
```
|
||||
# download en model:
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
|
||||
tar xf en_PP-OCRv3_det_infer.tar
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar
|
||||
tar xf en_PP-OCRv3_rec_infer.tar
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_rec_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
|
||||
```
|
||||
|
||||

|
||||
|
|
|
@ -45,6 +45,7 @@ from .vqa import *
|
|||
from .fce_aug import *
|
||||
from .fce_targets import FCENetTargets
|
||||
from .ct_process import *
|
||||
from .drrg_targets import DRRGTargets
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -0,0 +1,696 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from lanms import merge_quadrangle_n9 as la_nms
|
||||
from numpy.linalg import norm
|
||||
|
||||
|
||||
class DRRGTargets(object):
|
||||
def __init__(self,
|
||||
orientation_thr=2.0,
|
||||
resample_step=8.0,
|
||||
num_min_comps=9,
|
||||
num_max_comps=600,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
center_region_shrink_ratio=0.3,
|
||||
comp_shrink_ratio=1.0,
|
||||
comp_w_h_ratio=0.3,
|
||||
text_comp_nms_thr=0.25,
|
||||
min_rand_half_height=8.0,
|
||||
max_rand_half_height=24.0,
|
||||
jitter_level=0.2,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.orientation_thr = orientation_thr
|
||||
self.resample_step = resample_step
|
||||
self.num_max_comps = num_max_comps
|
||||
self.num_min_comps = num_min_comps
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.comp_w_h_ratio = comp_w_h_ratio
|
||||
self.text_comp_nms_thr = text_comp_nms_thr
|
||||
self.min_rand_half_height = min_rand_half_height
|
||||
self.max_rand_half_height = max_rand_half_height
|
||||
self.jitter_level = jitter_level
|
||||
self.eps = 1e-8
|
||||
|
||||
def vector_angle(self, vec1, vec2):
|
||||
if vec1.ndim > 1:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps)
|
||||
if vec2.ndim > 1:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
|
||||
return np.arccos(
|
||||
np.clip(
|
||||
np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
||||
|
||||
def vector_slope(self, vec):
|
||||
assert len(vec) == 2
|
||||
return abs(vec[1] / (vec[0] + self.eps))
|
||||
|
||||
def vector_sin(self, vec):
|
||||
assert len(vec) == 2
|
||||
return vec[1] / (norm(vec) + self.eps)
|
||||
|
||||
def vector_cos(self, vec):
|
||||
assert len(vec) == 2
|
||||
return vec[0] / (norm(vec) + self.eps)
|
||||
|
||||
def find_head_tail(self, points, orientation_thr):
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
assert isinstance(orientation_thr, float)
|
||||
|
||||
if len(points) > 4:
|
||||
pad_points = np.vstack([points, points[0]])
|
||||
edge_vec = pad_points[1:] - pad_points[:-1]
|
||||
|
||||
theta_sum = []
|
||||
adjacent_vec_theta = []
|
||||
for i, edge_vec1 in enumerate(edge_vec):
|
||||
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
||||
adjacent_edge_vec = edge_vec[adjacent_ind]
|
||||
temp_theta_sum = np.sum(
|
||||
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
||||
temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
|
||||
adjacent_edge_vec[1])
|
||||
theta_sum.append(temp_theta_sum)
|
||||
adjacent_vec_theta.append(temp_adjacent_theta)
|
||||
theta_sum_score = np.array(theta_sum) / np.pi
|
||||
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
||||
poly_center = np.mean(points, axis=0)
|
||||
edge_dist = np.maximum(
|
||||
norm(
|
||||
pad_points[1:] - poly_center, axis=-1),
|
||||
norm(
|
||||
pad_points[:-1] - poly_center, axis=-1))
|
||||
dist_score = edge_dist / (np.max(edge_dist) + self.eps)
|
||||
position_score = np.zeros(len(edge_vec))
|
||||
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
||||
score += 0.35 * dist_score
|
||||
if len(points) % 2 == 0:
|
||||
position_score[(len(score) // 2 - 1)] += 1
|
||||
position_score[-1] += 1
|
||||
score += 0.1 * position_score
|
||||
pad_score = np.concatenate([score, score])
|
||||
score_matrix = np.zeros((len(score), len(score) - 3))
|
||||
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
||||
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
|
||||
(x - 0.5) / 0.5, 2.) / 2)
|
||||
gaussian = gaussian / np.max(gaussian)
|
||||
for i in range(len(score)):
|
||||
score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
|
||||
score) - 1)] * gaussian * 0.3
|
||||
|
||||
head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
|
||||
score_matrix.shape)
|
||||
tail_start = (head_start + tail_increment + 2) % len(points)
|
||||
head_end = (head_start + 1) % len(points)
|
||||
tail_end = (tail_start + 1) % len(points)
|
||||
|
||||
if head_end > tail_end:
|
||||
head_start, tail_start = tail_start, head_start
|
||||
head_end, tail_end = tail_end, head_end
|
||||
head_inds = [head_start, head_end]
|
||||
tail_inds = [tail_start, tail_end]
|
||||
else:
|
||||
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
|
||||
points[3] - points[2]) < self.vector_slope(points[
|
||||
2] - points[1]) + self.vector_slope(points[0] - points[
|
||||
3]):
|
||||
horizontal_edge_inds = [[0, 1], [2, 3]]
|
||||
vertical_edge_inds = [[3, 0], [1, 2]]
|
||||
else:
|
||||
horizontal_edge_inds = [[3, 0], [1, 2]]
|
||||
vertical_edge_inds = [[0, 1], [2, 3]]
|
||||
|
||||
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
|
||||
vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
|
||||
0]] - points[vertical_edge_inds[1][1]])
|
||||
horizontal_len_sum = norm(points[horizontal_edge_inds[0][
|
||||
0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
|
||||
horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
|
||||
[1]])
|
||||
|
||||
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
||||
head_inds = horizontal_edge_inds[0]
|
||||
tail_inds = horizontal_edge_inds[1]
|
||||
else:
|
||||
head_inds = vertical_edge_inds[0]
|
||||
tail_inds = vertical_edge_inds[1]
|
||||
|
||||
return head_inds, tail_inds
|
||||
|
||||
def reorder_poly_edge(self, points):
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
|
||||
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
|
||||
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
||||
|
||||
pad_points = np.vstack([points, points])
|
||||
if tail_inds[1] < 1:
|
||||
tail_inds[1] = len(points)
|
||||
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
|
||||
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
|
||||
sideline_mean_shift = np.mean(
|
||||
sideline1, axis=0) - np.mean(
|
||||
sideline2, axis=0)
|
||||
|
||||
if sideline_mean_shift[1] > 0:
|
||||
top_sideline, bot_sideline = sideline2, sideline1
|
||||
else:
|
||||
top_sideline, bot_sideline = sideline1, sideline2
|
||||
|
||||
return head_edge, tail_edge, top_sideline, bot_sideline
|
||||
|
||||
def cal_curve_length(self, line):
|
||||
|
||||
assert line.ndim == 2
|
||||
assert len(line) >= 2
|
||||
|
||||
edges_length = np.sqrt((line[1:, 0] - line[:-1, 0])**2 + (line[
|
||||
1:, 1] - line[:-1, 1])**2)
|
||||
total_length = np.sum(edges_length)
|
||||
return edges_length, total_length
|
||||
|
||||
def resample_line(self, line, n):
|
||||
|
||||
assert line.ndim == 2
|
||||
assert line.shape[0] >= 2
|
||||
assert line.shape[1] == 2
|
||||
assert isinstance(n, int)
|
||||
assert n > 2
|
||||
|
||||
edges_length, total_length = self.cal_curve_length(line)
|
||||
t_org = np.insert(np.cumsum(edges_length), 0, 0)
|
||||
unit_t = total_length / (n - 1)
|
||||
t_equidistant = np.arange(1, n - 1, dtype=np.float32) * unit_t
|
||||
edge_ind = 0
|
||||
points = [line[0]]
|
||||
for t in t_equidistant:
|
||||
while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
|
||||
edge_ind += 1
|
||||
t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
|
||||
weight = np.array(
|
||||
[t_r - t, t - t_l], dtype=np.float32) / (t_r - t_l + self.eps)
|
||||
p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
|
||||
points.append(p_coords)
|
||||
points.append(line[-1])
|
||||
resampled_line = np.vstack(points)
|
||||
|
||||
return resampled_line
|
||||
|
||||
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
||||
|
||||
assert sideline1.ndim == sideline2.ndim == 2
|
||||
assert sideline1.shape[1] == sideline2.shape[1] == 2
|
||||
assert sideline1.shape[0] >= 2
|
||||
assert sideline2.shape[0] >= 2
|
||||
assert isinstance(resample_step, float)
|
||||
|
||||
_, length1 = self.cal_curve_length(sideline1)
|
||||
_, length2 = self.cal_curve_length(sideline2)
|
||||
|
||||
avg_length = (length1 + length2) / 2
|
||||
resample_point_num = max(int(float(avg_length) / resample_step) + 1, 3)
|
||||
|
||||
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
||||
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
||||
|
||||
return resampled_line1, resampled_line2
|
||||
|
||||
def dist_point2line(self, point, line):
|
||||
|
||||
assert isinstance(line, tuple)
|
||||
point1, point2 = line
|
||||
d = abs(np.cross(point2 - point1, point - point1)) / (
|
||||
norm(point2 - point1) + 1e-8)
|
||||
return d
|
||||
|
||||
def draw_center_region_maps(self, top_line, bot_line, center_line,
|
||||
center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map,
|
||||
region_shrink_ratio):
|
||||
|
||||
assert top_line.shape == bot_line.shape == center_line.shape
|
||||
assert (center_region_mask.shape == top_height_map.shape ==
|
||||
bot_height_map.shape == sin_map.shape == cos_map.shape)
|
||||
assert isinstance(region_shrink_ratio, float)
|
||||
|
||||
h, w = center_region_mask.shape
|
||||
for i in range(0, len(center_line) - 1):
|
||||
|
||||
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
|
||||
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
|
||||
|
||||
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
|
||||
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
|
||||
|
||||
tl = center_line[i] + (top_line[i] - center_line[i]
|
||||
) * region_shrink_ratio
|
||||
tr = center_line[i + 1] + (top_line[i + 1] - center_line[i + 1]
|
||||
) * region_shrink_ratio
|
||||
br = center_line[i + 1] + (bot_line[i + 1] - center_line[i + 1]
|
||||
) * region_shrink_ratio
|
||||
bl = center_line[i] + (bot_line[i] - center_line[i]
|
||||
) * region_shrink_ratio
|
||||
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
|
||||
|
||||
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
|
||||
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
|
||||
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
|
||||
|
||||
current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
|
||||
w - 1)
|
||||
current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
|
||||
h - 1)
|
||||
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
|
||||
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
|
||||
current_center_box = current_center_box - min_coord
|
||||
box_sz = (max_coord - min_coord + 1)
|
||||
|
||||
center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
|
||||
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
|
||||
|
||||
inds = np.argwhere(center_box_mask > 0)
|
||||
inds = inds + (min_coord[1], min_coord[0])
|
||||
inds_xy = np.fliplr(inds)
|
||||
top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
|
||||
inds_xy, (top_line[i], top_line[i + 1]))
|
||||
bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
|
||||
inds_xy, (bot_line[i], bot_line[i + 1]))
|
||||
|
||||
def generate_center_mask_attrib_maps(self, img_size, text_polys):
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_lines = []
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
top_height_map = np.zeros((h, w), dtype=np.float32)
|
||||
bot_height_map = np.zeros((h, w), dtype=np.float32)
|
||||
sin_map = np.zeros((h, w), dtype=np.float32)
|
||||
cos_map = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
polygon_points = poly
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
if self.vector_slope(center_line[-1] - center_line[0]) > 2:
|
||||
if (center_line[-1] - center_line[0])[1] < 0:
|
||||
center_line = center_line[::-1]
|
||||
resampled_top_line = resampled_top_line[::-1]
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
else:
|
||||
if (center_line[-1] - center_line[0])[0] < 0:
|
||||
center_line = center_line[::-1]
|
||||
resampled_top_line = resampled_top_line[::-1]
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
|
||||
line_head_shrink_len = np.clip(
|
||||
(norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
|
||||
self.min_width, self.max_width) / 2
|
||||
line_tail_shrink_len = np.clip(
|
||||
(norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
|
||||
self.min_width, self.max_width) / 2
|
||||
num_head_shrink = int(line_head_shrink_len // self.resample_step)
|
||||
num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
|
||||
if len(center_line) > num_head_shrink + num_tail_shrink + 2:
|
||||
center_line = center_line[num_head_shrink:len(center_line) -
|
||||
num_tail_shrink]
|
||||
resampled_top_line = resampled_top_line[num_head_shrink:len(
|
||||
resampled_top_line) - num_tail_shrink]
|
||||
resampled_bot_line = resampled_bot_line[num_head_shrink:len(
|
||||
resampled_bot_line) - num_tail_shrink]
|
||||
center_lines.append(center_line.astype(np.int32))
|
||||
|
||||
self.draw_center_region_maps(
|
||||
resampled_top_line, resampled_bot_line, center_line,
|
||||
center_region_mask, top_height_map, bot_height_map, sin_map,
|
||||
cos_map, self.center_region_shrink_ratio)
|
||||
|
||||
return (center_lines, center_region_mask, top_height_map,
|
||||
bot_height_map, sin_map, cos_map)
|
||||
|
||||
def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
|
||||
|
||||
assert isinstance(num_rand_comps, int)
|
||||
assert num_rand_comps > 0
|
||||
assert center_sample_mask.ndim == 2
|
||||
|
||||
h, w = center_sample_mask.shape
|
||||
|
||||
max_rand_half_height = self.max_rand_half_height
|
||||
min_rand_half_height = self.min_rand_half_height
|
||||
max_rand_height = max_rand_half_height * 2
|
||||
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
|
||||
self.min_width, self.max_width)
|
||||
margin = int(
|
||||
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
|
||||
|
||||
if 2 * margin + 1 > min(h, w):
|
||||
|
||||
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
|
||||
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
|
||||
min_rand_half_height = max(max_rand_half_height / 4,
|
||||
self.min_width / 2)
|
||||
|
||||
max_rand_height = max_rand_half_height * 2
|
||||
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
|
||||
self.min_width, self.max_width)
|
||||
margin = int(
|
||||
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
|
||||
|
||||
inner_center_sample_mask = np.zeros_like(center_sample_mask)
|
||||
inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
|
||||
center_sample_mask[margin:h - margin, margin:w - margin]
|
||||
kernel_size = int(np.clip(max_rand_half_height, 7, 21))
|
||||
inner_center_sample_mask = cv2.erode(
|
||||
inner_center_sample_mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8))
|
||||
|
||||
center_candidates = np.argwhere(inner_center_sample_mask > 0)
|
||||
num_center_candidates = len(center_candidates)
|
||||
sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
|
||||
rand_centers = center_candidates[sample_inds]
|
||||
|
||||
rand_top_height = np.random.randint(
|
||||
min_rand_half_height,
|
||||
max_rand_half_height,
|
||||
size=(len(rand_centers), 1))
|
||||
rand_bot_height = np.random.randint(
|
||||
min_rand_half_height,
|
||||
max_rand_half_height,
|
||||
size=(len(rand_centers), 1))
|
||||
|
||||
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
||||
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
||||
scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
|
||||
rand_cos = rand_cos * scale
|
||||
rand_sin = rand_sin * scale
|
||||
|
||||
height = (rand_top_height + rand_bot_height)
|
||||
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
|
||||
self.max_width)
|
||||
|
||||
rand_comp_attribs = np.hstack([
|
||||
rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
|
||||
np.zeros_like(rand_sin)
|
||||
]).astype(np.float32)
|
||||
|
||||
return rand_comp_attribs
|
||||
|
||||
def jitter_comp_attribs(self, comp_attribs, jitter_level):
|
||||
"""Jitter text components attributes.
|
||||
|
||||
Args:
|
||||
comp_attribs (ndarray): The text component attributes.
|
||||
jitter_level (float): The jitter level of text components
|
||||
attributes.
|
||||
|
||||
Returns:
|
||||
jittered_comp_attribs (ndarray): The jittered text component
|
||||
attributes (x, y, h, w, cos, sin, comp_label).
|
||||
"""
|
||||
|
||||
assert comp_attribs.shape[1] == 7
|
||||
assert comp_attribs.shape[0] > 0
|
||||
assert isinstance(jitter_level, float)
|
||||
|
||||
x = comp_attribs[:, 0].reshape((-1, 1))
|
||||
y = comp_attribs[:, 1].reshape((-1, 1))
|
||||
h = comp_attribs[:, 2].reshape((-1, 1))
|
||||
w = comp_attribs[:, 3].reshape((-1, 1))
|
||||
cos = comp_attribs[:, 4].reshape((-1, 1))
|
||||
sin = comp_attribs[:, 5].reshape((-1, 1))
|
||||
comp_labels = comp_attribs[:, 6].reshape((-1, 1))
|
||||
|
||||
x += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
|
||||
h * np.abs(cos) + w * np.abs(sin)) * jitter_level
|
||||
y += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
|
||||
h * np.abs(sin) + w * np.abs(cos)) * jitter_level
|
||||
|
||||
h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
||||
) * h * jitter_level
|
||||
w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
||||
) * w * jitter_level
|
||||
|
||||
cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
||||
) * 2 * jitter_level
|
||||
sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
||||
) * 2 * jitter_level
|
||||
|
||||
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
|
||||
cos = cos * scale
|
||||
sin = sin * scale
|
||||
|
||||
jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
|
||||
|
||||
return jittered_comp_attribs
|
||||
|
||||
def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
|
||||
top_height_map, bot_height_map, sin_map, cos_map):
|
||||
"""Generate text component attributes.
|
||||
|
||||
Args:
|
||||
center_lines (list[ndarray]): The list of text center lines .
|
||||
text_mask (ndarray): The text region mask.
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
top_height_map (ndarray): The map on which the distance from points
|
||||
to top side lines will be drawn for each pixel in text center
|
||||
regions.
|
||||
bot_height_map (ndarray): The map on which the distance from points
|
||||
to bottom side lines will be drawn for each pixel in text
|
||||
center regions.
|
||||
sin_map (ndarray): The sin(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
cos_map (ndarray): The cos(theta) map where theta is the angle
|
||||
between vector (top point - bottom point) and vector (1, 0).
|
||||
|
||||
Returns:
|
||||
pad_comp_attribs (ndarray): The padded text component attributes
|
||||
of a fixed size.
|
||||
"""
|
||||
|
||||
assert isinstance(center_lines, list)
|
||||
assert (
|
||||
text_mask.shape == center_region_mask.shape == top_height_map.shape
|
||||
== bot_height_map.shape == sin_map.shape == cos_map.shape)
|
||||
|
||||
center_lines_mask = np.zeros_like(center_region_mask)
|
||||
cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
|
||||
center_lines_mask = center_lines_mask * center_region_mask
|
||||
comp_centers = np.argwhere(center_lines_mask > 0)
|
||||
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_height = top_height_map[y, x].reshape(
|
||||
(-1, 1)) * self.comp_shrink_ratio
|
||||
bot_height = bot_height_map[y, x].reshape(
|
||||
(-1, 1)) * self.comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_points = comp_centers + np.hstack(
|
||||
[top_height * sin, top_height * cos])
|
||||
bot_mid_points = comp_centers - np.hstack(
|
||||
[bot_height * sin, bot_height * cos])
|
||||
|
||||
width = (top_height + bot_height) * self.comp_w_h_ratio
|
||||
width = np.clip(width, self.min_width, self.max_width)
|
||||
r = width / 2
|
||||
|
||||
tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
|
||||
|
||||
score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
|
||||
text_comps = np.hstack([text_comps, score])
|
||||
text_comps = la_nms(text_comps, self.text_comp_nms_thr)
|
||||
|
||||
if text_comps.shape[0] >= 1:
|
||||
img_h, img_w = center_region_mask.shape
|
||||
text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
|
||||
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
|
||||
|
||||
comp_centers = np.mean(
|
||||
text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
|
||||
x = comp_centers[:, 0]
|
||||
y = comp_centers[:, 1]
|
||||
|
||||
height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
|
||||
(-1, 1))
|
||||
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
|
||||
self.max_width)
|
||||
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
|
||||
_, comp_label_mask = cv2.connectedComponents(
|
||||
center_region_mask, connectivity=8)
|
||||
comp_labels = comp_label_mask[y, x].reshape(
|
||||
(-1, 1)).astype(np.float32)
|
||||
|
||||
x = x.reshape((-1, 1)).astype(np.float32)
|
||||
y = y.reshape((-1, 1)).astype(np.float32)
|
||||
comp_attribs = np.hstack(
|
||||
[x, y, height, width, cos, sin, comp_labels])
|
||||
comp_attribs = self.jitter_comp_attribs(comp_attribs,
|
||||
self.jitter_level)
|
||||
|
||||
if comp_attribs.shape[0] < self.num_min_comps:
|
||||
num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
|
||||
rand_comp_attribs = self.generate_rand_comp_attribs(
|
||||
num_rand_comps, 1 - text_mask)
|
||||
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
|
||||
else:
|
||||
comp_attribs = self.generate_rand_comp_attribs(self.num_min_comps,
|
||||
1 - text_mask)
|
||||
|
||||
num_comps = (np.ones(
|
||||
(comp_attribs.shape[0], 1),
|
||||
dtype=np.float32) * comp_attribs.shape[0])
|
||||
comp_attribs = np.hstack([num_comps, comp_attribs])
|
||||
|
||||
if comp_attribs.shape[0] > self.num_max_comps:
|
||||
comp_attribs = comp_attribs[:self.num_max_comps, :]
|
||||
comp_attribs[:, 0] = self.num_max_comps
|
||||
|
||||
pad_comp_attribs = np.zeros(
|
||||
(self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
|
||||
pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
|
||||
|
||||
return pad_comp_attribs
|
||||
|
||||
def generate_text_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask and geometry attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
text_region_mask (ndarray): The text region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
|
||||
h, w = img_size
|
||||
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for poly in text_polys:
|
||||
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
||||
cv2.fillPoly(text_region_mask, polygon, 1)
|
||||
|
||||
return text_region_mask
|
||||
|
||||
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
|
||||
"""Generate effective mask by setting the ineffective regions to 0 and
|
||||
effective regions to 1.
|
||||
|
||||
Args:
|
||||
mask_size (tuple): The mask size.
|
||||
polygons_ignore (list[[ndarray]]: The list of ignored text
|
||||
polygons.
|
||||
|
||||
Returns:
|
||||
mask (ndarray): The effective mask of (height, width).
|
||||
"""
|
||||
mask = np.ones(mask_size, dtype=np.uint8)
|
||||
|
||||
for poly in polygons_ignore:
|
||||
instance = poly.astype(np.int32).reshape(1, -1, 2)
|
||||
cv2.fillPoly(mask, instance, 0)
|
||||
|
||||
return mask
|
||||
|
||||
def generate_targets(self, data):
|
||||
"""Generate the gt targets for DRRG.
|
||||
|
||||
Args:
|
||||
data (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
data (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
image = data['image']
|
||||
polygons = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
h, w, _ = image.shape
|
||||
|
||||
polygon_masks = []
|
||||
polygon_masks_ignore = []
|
||||
for tag, polygon in zip(ignore_tags, polygons):
|
||||
if tag is True:
|
||||
polygon_masks_ignore.append(polygon)
|
||||
else:
|
||||
polygon_masks.append(polygon)
|
||||
|
||||
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
|
||||
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
|
||||
(center_lines, gt_center_region_mask, gt_top_height_map,
|
||||
gt_bot_height_map, gt_sin_map,
|
||||
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
|
||||
polygon_masks)
|
||||
|
||||
gt_comp_attribs = self.generate_comp_attribs(
|
||||
center_lines, gt_text_mask, gt_center_region_mask,
|
||||
gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map)
|
||||
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_top_height_map': gt_top_height_map,
|
||||
'gt_bot_height_map': gt_bot_height_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
|
||||
data.update(mapping)
|
||||
data['gt_comp_attribs'] = gt_comp_attribs
|
||||
return data
|
||||
|
||||
def __call__(self, data):
|
||||
data = self.generate_targets(data)
|
||||
return data
|
|
@ -0,0 +1 @@
|
|||
from .roi_align_rotated.roi_align_rotated import RoIAlignRotated
|
|
@ -0,0 +1,528 @@
|
|||
|
||||
// This code is refer from:
|
||||
// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cpu/roi_align_rotated.cpp
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define PADDLE_WITH_CUDA
|
||||
#define CHECK_INPUT_SAME(x1, x2) \
|
||||
PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
|
||||
#define CHECK_INPUT_CPU(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
|
||||
|
||||
template <typename T> struct PreCalc {
|
||||
int pos1;
|
||||
int pos2;
|
||||
int pos3;
|
||||
int pos4;
|
||||
T w1;
|
||||
T w2;
|
||||
T w3;
|
||||
T w4;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void pre_calc_for_bilinear_interpolate(
|
||||
const int height, const int width, const int pooled_height,
|
||||
const int pooled_width, const int iy_upper, const int ix_upper,
|
||||
T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
|
||||
int roi_bin_grid_h, int roi_bin_grid_w, T roi_center_h, T roi_center_w,
|
||||
T cos_theta, T sin_theta, std::vector<PreCalc<T>> &pre_calc) {
|
||||
int pre_calc_index = 0;
|
||||
for (int ph = 0; ph < pooled_height; ph++) {
|
||||
for (int pw = 0; pw < pooled_width; pw++) {
|
||||
for (int iy = 0; iy < iy_upper; iy++) {
|
||||
const T yy = roi_start_h + ph * bin_size_h +
|
||||
static_cast<T>(iy + .5f) * bin_size_h /
|
||||
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < ix_upper; ix++) {
|
||||
const T xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<T>(ix + .5f) * bin_size_w /
|
||||
static_cast<T>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta around the center and translate
|
||||
// In image space, (y, x) is the order for Right Handed System,
|
||||
// and this is essentially multiplying the point by a rotation matrix
|
||||
// to rotate it counterclockwise through angle theta.
|
||||
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
||||
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
||||
// deal with: inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
// empty
|
||||
PreCalc<T> pc;
|
||||
pc.pos1 = 0;
|
||||
pc.pos2 = 0;
|
||||
pc.pos3 = 0;
|
||||
pc.pos4 = 0;
|
||||
pc.w1 = 0;
|
||||
pc.w2 = 0;
|
||||
pc.w3 = 0;
|
||||
pc.w4 = 0;
|
||||
pre_calc[pre_calc_index] = pc;
|
||||
pre_calc_index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (y < 0) {
|
||||
y = 0;
|
||||
}
|
||||
if (x < 0) {
|
||||
x = 0;
|
||||
}
|
||||
|
||||
int y_low = (int)y;
|
||||
int x_low = (int)x;
|
||||
int y_high;
|
||||
int x_high;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
// save weights and indices
|
||||
PreCalc<T> pc;
|
||||
pc.pos1 = y_low * width + x_low;
|
||||
pc.pos2 = y_low * width + x_high;
|
||||
pc.pos3 = y_high * width + x_low;
|
||||
pc.pos4 = y_high * width + x_high;
|
||||
pc.w1 = w1;
|
||||
pc.w2 = w2;
|
||||
pc.w3 = w3;
|
||||
pc.w4 = w4;
|
||||
pre_calc[pre_calc_index] = pc;
|
||||
|
||||
pre_calc_index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void roi_align_rotated_cpu_forward(const int nthreads, const T *input,
|
||||
const T &spatial_scale, const bool aligned,
|
||||
const bool clockwise, const int channels,
|
||||
const int height, const int width,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int sampling_ratio, const T *rois,
|
||||
T *output) {
|
||||
int n_rois = nthreads / channels / pooled_width / pooled_height;
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
// can be parallelized using omp
|
||||
// #pragma omp parallel for num_threads(32)
|
||||
for (int n = 0; n < n_rois; n++) {
|
||||
int index_n = n * channels * pooled_width * pooled_height;
|
||||
|
||||
const T *current_roi = rois + n * 6;
|
||||
int roi_batch_ind = current_roi[0];
|
||||
|
||||
// Do not use rounding; this implementation detail is critical
|
||||
T offset = aligned ? (T)0.5 : (T)0.0;
|
||||
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
||||
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
||||
T roi_width = current_roi[3] * spatial_scale;
|
||||
T roi_height = current_roi[4] * spatial_scale;
|
||||
T theta = current_roi[5];
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
T cos_theta = cos(theta);
|
||||
T sin_theta = sin(theta);
|
||||
|
||||
if (aligned) {
|
||||
assert(roi_width >= 0 && roi_height >= 0);
|
||||
} else { // for backward-compatibility only
|
||||
roi_width = std::max(roi_width, (T)1.);
|
||||
roi_height = std::max(roi_height, (T)1.);
|
||||
}
|
||||
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
||||
|
||||
// we want to precalculate indices and weights shared by all channels,
|
||||
// this is the key point of optimization
|
||||
std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
|
||||
pooled_width * pooled_height);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
T roi_start_h = -roi_height / 2.0;
|
||||
T roi_start_w = -roi_width / 2.0;
|
||||
|
||||
pre_calc_for_bilinear_interpolate(
|
||||
height, width, pooled_height, pooled_width, roi_bin_grid_h,
|
||||
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
|
||||
roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta,
|
||||
sin_theta, pre_calc);
|
||||
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int index_n_c = index_n + c * pooled_width * pooled_height;
|
||||
const T *offset_input =
|
||||
input + (roi_batch_ind * channels + c) * height * width;
|
||||
int pre_calc_index = 0;
|
||||
|
||||
for (int ph = 0; ph < pooled_height; ph++) {
|
||||
for (int pw = 0; pw < pooled_width; pw++) {
|
||||
int index = index_n_c + ph * pooled_width + pw;
|
||||
|
||||
T output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
PreCalc<T> pc = pre_calc[pre_calc_index];
|
||||
output_val += pc.w1 * offset_input[pc.pos1] +
|
||||
pc.w2 * offset_input[pc.pos2] +
|
||||
pc.w3 * offset_input[pc.pos3] +
|
||||
pc.w4 * offset_input[pc.pos4];
|
||||
|
||||
pre_calc_index += 1;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
output[index] = output_val;
|
||||
} // for pw
|
||||
} // for ph
|
||||
} // for c
|
||||
} // for n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
|
||||
T &w1, T &w2, T &w3, T &w4, int &x_low,
|
||||
int &x_high, int &y_low, int &y_high) {
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
// empty
|
||||
w1 = w2 = w3 = w4 = 0.;
|
||||
x_low = x_high = y_low = y_high = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
if (y < 0) {
|
||||
y = 0;
|
||||
}
|
||||
|
||||
if (x < 0) {
|
||||
x = 0;
|
||||
}
|
||||
|
||||
y_low = (int)y;
|
||||
x_low = (int)x;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
|
||||
// reference in forward
|
||||
// T v1 = input[y_low * width + x_low];
|
||||
// T v2 = input[y_low * width + x_high];
|
||||
// T v3 = input[y_high * width + x_low];
|
||||
// T v4 = input[y_high * width + x_high];
|
||||
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <class T> inline void add(T *address, const T &val) {
|
||||
*address += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void roi_align_rotated_cpu_backward(
|
||||
const int nthreads,
|
||||
// may not be contiguous. should index using n_stride, etc
|
||||
const T *grad_output, const T &spatial_scale, const bool aligned,
|
||||
const bool clockwise, const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, const int sampling_ratio,
|
||||
T *grad_input, const T *rois, const int n_stride, const int c_stride,
|
||||
const int h_stride, const int w_stride) {
|
||||
for (int index = 0; index < nthreads; index++) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T *current_roi = rois + n * 6;
|
||||
int roi_batch_ind = current_roi[0];
|
||||
|
||||
// Do not use rounding; this implementation detail is critical
|
||||
T offset = aligned ? (T)0.5 : (T)0.0;
|
||||
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
||||
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
||||
T roi_width = current_roi[3] * spatial_scale;
|
||||
T roi_height = current_roi[4] * spatial_scale;
|
||||
T theta = current_roi[5];
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
T cos_theta = cos(theta);
|
||||
T sin_theta = sin(theta);
|
||||
|
||||
if (aligned) {
|
||||
assert(roi_width >= 0 && roi_height >= 0);
|
||||
} else { // for backward-compatibility only
|
||||
roi_width = std::max(roi_width, (T)1.);
|
||||
roi_height = std::max(roi_height, (T)1.);
|
||||
}
|
||||
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
T *offset_grad_input =
|
||||
grad_input + ((roi_batch_ind * channels + c) * height * width);
|
||||
|
||||
int output_offset = n * n_stride + c * c_stride;
|
||||
const T *offset_grad_output = grad_output + output_offset;
|
||||
const T grad_output_this_bin =
|
||||
offset_grad_output[ph * h_stride + pw * w_stride];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
T roi_start_h = -roi_height / 2.0;
|
||||
T roi_start_w = -roi_width / 2.0;
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
||||
const T yy = roi_start_h + ph * bin_size_h +
|
||||
static_cast<T>(iy + .5f) * bin_size_h /
|
||||
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const T xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<T>(ix + .5f) * bin_size_w /
|
||||
static_cast<T>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta around the center and translate
|
||||
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
||||
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
||||
|
||||
T w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
|
||||
x_low, x_high, y_low, y_high);
|
||||
|
||||
T g1 = grad_output_this_bin * w1 / count;
|
||||
T g2 = grad_output_this_bin * w2 / count;
|
||||
T g3 = grad_output_this_bin * w3 / count;
|
||||
T g4 = grad_output_this_bin * w4 / count;
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
// atomic add is not needed for now since it is single threaded
|
||||
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
|
||||
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
|
||||
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
|
||||
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
|
||||
} // if
|
||||
} // ix
|
||||
} // iy
|
||||
} // for
|
||||
} // ROIAlignRotatedBackward
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
RoIAlignRotatedCPUForward(const paddle::Tensor &input,
|
||||
const paddle::Tensor &rois, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, bool aligned, bool clockwise) {
|
||||
CHECK_INPUT_CPU(input);
|
||||
CHECK_INPUT_CPU(rois);
|
||||
|
||||
auto num_rois = rois.shape()[0];
|
||||
|
||||
auto channels = input.shape()[1];
|
||||
auto height = input.shape()[2];
|
||||
auto width = input.shape()[3];
|
||||
|
||||
auto output =
|
||||
paddle::empty({num_rois, channels, aligned_height, aligned_width},
|
||||
input.type(), paddle::CPUPlace());
|
||||
auto output_size = output.numel();
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
input.type(), "roi_align_rotated_cpu_forward", ([&] {
|
||||
roi_align_rotated_cpu_forward<data_t>(
|
||||
output_size, input.data<data_t>(),
|
||||
static_cast<data_t>(spatial_scale), aligned, clockwise, channels,
|
||||
height, width, aligned_height, aligned_width, sampling_ratio,
|
||||
rois.data<data_t>(), output.data<data_t>());
|
||||
}));
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RoIAlignRotatedCPUBackward(
|
||||
const paddle::Tensor &input, const paddle::Tensor &rois,
|
||||
const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
|
||||
|
||||
auto batch_size = input.shape()[0];
|
||||
auto channels = input.shape()[1];
|
||||
auto height = input.shape()[2];
|
||||
auto width = input.shape()[3];
|
||||
|
||||
auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
|
||||
input.type(), paddle::CPUPlace());
|
||||
|
||||
// get stride values to ensure indexing into gradients is correct.
|
||||
int n_stride = grad_output.shape()[0];
|
||||
int c_stride = grad_output.shape()[1];
|
||||
int h_stride = grad_output.shape()[2];
|
||||
int w_stride = grad_output.shape()[3];
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
grad_output.type(), "roi_align_rotated_cpu_backward", [&] {
|
||||
roi_align_rotated_cpu_backward<data_t>(
|
||||
grad_output.numel(), grad_output.data<data_t>(),
|
||||
static_cast<data_t>(spatial_scale), aligned, clockwise, channels,
|
||||
height, width, aligned_height, aligned_width, sampling_ratio,
|
||||
grad_input.data<data_t>(), rois.data<data_t>(), n_stride, c_stride,
|
||||
h_stride, w_stride);
|
||||
});
|
||||
return {grad_input};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::vector<paddle::Tensor>
|
||||
RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
|
||||
const paddle::Tensor &rois, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, bool aligned, bool clockwise);
|
||||
#endif
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
|
||||
const paddle::Tensor &input, const paddle::Tensor &rois,
|
||||
const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio, bool aligned, bool clockwise);
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
RoIAlignRotatedForward(const paddle::Tensor &input, const paddle::Tensor &rois,
|
||||
int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio, bool aligned,
|
||||
bool clockwise) {
|
||||
CHECK_INPUT_SAME(input, rois);
|
||||
if (input.is_cpu()) {
|
||||
return RoIAlignRotatedCPUForward(input, rois, aligned_height, aligned_width,
|
||||
spatial_scale, sampling_ratio, aligned,
|
||||
clockwise);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (input.is_gpu()) {
|
||||
return RoIAlignRotatedCUDAForward(input, rois, aligned_height,
|
||||
aligned_width, spatial_scale,
|
||||
sampling_ratio, aligned, clockwise);
|
||||
#endif
|
||||
} else {
|
||||
PD_THROW("Unsupported device type for forward function of roi align "
|
||||
"rotated operator.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
RoIAlignRotatedBackward(const paddle::Tensor &input, const paddle::Tensor &rois,
|
||||
const paddle::Tensor &grad_output, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, bool aligned, bool clockwise) {
|
||||
CHECK_INPUT_SAME(input, rois);
|
||||
if (input.is_cpu()) {
|
||||
return RoIAlignRotatedCPUBackward(input, rois, grad_output, aligned_height,
|
||||
aligned_width, spatial_scale,
|
||||
sampling_ratio, aligned, clockwise);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (input.is_gpu()) {
|
||||
return RoIAlignRotatedCUDABackward(input, rois, grad_output, aligned_height,
|
||||
aligned_width, spatial_scale,
|
||||
sampling_ratio, aligned, clockwise);
|
||||
#endif
|
||||
} else {
|
||||
PD_THROW("Unsupported device type for forward function of roi align "
|
||||
"rotated operator.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> input_shape,
|
||||
std::vector<int64_t> rois_shape) {
|
||||
return {{rois_shape[0], input_shape[1], input_shape[2], input_shape[3]}};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
InferBackShape(std::vector<int64_t> input_shape,
|
||||
std::vector<int64_t> rois_shape) {
|
||||
return {input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> InferDtype(paddle::DataType input_dtype,
|
||||
paddle::DataType rois_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(roi_align_rotated)
|
||||
.Inputs({"Input", "Rois"})
|
||||
.Outputs({"Output"})
|
||||
.Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
|
||||
"sampling_ratio: int", "aligned: bool", "clockwise: bool"})
|
||||
.SetKernelFn(PD_KERNEL(RoIAlignRotatedForward))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
|
||||
|
||||
PD_BUILD_GRAD_OP(roi_align_rotated)
|
||||
.Inputs({"Input", "Rois", paddle::Grad("Output")})
|
||||
.Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
|
||||
"sampling_ratio: int", "aligned: bool", "clockwise: bool"})
|
||||
.Outputs({paddle::Grad("Input")})
|
||||
.SetKernelFn(PD_KERNEL(RoIAlignRotatedBackward))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(InferBackShape));
|
|
@ -0,0 +1,380 @@
|
|||
|
||||
// This code is refer from:
|
||||
// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <cuda.h>
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
#define THREADS_PER_BLOCK 512
|
||||
|
||||
inline int GET_BLOCKS(const int N) {
|
||||
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
||||
int max_block_num = 4096;
|
||||
return min(optimal_block_num, max_block_num);
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
|
||||
|
||||
static __inline__ __device__ double atomicAdd(double *address, double val) {
|
||||
unsigned long long int *address_as_ull = (unsigned long long int *)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
if (val == 0.0)
|
||||
return __longlong_as_double(old);
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
} while (assumed != old);
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ T bilinear_interpolate(const T *input, const int height,
|
||||
const int width, T y, T x,
|
||||
const int index /* index for debug only*/) {
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width)
|
||||
return 0;
|
||||
|
||||
if (y <= 0)
|
||||
y = 0;
|
||||
if (x <= 0)
|
||||
x = 0;
|
||||
|
||||
int y_low = (int)y;
|
||||
int x_low = (int)x;
|
||||
int y_high;
|
||||
int x_high;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
// do bilinear interpolation
|
||||
T v1 = input[y_low * width + x_low];
|
||||
T v2 = input[y_low * width + x_high];
|
||||
T v3 = input[y_high * width + x_low];
|
||||
T v4 = input[y_high * width + x_high];
|
||||
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
bilinear_interpolate_gradient(const int height, const int width, T y, T x,
|
||||
T &w1, T &w2, T &w3, T &w4, int &x_low,
|
||||
int &x_high, int &y_low, int &y_high,
|
||||
const int index /* index for debug only*/) {
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
// empty
|
||||
w1 = w2 = w3 = w4 = 0.;
|
||||
x_low = x_high = y_low = y_high = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
if (y <= 0)
|
||||
y = 0;
|
||||
if (x <= 0)
|
||||
x = 0;
|
||||
|
||||
y_low = (int)y;
|
||||
x_low = (int)x;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
|
||||
// reference in forward
|
||||
// T v1 = input[y_low * width + x_low];
|
||||
// T v2 = input[y_low * width + x_high];
|
||||
// T v3 = input[y_high * width + x_low];
|
||||
// T v4 = input[y_high * width + x_high];
|
||||
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
/*** Forward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void roi_align_rotated_cuda_forward_kernel(
|
||||
const int nthreads, const scalar_t *bottom_data,
|
||||
const scalar_t *bottom_rois, const scalar_t spatial_scale,
|
||||
const int sample_num, const bool aligned, const bool clockwise,
|
||||
const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, scalar_t *top_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
if (!aligned) { // for backward-compatibility only
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
}
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
const scalar_t *offset_bottom_data =
|
||||
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sample_num > 0)
|
||||
? sample_num
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosscalar_theta = cos(theta);
|
||||
scalar_t sinscalar_theta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
||||
|
||||
scalar_t output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta (counterclockwise) around the center and translate
|
||||
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
|
||||
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
|
||||
|
||||
scalar_t val = bilinear_interpolate<scalar_t>(
|
||||
offset_bottom_data, height, width, y, x, index);
|
||||
output_val += val;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
/*** Backward ***/
|
||||
template <typename scalar_t>
|
||||
__global__ void roi_align_rotated_backward_cuda_kernel(
|
||||
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
|
||||
const scalar_t spatial_scale, const int sample_num, const bool aligned,
|
||||
const bool clockwise, const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not round
|
||||
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
|
||||
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
|
||||
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
|
||||
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
||||
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
||||
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
||||
scalar_t theta = offset_bottom_rois[5];
|
||||
if (clockwise) {
|
||||
theta = -theta; // If clockwise, the angle needs to be reversed.
|
||||
}
|
||||
if (!aligned) { // for backward-compatibility only
|
||||
// Force malformed ROIs to be 1x1
|
||||
roi_width = max(roi_width, (scalar_t)1.);
|
||||
roi_height = max(roi_height, (scalar_t)1.);
|
||||
}
|
||||
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
||||
static_cast<scalar_t>(pooled_height);
|
||||
scalar_t bin_size_w =
|
||||
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
||||
|
||||
scalar_t *offset_bottom_diff =
|
||||
bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
const scalar_t *offset_top_diff = top_diff + top_offset;
|
||||
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sample_num > 0)
|
||||
? sample_num
|
||||
: ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
|
||||
|
||||
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
||||
// Appropriate translation needs to be applied after.
|
||||
scalar_t roi_start_h = -roi_height / 2.0;
|
||||
scalar_t roi_start_w = -roi_width / 2.0;
|
||||
scalar_t cosTheta = cos(theta);
|
||||
scalar_t sinTheta = sin(theta);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
||||
const scalar_t yy =
|
||||
roi_start_h + ph * bin_size_h +
|
||||
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
||||
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
||||
static_cast<scalar_t>(roi_bin_grid_w);
|
||||
|
||||
// Rotate by theta around the center and translate
|
||||
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
|
||||
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
|
||||
|
||||
scalar_t w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
|
||||
w4, x_low, x_high, y_low,
|
||||
y_high, index);
|
||||
|
||||
scalar_t g1 = top_diff_this_bin * w1 / count;
|
||||
scalar_t g2 = top_diff_this_bin * w2 / count;
|
||||
scalar_t g3 = top_diff_this_bin * w3 / count;
|
||||
scalar_t g4 = top_diff_this_bin * w4 / count;
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
|
||||
} // if
|
||||
} // ix
|
||||
} // iy
|
||||
} // CUDA_1D_KERNEL_LOOP
|
||||
} // RoIAlignBackward
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
|
||||
const paddle::Tensor &rois, int aligned_height,
|
||||
int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, bool aligned, bool clockwise) {
|
||||
|
||||
auto num_rois = rois.shape()[0];
|
||||
|
||||
auto channels = input.shape()[1];
|
||||
auto height = input.shape()[2];
|
||||
auto width = input.shape()[3];
|
||||
|
||||
auto output =
|
||||
paddle::empty({num_rois, channels, aligned_height, aligned_width},
|
||||
input.type(), paddle::GPUPlace());
|
||||
auto output_size = output.numel();
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
input.type(), "roi_align_rotated_cuda_forward_kernel", ([&] {
|
||||
roi_align_rotated_cuda_forward_kernel<
|
||||
data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
|
||||
output_size, input.data<data_t>(), rois.data<data_t>(),
|
||||
static_cast<data_t>(spatial_scale), sampling_ratio, aligned,
|
||||
clockwise, channels, height, width, aligned_height, aligned_width,
|
||||
output.data<data_t>());
|
||||
}));
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
|
||||
const paddle::Tensor &input, const paddle::Tensor &rois,
|
||||
const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
|
||||
float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
|
||||
|
||||
auto num_rois = rois.shape()[0];
|
||||
|
||||
auto batch_size = input.shape()[0];
|
||||
auto channels = input.shape()[1];
|
||||
auto height = input.shape()[2];
|
||||
auto width = input.shape()[3];
|
||||
|
||||
auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
|
||||
input.type(), paddle::GPUPlace());
|
||||
|
||||
const int output_size = num_rois * aligned_height * aligned_width * channels;
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
grad_output.type(), "roi_align_rotated_backward_cuda_kernel", ([&] {
|
||||
roi_align_rotated_backward_cuda_kernel<
|
||||
data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
|
||||
output_size, grad_output.data<data_t>(), rois.data<data_t>(),
|
||||
spatial_scale, sampling_ratio, aligned, clockwise, channels, height,
|
||||
width, aligned_height, aligned_width, grad_input.data<data_t>());
|
||||
}));
|
||||
return {grad_input};
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/roi_align_rotated.py
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.utils.cpp_extension import load
|
||||
custom_ops = load(
|
||||
name="custom_jit_ops",
|
||||
sources=[
|
||||
"ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc",
|
||||
"ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu"
|
||||
])
|
||||
|
||||
roi_align_rotated = custom_ops.roi_align_rotated
|
||||
|
||||
|
||||
class RoIAlignRotated(nn.Layer):
|
||||
"""RoI align pooling layer for rotated proposals.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
spatial_scale,
|
||||
sample_num=0,
|
||||
aligned=True,
|
||||
clockwise=False):
|
||||
super(RoIAlignRotated, self).__init__()
|
||||
|
||||
if isinstance(out_size, int):
|
||||
self.out_h = out_size
|
||||
self.out_w = out_size
|
||||
elif isinstance(out_size, tuple):
|
||||
assert len(out_size) == 2
|
||||
assert isinstance(out_size[0], int)
|
||||
assert isinstance(out_size[1], int)
|
||||
self.out_h, self.out_w = out_size
|
||||
else:
|
||||
raise TypeError(
|
||||
'"out_size" must be an integer or tuple of integers')
|
||||
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.sample_num = int(sample_num)
|
||||
self.aligned = aligned
|
||||
self.clockwise = clockwise
|
||||
|
||||
def forward(self, feats, rois):
|
||||
output = roi_align_rotated(feats, rois, self.out_h, self.out_w,
|
||||
self.spatial_scale, self.sample_num,
|
||||
self.aligned, self.clockwise)
|
||||
return output
|
|
@ -26,6 +26,7 @@ from .det_sast_loss import SASTLoss
|
|||
from .det_pse_loss import PSELoss
|
||||
from .det_fce_loss import FCELoss
|
||||
from .det_ct_loss import CTLoss
|
||||
from .det_drrg_loss import DRRGLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
@ -70,7 +71,7 @@ def build_loss(config):
|
|||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss', 'CTLoss', 'RFLLoss'
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,224 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/drrg_loss.py
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class DRRGLoss(nn.Layer):
|
||||
def __init__(self, ohem_ratio=3.0):
|
||||
super().__init__()
|
||||
self.ohem_ratio = ohem_ratio
|
||||
self.downsample_ratio = 1.0
|
||||
|
||||
def balance_bce_loss(self, pred, gt, mask):
|
||||
"""Balanced Binary-CrossEntropy Loss.
|
||||
|
||||
Args:
|
||||
pred (Tensor): Shape of :math:`(1, H, W)`.
|
||||
gt (Tensor): Shape of :math:`(1, H, W)`.
|
||||
mask (Tensor): Shape of :math:`(1, H, W)`.
|
||||
|
||||
Returns:
|
||||
Tensor: Balanced bce loss.
|
||||
"""
|
||||
assert pred.shape == gt.shape == mask.shape
|
||||
assert paddle.all(pred >= 0) and paddle.all(pred <= 1)
|
||||
assert paddle.all(gt >= 0) and paddle.all(gt <= 1)
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
positive_count = int(positive.sum())
|
||||
|
||||
if positive_count > 0:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
positive_loss = paddle.sum(loss * positive)
|
||||
negative_loss = loss * negative
|
||||
negative_count = min(
|
||||
int(negative.sum()), int(positive_count * self.ohem_ratio))
|
||||
else:
|
||||
positive_loss = paddle.to_tensor(0.0)
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
negative_loss = loss * negative
|
||||
negative_count = 100
|
||||
negative_loss, _ = paddle.topk(
|
||||
negative_loss.reshape([-1]), negative_count)
|
||||
|
||||
balance_loss = (positive_loss + paddle.sum(negative_loss)) / (
|
||||
float(positive_count + negative_count) + 1e-5)
|
||||
|
||||
return balance_loss
|
||||
|
||||
def gcn_loss(self, gcn_data):
|
||||
"""CrossEntropy Loss from gcn module.
|
||||
|
||||
Args:
|
||||
gcn_data (tuple(Tensor, Tensor)): The first is the
|
||||
prediction with shape :math:`(N, 2)` and the
|
||||
second is the gt label with shape :math:`(m, n)`
|
||||
where :math:`m * n = N`.
|
||||
|
||||
Returns:
|
||||
Tensor: CrossEntropy loss.
|
||||
"""
|
||||
gcn_pred, gt_labels = gcn_data
|
||||
gt_labels = gt_labels.reshape([-1])
|
||||
loss = F.cross_entropy(gcn_pred, gt_labels)
|
||||
|
||||
return loss
|
||||
|
||||
def bitmasks2tensor(self, bitmasks, target_sz):
|
||||
"""Convert Bitmasks to tensor.
|
||||
|
||||
Args:
|
||||
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
|
||||
for one img.
|
||||
target_sz (tuple(int, int)): The target tensor of size
|
||||
:math:`(H, W)`.
|
||||
|
||||
Returns:
|
||||
list[Tensor]: The list of kernel tensors. Each element stands for
|
||||
one kernel level.
|
||||
"""
|
||||
batch_size = len(bitmasks)
|
||||
results = []
|
||||
|
||||
kernel = []
|
||||
for batch_inx in range(batch_size):
|
||||
mask = bitmasks[batch_inx]
|
||||
# hxw
|
||||
mask_sz = mask.shape
|
||||
# left, right, top, bottom
|
||||
pad = [0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]]
|
||||
mask = F.pad(mask, pad, mode='constant', value=0)
|
||||
kernel.append(mask)
|
||||
kernel = paddle.stack(kernel)
|
||||
results.append(kernel)
|
||||
|
||||
return results
|
||||
|
||||
def forward(self, preds, labels):
|
||||
"""Compute Drrg loss.
|
||||
"""
|
||||
|
||||
assert isinstance(preds, tuple)
|
||||
gt_text_mask, gt_center_region_mask, gt_mask, gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map = labels[
|
||||
1:8]
|
||||
|
||||
downsample_ratio = self.downsample_ratio
|
||||
|
||||
pred_maps, gcn_data = preds
|
||||
pred_text_region = pred_maps[:, 0, :, :]
|
||||
pred_center_region = pred_maps[:, 1, :, :]
|
||||
pred_sin_map = pred_maps[:, 2, :, :]
|
||||
pred_cos_map = pred_maps[:, 3, :, :]
|
||||
pred_top_height_map = pred_maps[:, 4, :, :]
|
||||
pred_bot_height_map = pred_maps[:, 5, :, :]
|
||||
feature_sz = pred_maps.shape
|
||||
|
||||
# bitmask 2 tensor
|
||||
mapping = {
|
||||
'gt_text_mask': paddle.cast(gt_text_mask, 'float32'),
|
||||
'gt_center_region_mask':
|
||||
paddle.cast(gt_center_region_mask, 'float32'),
|
||||
'gt_mask': paddle.cast(gt_mask, 'float32'),
|
||||
'gt_top_height_map': paddle.cast(gt_top_height_map, 'float32'),
|
||||
'gt_bot_height_map': paddle.cast(gt_bot_height_map, 'float32'),
|
||||
'gt_sin_map': paddle.cast(gt_sin_map, 'float32'),
|
||||
'gt_cos_map': paddle.cast(gt_cos_map, 'float32')
|
||||
}
|
||||
gt = {}
|
||||
for key, value in mapping.items():
|
||||
gt[key] = value
|
||||
if abs(downsample_ratio - 1.0) < 1e-2:
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
else:
|
||||
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
if key in ['gt_top_height_map', 'gt_bot_height_map']:
|
||||
gt[key] = [item * downsample_ratio for item in gt[key]]
|
||||
gt[key] = [item for item in gt[key]]
|
||||
|
||||
scale = paddle.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
|
||||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.balance_bce_loss(
|
||||
F.sigmoid(pred_text_region), gt['gt_text_mask'][0],
|
||||
gt['gt_mask'][0])
|
||||
|
||||
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0])
|
||||
negative_text_mask = ((1 - gt['gt_text_mask'][0]) * gt['gt_mask'][0])
|
||||
loss_center_map = F.binary_cross_entropy(
|
||||
F.sigmoid(pred_center_region),
|
||||
gt['gt_center_region_mask'][0],
|
||||
reduction='none')
|
||||
if int(text_mask.sum()) > 0:
|
||||
loss_center_positive = paddle.sum(loss_center_map *
|
||||
text_mask) / paddle.sum(text_mask)
|
||||
else:
|
||||
loss_center_positive = paddle.to_tensor(0.0)
|
||||
loss_center_negative = paddle.sum(
|
||||
loss_center_map *
|
||||
negative_text_mask) / paddle.sum(negative_text_mask)
|
||||
loss_center = loss_center_positive + 0.5 * loss_center_negative
|
||||
|
||||
center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0])
|
||||
if int(center_mask.sum()) > 0:
|
||||
map_sz = pred_top_height_map.shape
|
||||
ones = paddle.ones(map_sz, dtype='float32')
|
||||
loss_top = F.smooth_l1_loss(
|
||||
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
loss_bot = F.smooth_l1_loss(
|
||||
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
gt_height = (
|
||||
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
|
||||
loss_height = paddle.sum(
|
||||
(paddle.log(gt_height + 1) *
|
||||
(loss_top + loss_bot)) * center_mask) / paddle.sum(center_mask)
|
||||
|
||||
loss_sin = paddle.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_sin_map, gt['gt_sin_map'][0],
|
||||
reduction='none') * center_mask) / paddle.sum(center_mask)
|
||||
loss_cos = paddle.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_cos_map, gt['gt_cos_map'][0],
|
||||
reduction='none') * center_mask) / paddle.sum(center_mask)
|
||||
else:
|
||||
loss_height = paddle.to_tensor(0.0)
|
||||
loss_sin = paddle.to_tensor(0.0)
|
||||
loss_cos = paddle.to_tensor(0.0)
|
||||
|
||||
loss_gcn = self.gcn_loss(gcn_data)
|
||||
|
||||
loss = loss_text + loss_center + loss_height + loss_sin + loss_cos + loss_gcn
|
||||
results = dict(
|
||||
loss=loss,
|
||||
loss_text=loss_text,
|
||||
loss_center=loss_center,
|
||||
loss_height=loss_height,
|
||||
loss_sin=loss_sin,
|
||||
loss_cos=loss_cos,
|
||||
loss_gcn=loss_gcn)
|
||||
|
||||
return results
|
|
@ -24,6 +24,7 @@ def build_head(config):
|
|||
from .det_fce_head import FCEHead
|
||||
from .e2e_pg_head import PGHead
|
||||
from .det_ct_head import CT_Head
|
||||
from .det_drrg_head import DRRGHead
|
||||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
|
@ -54,7 +55,8 @@ def build_head(config):
|
|||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
|
||||
'DRRGHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/drrg_head.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from .gcn import GCN
|
||||
from .local_graph import LocalGraphs
|
||||
from .proposal_local_graph import ProposalLocalGraphs
|
||||
|
||||
|
||||
class DRRGHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
k_at_hops=(8, 4),
|
||||
num_adjacent_linkages=3,
|
||||
node_geo_feat_len=120,
|
||||
pooling_scale=1.0,
|
||||
pooling_output_size=(4, 3),
|
||||
nms_thr=0.3,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
comp_shrink_ratio=1.03,
|
||||
comp_ratio=0.4,
|
||||
comp_score_thr=0.3,
|
||||
text_region_thr=0.2,
|
||||
center_region_thr=0.2,
|
||||
center_region_area_thr=50,
|
||||
local_graph_thr=0.7,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(num_adjacent_linkages, int)
|
||||
assert isinstance(node_geo_feat_len, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_ratio, float)
|
||||
assert isinstance(comp_score_thr, float)
|
||||
assert isinstance(text_region_thr, float)
|
||||
assert isinstance(center_region_thr, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
assert isinstance(local_graph_thr, float)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = 6
|
||||
self.downsample_ratio = 1.0
|
||||
self.k_at_hops = k_at_hops
|
||||
self.num_adjacent_linkages = num_adjacent_linkages
|
||||
self.node_geo_feat_len = node_geo_feat_len
|
||||
self.pooling_scale = pooling_scale
|
||||
self.pooling_output_size = pooling_output_size
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_ratio = comp_ratio
|
||||
self.comp_score_thr = comp_score_thr
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
self.local_graph_thr = local_graph_thr
|
||||
|
||||
self.out_conv = nn.Conv2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.graph_train = LocalGraphs(
|
||||
self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
|
||||
self.pooling_scale, self.pooling_output_size, self.local_graph_thr)
|
||||
|
||||
self.graph_test = ProposalLocalGraphs(
|
||||
self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
|
||||
self.pooling_scale, self.pooling_output_size, self.nms_thr,
|
||||
self.min_width, self.max_width, self.comp_shrink_ratio,
|
||||
self.comp_ratio, self.comp_score_thr, self.text_region_thr,
|
||||
self.center_region_thr, self.center_region_area_thr)
|
||||
|
||||
pool_w, pool_h = self.pooling_output_size
|
||||
node_feat_len = (pool_w * pool_h) * (
|
||||
self.in_channels + self.out_channels) + self.node_geo_feat_len
|
||||
self.gcn = GCN(node_feat_len)
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
"""
|
||||
Args:
|
||||
inputs (Tensor): Shape of :math:`(N, C, H, W)`.
|
||||
gt_comp_attribs (list[ndarray]): The padded text component
|
||||
attributes. Shape: (num_component, 8).
|
||||
|
||||
Returns:
|
||||
tuple: Returns (pred_maps, (gcn_pred, gt_labels)).
|
||||
|
||||
- | pred_maps (Tensor): Prediction map with shape
|
||||
:math:`(N, C_{out}, H, W)`.
|
||||
- | gcn_pred (Tensor): Prediction from GCN module, with
|
||||
shape :math:`(N, 2)`.
|
||||
- | gt_labels (Tensor): Ground-truth label with shape
|
||||
:math:`(N, 8)`.
|
||||
"""
|
||||
if self.training:
|
||||
assert targets is not None
|
||||
gt_comp_attribs = targets[7]
|
||||
pred_maps = self.out_conv(inputs)
|
||||
feat_maps = paddle.concat([inputs, pred_maps], axis=1)
|
||||
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
|
||||
feat_maps, np.stack(gt_comp_attribs))
|
||||
|
||||
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
|
||||
|
||||
return pred_maps, (gcn_pred, gt_labels)
|
||||
else:
|
||||
return self.single_test(inputs)
|
||||
|
||||
def single_test(self, feat_maps):
|
||||
r"""
|
||||
Args:
|
||||
feat_maps (Tensor): Shape of :math:`(N, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
tuple: Returns (edge, score, text_comps).
|
||||
|
||||
- | edge (ndarray): The edge array of shape :math:`(N, 2)`
|
||||
where each row is a pair of text component indices
|
||||
that makes up an edge in graph.
|
||||
- | score (ndarray): The score array of shape :math:`(N,)`,
|
||||
corresponding to the edge above.
|
||||
- | text_comps (ndarray): The text components of shape
|
||||
:math:`(N, 9)` where each row corresponds to one box and
|
||||
its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
|
||||
"""
|
||||
pred_maps = self.out_conv(feat_maps)
|
||||
feat_maps = paddle.concat([feat_maps, pred_maps], axis=1)
|
||||
|
||||
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
|
||||
|
||||
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivot_local_graphs, text_comps) = graph_data
|
||||
|
||||
if none_flag:
|
||||
return None, None, None
|
||||
gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
|
||||
pivots_knn_inds)
|
||||
pred_labels = F.softmax(gcn_pred, axis=1)
|
||||
|
||||
edges = []
|
||||
scores = []
|
||||
pivot_local_graphs = pivot_local_graphs.squeeze().numpy()
|
||||
|
||||
for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
|
||||
pivot = pivot_local_graph[0]
|
||||
for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
|
||||
neighbor = pivot_local_graph[neighbor_ind.item()]
|
||||
edges.append([pivot, neighbor])
|
||||
scores.append(pred_labels[pivot_ind * pivots_knn_inds.shape[1] +
|
||||
k_ind, 1].item())
|
||||
|
||||
edges = np.asarray(edges)
|
||||
scores = np.asarray(scores)
|
||||
|
||||
return edges, scores, text_comps
|
|
@ -0,0 +1,113 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/gcn.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class BatchNorm1D(nn.BatchNorm1D):
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True):
|
||||
momentum = 1 - momentum
|
||||
weight_attr = None
|
||||
bias_attr = None
|
||||
if not affine:
|
||||
weight_attr = paddle.ParamAttr(learning_rate=0.0)
|
||||
bias_attr = paddle.ParamAttr(learning_rate=0.0)
|
||||
super().__init__(
|
||||
num_features,
|
||||
momentum=momentum,
|
||||
epsilon=eps,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr,
|
||||
use_global_stats=track_running_stats)
|
||||
|
||||
|
||||
class MeanAggregator(nn.Layer):
|
||||
def forward(self, features, A):
|
||||
x = paddle.bmm(A, features)
|
||||
return x
|
||||
|
||||
|
||||
class GraphConv(nn.Layer):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.weight = self.create_parameter(
|
||||
[in_dim * 2, out_dim],
|
||||
default_initializer=nn.initializer.XavierUniform())
|
||||
self.bias = self.create_parameter(
|
||||
[out_dim],
|
||||
is_bias=True,
|
||||
default_initializer=nn.initializer.Assign([0] * out_dim))
|
||||
|
||||
self.aggregator = MeanAggregator()
|
||||
|
||||
def forward(self, features, A):
|
||||
b, n, d = features.shape
|
||||
assert d == self.in_dim
|
||||
agg_feats = self.aggregator(features, A)
|
||||
cat_feats = paddle.concat([features, agg_feats], axis=2)
|
||||
out = paddle.einsum('bnd,df->bnf', cat_feats, self.weight)
|
||||
out = F.relu(out + self.bias)
|
||||
return out
|
||||
|
||||
|
||||
class GCN(nn.Layer):
|
||||
def __init__(self, feat_len):
|
||||
super(GCN, self).__init__()
|
||||
self.bn0 = BatchNorm1D(feat_len, affine=False)
|
||||
self.conv1 = GraphConv(feat_len, 512)
|
||||
self.conv2 = GraphConv(512, 256)
|
||||
self.conv3 = GraphConv(256, 128)
|
||||
self.conv4 = GraphConv(128, 64)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2))
|
||||
|
||||
def forward(self, x, A, knn_inds):
|
||||
|
||||
num_local_graphs, num_max_nodes, feat_len = x.shape
|
||||
|
||||
x = x.reshape([-1, feat_len])
|
||||
x = self.bn0(x)
|
||||
x = x.reshape([num_local_graphs, num_max_nodes, feat_len])
|
||||
|
||||
x = self.conv1(x, A)
|
||||
x = self.conv2(x, A)
|
||||
x = self.conv3(x, A)
|
||||
x = self.conv4(x, A)
|
||||
k = knn_inds.shape[-1]
|
||||
mid_feat_len = x.shape[-1]
|
||||
edge_feat = paddle.zeros([num_local_graphs, k, mid_feat_len])
|
||||
for graph_ind in range(num_local_graphs):
|
||||
edge_feat[graph_ind, :, :] = x[graph_ind][paddle.to_tensor(knn_inds[
|
||||
graph_ind])]
|
||||
edge_feat = edge_feat.reshape([-1, mid_feat_len])
|
||||
pred = self.classifier(edge_feat)
|
||||
|
||||
return pred
|
|
@ -0,0 +1,388 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/local_graph.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from ppocr.ext_op import RoIAlignRotated
|
||||
|
||||
|
||||
def normalize_adjacent_matrix(A):
|
||||
assert A.ndim == 2
|
||||
assert A.shape[0] == A.shape[1]
|
||||
|
||||
A = A + np.eye(A.shape[0])
|
||||
d = np.sum(A, axis=0)
|
||||
d = np.clip(d, 0, None)
|
||||
d_inv = np.power(d, -0.5).flatten()
|
||||
d_inv[np.isinf(d_inv)] = 0.0
|
||||
d_inv = np.diag(d_inv)
|
||||
G = A.dot(d_inv).transpose().dot(d_inv)
|
||||
return G
|
||||
|
||||
|
||||
def euclidean_distance_matrix(A, B):
|
||||
"""Calculate the Euclidean distance matrix.
|
||||
|
||||
Args:
|
||||
A (ndarray): The point sequence.
|
||||
B (ndarray): The point sequence with the same dimensions as A.
|
||||
|
||||
returns:
|
||||
D (ndarray): The Euclidean distance matrix.
|
||||
"""
|
||||
assert A.ndim == 2
|
||||
assert B.ndim == 2
|
||||
assert A.shape[1] == B.shape[1]
|
||||
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
|
||||
A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
|
||||
B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
|
||||
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
|
||||
|
||||
zero_mask = np.less(D_squared, 0.0)
|
||||
D_squared[zero_mask] = 0.0
|
||||
D = np.sqrt(D_squared)
|
||||
return D
|
||||
|
||||
|
||||
def feature_embedding(input_feats, out_feat_len):
|
||||
"""Embed features. This code was partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
||||
Args:
|
||||
input_feats (ndarray): The input features of shape (N, d), where N is
|
||||
the number of nodes in graph, d is the input feature vector length.
|
||||
out_feat_len (int): The length of output feature vector.
|
||||
|
||||
Returns:
|
||||
embedded_feats (ndarray): The embedded features.
|
||||
"""
|
||||
assert input_feats.ndim == 2
|
||||
assert isinstance(out_feat_len, int)
|
||||
assert out_feat_len >= input_feats.shape[1]
|
||||
|
||||
num_nodes = input_feats.shape[0]
|
||||
feat_dim = input_feats.shape[1]
|
||||
feat_repeat_times = out_feat_len // feat_dim
|
||||
residue_dim = out_feat_len % feat_dim
|
||||
|
||||
if residue_dim > 0:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
|
||||
for j in range(feat_repeat_times + 1)
|
||||
]).reshape((feat_repeat_times + 1, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(
|
||||
input_feats, axis=0), feat_repeat_times, axis=0)
|
||||
residue_feats = np.hstack([
|
||||
input_feats[:, 0:residue_dim], np.zeros(
|
||||
(num_nodes, feat_dim - residue_dim))
|
||||
])
|
||||
residue_feats = np.expand_dims(residue_feats, axis=0)
|
||||
repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(num_nodes, -1))[:, 0:out_feat_len]
|
||||
else:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
|
||||
for j in range(feat_repeat_times)
|
||||
]).reshape((feat_repeat_times, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(
|
||||
input_feats, axis=0), feat_repeat_times, axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(num_nodes, -1)).astype(np.float32)
|
||||
|
||||
return embedded_feats
|
||||
|
||||
|
||||
class LocalGraphs:
|
||||
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
|
||||
pooling_scale, pooling_output_size, local_graph_thr):
|
||||
|
||||
assert len(k_at_hops) == 2
|
||||
assert all(isinstance(n, int) for n in k_at_hops)
|
||||
assert isinstance(num_adjacent_linkages, int)
|
||||
assert isinstance(node_geo_feat_len, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert all(isinstance(n, int) for n in pooling_output_size)
|
||||
assert isinstance(local_graph_thr, float)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.num_adjacent_linkages = num_adjacent_linkages
|
||||
self.node_geo_feat_dim = node_geo_feat_len
|
||||
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
|
||||
self.local_graph_thr = local_graph_thr
|
||||
|
||||
def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
|
||||
"""Generate local graphs for GCN to predict which instance a text
|
||||
component belongs to.
|
||||
|
||||
Args:
|
||||
sorted_dist_inds (ndarray): The complete graph node indices, which
|
||||
is sorted according to the Euclidean distance.
|
||||
gt_comp_labels(ndarray): The ground truth labels define the
|
||||
instance to which the text components (nodes in graphs) belong.
|
||||
|
||||
Returns:
|
||||
pivot_local_graphs(list[list[int]]): The list of local graph
|
||||
neighbor indices of pivots.
|
||||
pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
|
||||
of pivots.
|
||||
"""
|
||||
|
||||
assert sorted_dist_inds.ndim == 2
|
||||
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
|
||||
gt_comp_labels.shape[0])
|
||||
|
||||
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
|
||||
pivot_local_graphs = []
|
||||
pivot_knns = []
|
||||
for pivot_ind, knn in enumerate(knn_graph):
|
||||
|
||||
local_graph_neighbors = set(knn)
|
||||
|
||||
for neighbor_ind in knn:
|
||||
local_graph_neighbors.update(
|
||||
set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
|
||||
1]))
|
||||
|
||||
local_graph_neighbors.discard(pivot_ind)
|
||||
pivot_local_graph = list(local_graph_neighbors)
|
||||
pivot_local_graph.insert(0, pivot_ind)
|
||||
pivot_knn = [pivot_ind] + list(knn)
|
||||
|
||||
if pivot_ind < 1:
|
||||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
else:
|
||||
add_flag = True
|
||||
for graph_ind, added_knn in enumerate(pivot_knns):
|
||||
added_pivot_ind = added_knn[0]
|
||||
added_local_graph = pivot_local_graphs[graph_ind]
|
||||
|
||||
union = len(
|
||||
set(pivot_local_graph[1:]).union(
|
||||
set(added_local_graph[1:])))
|
||||
intersect = len(
|
||||
set(pivot_local_graph[1:]).intersection(
|
||||
set(added_local_graph[1:])))
|
||||
local_graph_iou = intersect / (union + 1e-8)
|
||||
|
||||
if (local_graph_iou > self.local_graph_thr and
|
||||
pivot_ind in added_knn and
|
||||
gt_comp_labels[added_pivot_ind] ==
|
||||
gt_comp_labels[pivot_ind] and
|
||||
gt_comp_labels[pivot_ind] != 0):
|
||||
add_flag = False
|
||||
break
|
||||
if add_flag:
|
||||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
|
||||
return pivot_local_graphs, pivot_knns
|
||||
|
||||
def generate_gcn_input(self, node_feat_batch, node_label_batch,
|
||||
local_graph_batch, knn_batch, sorted_dist_ind_batch):
|
||||
"""Generate graph convolution network input data.
|
||||
|
||||
Args:
|
||||
node_feat_batch (List[Tensor]): The batched graph node features.
|
||||
node_label_batch (List[ndarray]): The batched text component
|
||||
labels.
|
||||
local_graph_batch (List[List[list[int]]]): The local graph node
|
||||
indices of image batch.
|
||||
knn_batch (List[List[list[int]]]): The knn graph node indices of
|
||||
image batch.
|
||||
sorted_dist_ind_batch (list[ndarray]): The node indices sorted
|
||||
according to the Euclidean distance.
|
||||
|
||||
Returns:
|
||||
local_graphs_node_feat (Tensor): The node features of graph.
|
||||
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
|
||||
pivots_knn_inds (Tensor): The k-nearest neighbor indices in
|
||||
local graph.
|
||||
gt_linkage (Tensor): The surpervision signal of GCN for linkage
|
||||
prediction.
|
||||
"""
|
||||
assert isinstance(node_feat_batch, list)
|
||||
assert isinstance(node_label_batch, list)
|
||||
assert isinstance(local_graph_batch, list)
|
||||
assert isinstance(knn_batch, list)
|
||||
assert isinstance(sorted_dist_ind_batch, list)
|
||||
|
||||
num_max_nodes = max([
|
||||
len(pivot_local_graph)
|
||||
for pivot_local_graphs in local_graph_batch
|
||||
for pivot_local_graph in pivot_local_graphs
|
||||
])
|
||||
|
||||
local_graphs_node_feat = []
|
||||
adjacent_matrices = []
|
||||
pivots_knn_inds = []
|
||||
pivots_gt_linkage = []
|
||||
|
||||
for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
|
||||
node_feats = node_feat_batch[batch_ind]
|
||||
pivot_local_graphs = local_graph_batch[batch_ind]
|
||||
pivot_knns = knn_batch[batch_ind]
|
||||
node_labels = node_label_batch[batch_ind]
|
||||
|
||||
for graph_ind, pivot_knn in enumerate(pivot_knns):
|
||||
pivot_local_graph = pivot_local_graphs[graph_ind]
|
||||
num_nodes = len(pivot_local_graph)
|
||||
pivot_ind = pivot_local_graph[0]
|
||||
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
|
||||
|
||||
knn_inds = paddle.to_tensor(
|
||||
[node2ind_map[i] for i in pivot_knn[1:]])
|
||||
pivot_feats = node_feats[pivot_ind]
|
||||
normalized_feats = node_feats[paddle.to_tensor(
|
||||
pivot_local_graph)] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros(
|
||||
(num_nodes, num_nodes), dtype=np.float32)
|
||||
for node in pivot_local_graph:
|
||||
neighbors = sorted_dist_inds[node, 1:
|
||||
self.num_adjacent_linkages + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in pivot_local_graph:
|
||||
|
||||
adjacent_matrix[node2ind_map[node], node2ind_map[
|
||||
neighbor]] = 1
|
||||
adjacent_matrix[node2ind_map[neighbor],
|
||||
node2ind_map[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
|
||||
pad_adjacent_matrix = paddle.zeros(
|
||||
(num_max_nodes, num_max_nodes))
|
||||
pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
|
||||
paddle.to_tensor(adjacent_matrix), 'float32')
|
||||
|
||||
pad_normalized_feats = paddle.concat(
|
||||
[
|
||||
normalized_feats, paddle.zeros(
|
||||
(num_max_nodes - num_nodes,
|
||||
normalized_feats.shape[1]))
|
||||
],
|
||||
axis=0)
|
||||
local_graph_labels = node_labels[pivot_local_graph]
|
||||
knn_labels = local_graph_labels[knn_inds.numpy()]
|
||||
link_labels = ((node_labels[pivot_ind] == knn_labels) &
|
||||
(node_labels[pivot_ind] > 0)).astype(np.int64)
|
||||
link_labels = paddle.to_tensor(link_labels)
|
||||
|
||||
local_graphs_node_feat.append(pad_normalized_feats)
|
||||
adjacent_matrices.append(pad_adjacent_matrix)
|
||||
pivots_knn_inds.append(knn_inds)
|
||||
pivots_gt_linkage.append(link_labels)
|
||||
|
||||
local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
|
||||
adjacent_matrices = paddle.stack(adjacent_matrices, 0)
|
||||
pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
|
||||
pivots_gt_linkage = paddle.stack(pivots_gt_linkage, 0)
|
||||
|
||||
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivots_gt_linkage)
|
||||
|
||||
def __call__(self, feat_maps, comp_attribs):
|
||||
"""Generate local graphs as GCN input.
|
||||
|
||||
Args:
|
||||
feat_maps (Tensor): The feature maps to extract the content
|
||||
features of text components.
|
||||
comp_attribs (ndarray): The text component attributes.
|
||||
|
||||
Returns:
|
||||
local_graphs_node_feat (Tensor): The node features of graph.
|
||||
adjacent_matrices (Tensor): The adjacent matrices of local graphs.
|
||||
pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
|
||||
graph.
|
||||
gt_linkage (Tensor): The surpervision signal of GCN for linkage
|
||||
prediction.
|
||||
"""
|
||||
|
||||
assert isinstance(feat_maps, paddle.Tensor)
|
||||
assert comp_attribs.ndim == 3
|
||||
assert comp_attribs.shape[2] == 8
|
||||
|
||||
sorted_dist_inds_batch = []
|
||||
local_graph_batch = []
|
||||
knn_batch = []
|
||||
node_feat_batch = []
|
||||
node_label_batch = []
|
||||
|
||||
for batch_ind in range(comp_attribs.shape[0]):
|
||||
num_comps = int(comp_attribs[batch_ind, 0, 0])
|
||||
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
|
||||
node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(
|
||||
np.int32)
|
||||
|
||||
comp_centers = comp_geo_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(comp_centers,
|
||||
comp_centers)
|
||||
|
||||
batch_id = np.zeros(
|
||||
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
|
||||
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
|
||||
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
|
||||
comp_geo_attribs[:, -1])
|
||||
angle = angle.reshape((-1, 1))
|
||||
rotated_rois = np.hstack(
|
||||
[batch_id, comp_geo_attribs[:, :-2], angle])
|
||||
rois = paddle.to_tensor(rotated_rois)
|
||||
content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
|
||||
rois)
|
||||
|
||||
content_feats = content_feats.reshape([content_feats.shape[0], -1])
|
||||
geo_feats = feature_embedding(comp_geo_attribs,
|
||||
self.node_geo_feat_dim)
|
||||
geo_feats = paddle.to_tensor(geo_feats)
|
||||
node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
|
||||
|
||||
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
|
||||
pivot_local_graphs, pivot_knns = self.generate_local_graphs(
|
||||
sorted_dist_inds, node_labels)
|
||||
|
||||
node_feat_batch.append(node_feats)
|
||||
node_label_batch.append(node_labels)
|
||||
local_graph_batch.append(pivot_local_graphs)
|
||||
knn_batch.append(pivot_knns)
|
||||
sorted_dist_inds_batch.append(sorted_dist_inds)
|
||||
|
||||
(node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
|
||||
self.generate_gcn_input(node_feat_batch,
|
||||
node_label_batch,
|
||||
local_graph_batch,
|
||||
knn_batch,
|
||||
sorted_dist_inds_batch)
|
||||
|
||||
return node_feats, adjacent_matrices, knn_inds, gt_linkage
|
|
@ -0,0 +1,412 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/proposal_local_graph.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from lanms import merge_quadrangle_n9 as la_nms
|
||||
|
||||
from ppocr.ext_op import RoIAlignRotated
|
||||
from .local_graph import (euclidean_distance_matrix, feature_embedding,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
def fill_hole(input_mask):
|
||||
h, w = input_mask.shape
|
||||
canvas = np.zeros((h + 2, w + 2), np.uint8)
|
||||
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
|
||||
|
||||
mask = np.zeros((h + 4, w + 4), np.uint8)
|
||||
|
||||
cv2.floodFill(canvas, mask, (0, 0), 1)
|
||||
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
|
||||
|
||||
return ~canvas | input_mask
|
||||
|
||||
|
||||
class ProposalLocalGraphs:
|
||||
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
|
||||
pooling_scale, pooling_output_size, nms_thr, min_width,
|
||||
max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
|
||||
text_region_thr, center_region_thr, center_region_area_thr):
|
||||
|
||||
assert len(k_at_hops) == 2
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(num_adjacent_linkages, int)
|
||||
assert isinstance(node_geo_feat_len, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(comp_w_h_ratio, float)
|
||||
assert isinstance(comp_score_thr, float)
|
||||
assert isinstance(text_region_thr, float)
|
||||
assert isinstance(center_region_thr, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.active_connection = num_adjacent_linkages
|
||||
self.local_graph_depth = len(self.k_at_hops)
|
||||
self.node_geo_feat_dim = node_geo_feat_len
|
||||
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.comp_w_h_ratio = comp_w_h_ratio
|
||||
self.comp_score_thr = comp_score_thr
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
|
||||
def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
|
||||
cos_map, comp_score_thr, min_width, max_width,
|
||||
comp_shrink_ratio, comp_w_h_ratio):
|
||||
"""Propose text components.
|
||||
|
||||
Args:
|
||||
score_map (ndarray): The score map for NMS.
|
||||
top_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
comp_score_thr (float): The score threshold of text component.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_w_h_ratio (float): The width to height ratio of text
|
||||
components.
|
||||
|
||||
Returns:
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
comp_centers = np.argwhere(score_map > comp_score_thr)
|
||||
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_pts = comp_centers + np.hstack(
|
||||
[top_height * sin, top_height * cos])
|
||||
bot_mid_pts = comp_centers - np.hstack(
|
||||
[bot_height * sin, bot_height * cos])
|
||||
|
||||
width = (top_height + bot_height) * comp_w_h_ratio
|
||||
width = np.clip(width, min_width, max_width)
|
||||
r = width / 2
|
||||
|
||||
tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
|
||||
bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
|
||||
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
|
||||
|
||||
score = score_map[y, x].reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps, score])
|
||||
|
||||
return text_comps
|
||||
|
||||
def propose_comps_and_attribs(self, text_region_map, center_region_map,
|
||||
top_height_map, bot_height_map, sin_map,
|
||||
cos_map):
|
||||
"""Generate text components and attributes.
|
||||
|
||||
Args:
|
||||
text_region_map (ndarray): The predicted text region probability
|
||||
map.
|
||||
center_region_map (ndarray): The predicted text center region
|
||||
probability map.
|
||||
top_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_height_map (ndarray): The predicted text height map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
|
||||
Returns:
|
||||
comp_attribs (ndarray): The text component attributes.
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
assert (text_region_map.shape == center_region_map.shape ==
|
||||
top_height_map.shape == bot_height_map.shape == sin_map.shape ==
|
||||
cos_map.shape)
|
||||
text_mask = text_region_map > self.text_region_thr
|
||||
center_region_mask = (
|
||||
center_region_map > self.center_region_thr) * text_mask
|
||||
|
||||
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8))
|
||||
sin_map, cos_map = sin_map * scale, cos_map * scale
|
||||
|
||||
center_region_mask = fill_hole(center_region_mask)
|
||||
center_region_contours, _ = cv2.findContours(
|
||||
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
mask_sz = center_region_map.shape
|
||||
comp_list = []
|
||||
for contour in center_region_contours:
|
||||
current_center_mask = np.zeros(mask_sz)
|
||||
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
|
||||
if current_center_mask.sum() <= self.center_region_area_thr:
|
||||
continue
|
||||
score_map = text_region_map * current_center_mask
|
||||
|
||||
text_comps = self.propose_comps(
|
||||
score_map, top_height_map, bot_height_map, sin_map, cos_map,
|
||||
self.comp_score_thr, self.min_width, self.max_width,
|
||||
self.comp_shrink_ratio, self.comp_w_h_ratio)
|
||||
|
||||
text_comps = la_nms(text_comps, self.nms_thr)
|
||||
text_comp_mask = np.zeros(mask_sz)
|
||||
text_comp_boxes = text_comps[:, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
|
||||
cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1)
|
||||
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
|
||||
continue
|
||||
if text_comps.shape[-1] > 0:
|
||||
comp_list.append(text_comps)
|
||||
|
||||
if len(comp_list) <= 0:
|
||||
return None, None
|
||||
|
||||
text_comps = np.vstack(comp_list)
|
||||
text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2))
|
||||
centers = np.mean(text_comp_boxes, axis=1).astype(np.int32)
|
||||
x = centers[:, 0]
|
||||
y = centers[:, 1]
|
||||
|
||||
scores = []
|
||||
for text_comp_box in text_comp_boxes:
|
||||
text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
|
||||
mask_sz[1] - 1)
|
||||
text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
|
||||
mask_sz[0] - 1)
|
||||
min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
|
||||
max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
|
||||
text_comp_box = text_comp_box - min_coord
|
||||
box_sz = (max_coord - min_coord + 1)
|
||||
temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
|
||||
cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
|
||||
temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] + 1),
|
||||
min_coord[0]:(max_coord[0] + 1)]
|
||||
score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
|
||||
scores.append(score)
|
||||
scores = np.array(scores).reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps[:, :-1], scores])
|
||||
|
||||
h = top_height_map[y, x].reshape(
|
||||
(-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
|
||||
w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
x = x.reshape((-1, 1))
|
||||
y = y.reshape((-1, 1))
|
||||
comp_attribs = np.hstack([x, y, h, w, cos, sin])
|
||||
|
||||
return comp_attribs, text_comps
|
||||
|
||||
def generate_local_graphs(self, sorted_dist_inds, node_feats):
|
||||
"""Generate local graphs and graph convolution network input data.
|
||||
|
||||
Args:
|
||||
sorted_dist_inds (ndarray): The node indices sorted according to
|
||||
the Euclidean distance.
|
||||
node_feats (tensor): The features of nodes in graph.
|
||||
|
||||
Returns:
|
||||
local_graphs_node_feats (tensor): The features of nodes in local
|
||||
graphs.
|
||||
adjacent_matrices (tensor): The adjacent matrices.
|
||||
pivots_knn_inds (tensor): The k-nearest neighbor indices in
|
||||
local graphs.
|
||||
pivots_local_graphs (tensor): The indices of nodes in local
|
||||
graphs.
|
||||
"""
|
||||
|
||||
assert sorted_dist_inds.ndim == 2
|
||||
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
|
||||
node_feats.shape[0])
|
||||
|
||||
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
|
||||
pivot_local_graphs = []
|
||||
pivot_knns = []
|
||||
|
||||
for pivot_ind, knn in enumerate(knn_graph):
|
||||
|
||||
local_graph_neighbors = set(knn)
|
||||
|
||||
for neighbor_ind in knn:
|
||||
local_graph_neighbors.update(
|
||||
set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
|
||||
1]))
|
||||
|
||||
local_graph_neighbors.discard(pivot_ind)
|
||||
pivot_local_graph = list(local_graph_neighbors)
|
||||
pivot_local_graph.insert(0, pivot_ind)
|
||||
pivot_knn = [pivot_ind] + list(knn)
|
||||
|
||||
pivot_local_graphs.append(pivot_local_graph)
|
||||
pivot_knns.append(pivot_knn)
|
||||
|
||||
num_max_nodes = max([
|
||||
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
|
||||
])
|
||||
|
||||
local_graphs_node_feat = []
|
||||
adjacent_matrices = []
|
||||
pivots_knn_inds = []
|
||||
pivots_local_graphs = []
|
||||
|
||||
for graph_ind, pivot_knn in enumerate(pivot_knns):
|
||||
pivot_local_graph = pivot_local_graphs[graph_ind]
|
||||
num_nodes = len(pivot_local_graph)
|
||||
pivot_ind = pivot_local_graph[0]
|
||||
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
|
||||
|
||||
knn_inds = paddle.cast(
|
||||
paddle.to_tensor([node2ind_map[i]
|
||||
for i in pivot_knn[1:]]), 'int64')
|
||||
pivot_feats = node_feats[pivot_ind]
|
||||
normalized_feats = node_feats[paddle.to_tensor(
|
||||
pivot_local_graph)] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
|
||||
for node in pivot_local_graph:
|
||||
neighbors = sorted_dist_inds[node, 1:self.active_connection + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in pivot_local_graph:
|
||||
adjacent_matrix[node2ind_map[node], node2ind_map[
|
||||
neighbor]] = 1
|
||||
adjacent_matrix[node2ind_map[neighbor], node2ind_map[
|
||||
node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
|
||||
pad_adjacent_matrix = paddle.zeros((num_max_nodes, num_max_nodes), )
|
||||
pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
|
||||
paddle.to_tensor(adjacent_matrix), 'float32')
|
||||
|
||||
pad_normalized_feats = paddle.concat(
|
||||
[
|
||||
normalized_feats, paddle.zeros(
|
||||
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
|
||||
)
|
||||
],
|
||||
axis=0)
|
||||
|
||||
local_graph_nodes = paddle.to_tensor(pivot_local_graph)
|
||||
local_graph_nodes = paddle.concat(
|
||||
[
|
||||
local_graph_nodes, paddle.zeros(
|
||||
[num_max_nodes - num_nodes], dtype='int64')
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
local_graphs_node_feat.append(pad_normalized_feats)
|
||||
adjacent_matrices.append(pad_adjacent_matrix)
|
||||
pivots_knn_inds.append(knn_inds)
|
||||
pivots_local_graphs.append(local_graph_nodes)
|
||||
|
||||
local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
|
||||
adjacent_matrices = paddle.stack(adjacent_matrices, 0)
|
||||
pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
|
||||
pivots_local_graphs = paddle.stack(pivots_local_graphs, 0)
|
||||
|
||||
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivots_local_graphs)
|
||||
|
||||
def __call__(self, preds, feat_maps):
|
||||
"""Generate local graphs and graph convolutional network input data.
|
||||
|
||||
Args:
|
||||
preds (tensor): The predicted maps.
|
||||
feat_maps (tensor): The feature maps to extract content feature of
|
||||
text components.
|
||||
|
||||
Returns:
|
||||
none_flag (bool): The flag showing whether the number of proposed
|
||||
text components is 0.
|
||||
local_graphs_node_feats (tensor): The features of nodes in local
|
||||
graphs.
|
||||
adjacent_matrices (tensor): The adjacent matrices.
|
||||
pivots_knn_inds (tensor): The k-nearest neighbor indices in
|
||||
local graphs.
|
||||
pivots_local_graphs (tensor): The indices of nodes in local
|
||||
graphs.
|
||||
text_comps (ndarray): The predicted text components.
|
||||
"""
|
||||
if preds.ndim == 4:
|
||||
assert preds.shape[0] == 1
|
||||
preds = paddle.squeeze(preds)
|
||||
pred_text_region = F.sigmoid(preds[0]).numpy()
|
||||
pred_center_region = F.sigmoid(preds[1]).numpy()
|
||||
pred_sin_map = preds[2].numpy()
|
||||
pred_cos_map = preds[3].numpy()
|
||||
pred_top_height_map = preds[4].numpy()
|
||||
pred_bot_height_map = preds[5].numpy()
|
||||
|
||||
comp_attribs, text_comps = self.propose_comps_and_attribs(
|
||||
pred_text_region, pred_center_region, pred_top_height_map,
|
||||
pred_bot_height_map, pred_sin_map, pred_cos_map)
|
||||
|
||||
if comp_attribs is None or len(comp_attribs) < 2:
|
||||
none_flag = True
|
||||
return none_flag, (0, 0, 0, 0, 0)
|
||||
|
||||
comp_centers = comp_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
|
||||
|
||||
geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
|
||||
geo_feats = paddle.to_tensor(geo_feats)
|
||||
|
||||
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
|
||||
comp_attribs = comp_attribs.astype(np.float32)
|
||||
angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
|
||||
angle = angle.reshape((-1, 1))
|
||||
rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
|
||||
rois = paddle.to_tensor(rotated_rois)
|
||||
|
||||
content_feats = self.pooling(feat_maps, rois)
|
||||
content_feats = content_feats.reshape([content_feats.shape[0], -1])
|
||||
node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
|
||||
|
||||
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
|
||||
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivots_local_graphs) = self.generate_local_graphs(sorted_dist_inds,
|
||||
node_feats)
|
||||
|
||||
none_flag = False
|
||||
return none_flag, (local_graphs_node_feat, adjacent_matrices,
|
||||
pivots_knn_inds, pivots_local_graphs, text_comps)
|
|
@ -27,11 +27,12 @@ def build_neck(config):
|
|||
from .pren_fpn import PRENFPN
|
||||
from .csp_pan import CSPPAN
|
||||
from .ct_fpn import CTFPN
|
||||
from .fpn_unet import FPN_UNet
|
||||
from .rf_adaptor import RFAdaptor
|
||||
support_dict = [
|
||||
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
|
||||
'RFAdaptor'
|
||||
'RFAdaptor', 'FPN_UNet'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/necks/fpn_unet.py
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class UpBlock(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(out_channels, int)
|
||||
|
||||
self.conv1x1 = nn.Conv2D(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv3x3 = nn.Conv2D(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv = nn.Conv2DTranspose(
|
||||
out_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1x1(x))
|
||||
x = F.relu(self.conv3x3(x))
|
||||
x = self.deconv(x)
|
||||
return x
|
||||
|
||||
|
||||
class FPN_UNet(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
assert len(in_channels) == 4
|
||||
assert isinstance(out_channels, int)
|
||||
self.out_channels = out_channels
|
||||
|
||||
blocks_out_channels = [out_channels] + [
|
||||
min(out_channels * 2**i, 256) for i in range(4)
|
||||
]
|
||||
blocks_in_channels = [blocks_out_channels[1]] + [
|
||||
in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
|
||||
] + [in_channels[3]]
|
||||
|
||||
self.up4 = nn.Conv2DTranspose(
|
||||
blocks_in_channels[4],
|
||||
blocks_out_channels[4],
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
|
||||
self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
|
||||
self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
|
||||
self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (list[Tensor] | tuple[Tensor]): A list of four tensors of shape
|
||||
:math:`(N, C_i, H_i, W_i)`, representing C2, C3, C4, C5
|
||||
features respectively. :math:`C_i` should matches the number in
|
||||
``in_channels``.
|
||||
|
||||
Returns:
|
||||
Tensor: Shape :math:`(N, C, H, W)` where :math:`H=4H_0` and
|
||||
:math:`W=4W_0`.
|
||||
"""
|
||||
c2, c3, c4, c5 = x
|
||||
|
||||
x = F.relu(self.up4(c5))
|
||||
|
||||
x = paddle.concat([x, c4], axis=1)
|
||||
x = F.relu(self.up_block3(x))
|
||||
|
||||
x = paddle.concat([x, c3], axis=1)
|
||||
x = F.relu(self.up_block2(x))
|
||||
|
||||
x = paddle.concat([x, c2], axis=1)
|
||||
x = F.relu(self.up_block1(x))
|
||||
|
||||
x = self.up_block0(x)
|
||||
return x
|
|
@ -36,6 +36,7 @@ from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, Di
|
|||
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
|
||||
from .picodet_postprocess import PicoDetPostProcess
|
||||
from .ct_postprocess import CTPostProcess
|
||||
from .drrg_postprocess import DRRGPostprocess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
|
@ -49,7 +50,8 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
|
||||
'TableMasterLabelDecode', 'SPINLabelDecode',
|
||||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
|
||||
'RFLLabelDecode', 'DRRGPostprocess'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -38,7 +38,7 @@ class DBPostProcess(object):
|
|||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
use_polygon=False,
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
|
@ -46,7 +46,7 @@ class DBPostProcess(object):
|
|||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.use_polygon = use_polygon
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
@ -233,12 +233,14 @@ class DBPostProcess(object):
|
|||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.use_polygon is True:
|
||||
if self.box_type == 'poly':
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||
mask, src_w, src_h)
|
||||
else:
|
||||
elif self.box_type == 'quad':
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
else:
|
||||
raise ValueError("box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
@ -254,7 +256,7 @@ class DistillationDBPostProcess(object):
|
|||
unclip_ratio=1.5,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
use_polygon=False,
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
@ -265,7 +267,7 @@ class DistillationDBPostProcess(object):
|
|||
unclip_ratio=unclip_ratio,
|
||||
use_dilation=use_dilation,
|
||||
score_mode=score_mode,
|
||||
use_polygon=use_polygon)
|
||||
box_type=box_type)
|
||||
|
||||
def __call__(self, predicts, shape_list):
|
||||
results = {}
|
||||
|
|
|
@ -0,0 +1,326 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/postprocess/drrg_postprocessor.py
|
||||
"""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from numpy.linalg import norm
|
||||
import cv2
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, ind):
|
||||
self.__ind = ind
|
||||
self.__links = set()
|
||||
|
||||
@property
|
||||
def ind(self):
|
||||
return self.__ind
|
||||
|
||||
@property
|
||||
def links(self):
|
||||
return set(self.__links)
|
||||
|
||||
def add_link(self, link_node):
|
||||
self.__links.add(link_node)
|
||||
link_node.__links.add(self)
|
||||
|
||||
|
||||
def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
|
||||
assert edges.ndim == 2
|
||||
assert edges.shape[1] == 2
|
||||
assert edges.shape[0] == scores.shape[0]
|
||||
assert text_comps.ndim == 2
|
||||
assert isinstance(edge_len_thr, float)
|
||||
|
||||
edges = np.sort(edges, axis=1)
|
||||
score_dict = {}
|
||||
for i, edge in enumerate(edges):
|
||||
if text_comps is not None:
|
||||
box1 = text_comps[edge[0], :8].reshape(4, 2)
|
||||
box2 = text_comps[edge[1], :8].reshape(4, 2)
|
||||
center1 = np.mean(box1, axis=0)
|
||||
center2 = np.mean(box2, axis=0)
|
||||
distance = norm(center1 - center2)
|
||||
if distance > edge_len_thr:
|
||||
scores[i] = 0
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0], edge[1]] = 0.5 * (
|
||||
score_dict[edge[0], edge[1]] + scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
|
||||
nodes = np.sort(np.unique(edges.flatten()))
|
||||
mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
|
||||
mapping[nodes] = np.arange(nodes.shape[0])
|
||||
order_inds = mapping[edges]
|
||||
vertices = [Node(node) for node in nodes]
|
||||
for ind in order_inds:
|
||||
vertices[ind[0]].add_link(vertices[ind[1]])
|
||||
|
||||
return vertices, score_dict
|
||||
|
||||
|
||||
def connected_components(nodes, score_dict, link_thr):
|
||||
assert isinstance(nodes, list)
|
||||
assert all([isinstance(node, Node) for node in nodes])
|
||||
assert isinstance(score_dict, dict)
|
||||
assert isinstance(link_thr, float)
|
||||
|
||||
clusters = []
|
||||
nodes = set(nodes)
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
cluster = {node}
|
||||
node_queue = [node]
|
||||
while node_queue:
|
||||
node = node_queue.pop(0)
|
||||
neighbors = set([
|
||||
neighbor for neighbor in node.links
|
||||
if score_dict[tuple(sorted([node.ind, neighbor.ind]))] >=
|
||||
link_thr
|
||||
])
|
||||
neighbors.difference_update(cluster)
|
||||
nodes.difference_update(neighbors)
|
||||
cluster.update(neighbors)
|
||||
node_queue.extend(neighbors)
|
||||
clusters.append(list(cluster))
|
||||
return clusters
|
||||
|
||||
|
||||
def clusters2labels(clusters, num_nodes):
|
||||
assert isinstance(clusters, list)
|
||||
assert all([isinstance(cluster, list) for cluster in clusters])
|
||||
assert all(
|
||||
[isinstance(node, Node) for cluster in clusters for node in cluster])
|
||||
assert isinstance(num_nodes, int)
|
||||
|
||||
node_labels = np.zeros(num_nodes)
|
||||
for cluster_ind, cluster in enumerate(clusters):
|
||||
for node in cluster:
|
||||
node_labels[node.ind] = cluster_ind
|
||||
return node_labels
|
||||
|
||||
|
||||
def remove_single(text_comps, comp_pred_labels):
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[0] == comp_pred_labels.shape[0]
|
||||
|
||||
single_flags = np.zeros_like(comp_pred_labels)
|
||||
pred_labels = np.unique(comp_pred_labels)
|
||||
for label in pred_labels:
|
||||
current_label_flag = (comp_pred_labels == label)
|
||||
if np.sum(current_label_flag) == 1:
|
||||
single_flags[np.where(current_label_flag)[0][0]] = 1
|
||||
keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
|
||||
filtered_text_comps = text_comps[keep_ind, :]
|
||||
filtered_labels = comp_pred_labels[keep_ind]
|
||||
|
||||
return filtered_text_comps, filtered_labels
|
||||
|
||||
|
||||
def norm2(point1, point2):
|
||||
return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
|
||||
|
||||
|
||||
def min_connect_path(points):
|
||||
assert isinstance(points, list)
|
||||
assert all([isinstance(point, list) for point in points])
|
||||
assert all([isinstance(coord, int) for point in points for coord in point])
|
||||
|
||||
points_queue = points.copy()
|
||||
shortest_path = []
|
||||
current_edge = [[], []]
|
||||
|
||||
edge_dict0 = {}
|
||||
edge_dict1 = {}
|
||||
current_edge[0] = points_queue[0]
|
||||
current_edge[1] = points_queue[0]
|
||||
points_queue.remove(points_queue[0])
|
||||
while points_queue:
|
||||
for point in points_queue:
|
||||
length0 = norm2(point, current_edge[0])
|
||||
edge_dict0[length0] = [point, current_edge[0]]
|
||||
length1 = norm2(current_edge[1], point)
|
||||
edge_dict1[length1] = [current_edge[1], point]
|
||||
key0 = min(edge_dict0.keys())
|
||||
key1 = min(edge_dict1.keys())
|
||||
|
||||
if key0 <= key1:
|
||||
start = edge_dict0[key0][0]
|
||||
end = edge_dict0[key0][1]
|
||||
shortest_path.insert(0, [points.index(start), points.index(end)])
|
||||
points_queue.remove(start)
|
||||
current_edge[0] = start
|
||||
else:
|
||||
start = edge_dict1[key1][0]
|
||||
end = edge_dict1[key1][1]
|
||||
shortest_path.append([points.index(start), points.index(end)])
|
||||
points_queue.remove(end)
|
||||
current_edge[1] = end
|
||||
|
||||
edge_dict0 = {}
|
||||
edge_dict1 = {}
|
||||
|
||||
shortest_path = functools.reduce(operator.concat, shortest_path)
|
||||
shortest_path = sorted(set(shortest_path), key=shortest_path.index)
|
||||
|
||||
return shortest_path
|
||||
|
||||
|
||||
def in_contour(cont, point):
|
||||
x, y = point
|
||||
is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
|
||||
return is_inner
|
||||
|
||||
|
||||
def fix_corner(top_line, bot_line, start_box, end_box):
|
||||
assert isinstance(top_line, list)
|
||||
assert all(isinstance(point, list) for point in top_line)
|
||||
assert isinstance(bot_line, list)
|
||||
assert all(isinstance(point, list) for point in bot_line)
|
||||
assert start_box.shape == end_box.shape == (4, 2)
|
||||
|
||||
contour = np.array(top_line + bot_line[::-1])
|
||||
start_left_mid = (start_box[0] + start_box[3]) / 2
|
||||
start_right_mid = (start_box[1] + start_box[2]) / 2
|
||||
end_left_mid = (end_box[0] + end_box[3]) / 2
|
||||
end_right_mid = (end_box[1] + end_box[2]) / 2
|
||||
if not in_contour(contour, start_left_mid):
|
||||
top_line.insert(0, start_box[0].tolist())
|
||||
bot_line.insert(0, start_box[3].tolist())
|
||||
elif not in_contour(contour, start_right_mid):
|
||||
top_line.insert(0, start_box[1].tolist())
|
||||
bot_line.insert(0, start_box[2].tolist())
|
||||
if not in_contour(contour, end_left_mid):
|
||||
top_line.append(end_box[0].tolist())
|
||||
bot_line.append(end_box[3].tolist())
|
||||
elif not in_contour(contour, end_right_mid):
|
||||
top_line.append(end_box[1].tolist())
|
||||
bot_line.append(end_box[2].tolist())
|
||||
return top_line, bot_line
|
||||
|
||||
|
||||
def comps2boundaries(text_comps, comp_pred_labels):
|
||||
assert text_comps.ndim == 2
|
||||
assert len(text_comps) == len(comp_pred_labels)
|
||||
boundaries = []
|
||||
if len(text_comps) < 1:
|
||||
return boundaries
|
||||
for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
|
||||
cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
|
||||
text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
score = np.mean(text_comps[cluster_comp_inds, -1])
|
||||
|
||||
if text_comp_boxes.shape[0] < 1:
|
||||
continue
|
||||
|
||||
elif text_comp_boxes.shape[0] > 1:
|
||||
centers = np.mean(text_comp_boxes, axis=1).astype(np.int32).tolist()
|
||||
shortest_path = min_connect_path(centers)
|
||||
text_comp_boxes = text_comp_boxes[shortest_path]
|
||||
top_line = np.mean(
|
||||
text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
|
||||
bot_line = np.mean(
|
||||
text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
|
||||
top_line, bot_line = fix_corner(
|
||||
top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1])
|
||||
boundary_points = top_line + bot_line[::-1]
|
||||
|
||||
else:
|
||||
top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
|
||||
bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
|
||||
boundary_points = top_line + bot_line
|
||||
|
||||
boundary = [p for coord in boundary_points for p in coord] + [score]
|
||||
boundaries.append(boundary)
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
class DRRGPostprocess(object):
|
||||
"""Merge text components and construct boundaries of text instances.
|
||||
|
||||
Args:
|
||||
link_thr (float): The edge score threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, link_thr, **kwargs):
|
||||
assert isinstance(link_thr, float)
|
||||
self.link_thr = link_thr
|
||||
|
||||
def __call__(self, preds, shape_list):
|
||||
"""
|
||||
Args:
|
||||
edges (ndarray): The edge array of shape N * 2, each row is a node
|
||||
index pair that makes up an edge in graph.
|
||||
scores (ndarray): The edge score array of shape (N,).
|
||||
text_comps (ndarray): The text components.
|
||||
|
||||
Returns:
|
||||
List[list[float]]: The predicted boundaries of text instances.
|
||||
"""
|
||||
edges, scores, text_comps = preds
|
||||
if edges is not None:
|
||||
if isinstance(edges, paddle.Tensor):
|
||||
edges = edges.numpy()
|
||||
if isinstance(scores, paddle.Tensor):
|
||||
scores = scores.numpy()
|
||||
if isinstance(text_comps, paddle.Tensor):
|
||||
text_comps = text_comps.numpy()
|
||||
assert len(edges) == len(scores)
|
||||
assert text_comps.ndim == 2
|
||||
assert text_comps.shape[1] == 9
|
||||
|
||||
vertices, score_dict = graph_propagation(edges, scores, text_comps)
|
||||
clusters = connected_components(vertices, score_dict, self.link_thr)
|
||||
pred_labels = clusters2labels(clusters, text_comps.shape[0])
|
||||
text_comps, pred_labels = remove_single(text_comps, pred_labels)
|
||||
boundaries = comps2boundaries(text_comps, pred_labels)
|
||||
else:
|
||||
boundaries = []
|
||||
|
||||
boundaries, scores = self.resize_boundary(
|
||||
boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
|
||||
boxes_batch = [dict(points=boundaries, scores=scores)]
|
||||
return boxes_batch
|
||||
|
||||
def resize_boundary(self, boundaries, scale_factor):
|
||||
"""Rescale boundaries via scale_factor.
|
||||
|
||||
Args:
|
||||
boundaries (list[list[float]]): The boundary list. Each boundary
|
||||
with size 2k+1 with k>=4.
|
||||
scale_factor(ndarray): The scale factor of size (4,).
|
||||
|
||||
Returns:
|
||||
boundaries (list[list[float]]): The scaled boundaries.
|
||||
"""
|
||||
boxes = []
|
||||
scores = []
|
||||
for b in boundaries:
|
||||
sz = len(b)
|
||||
scores.append(b[-1])
|
||||
b = (np.array(b[:sz - 1]) *
|
||||
(np.tile(scale_factor[:2], int(
|
||||
(sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
|
||||
boxes.append(np.array(b).reshape([-1, 2]))
|
||||
return boxes, scores
|
|
@ -15,15 +15,15 @@ English | [简体中文](README_ch.md)
|
|||
|
||||
PP-Structure is an intelligent document analysis system developed by the PaddleOCR team, which aims to help developers better complete tasks related to document understanding such as layout analysis and table recognition.
|
||||
|
||||
The pipeline of PP-Structurev2 system is shown below. The document image first passes through the image direction correction module to identify the direction of the entire image and complete the direction correction. Then, two tasks of layout information analysis and key information extraction can be completed.
|
||||
The pipeline of PP-StructureV2 system is shown below. The document image first passes through the image direction correction module to identify the direction of the entire image and complete the direction correction. Then, two tasks of layout information analysis and key information extraction can be completed.
|
||||
|
||||
- In the layout analysis task, the image first goes through the layout analysis model to divide the image into different areas such as text, table, and figure, and then analyze these areas separately. For example, the table area is sent to the form recognition module for structured recognition, and the text area is sent to the OCR engine for text recognition. Finally, the layout recovery module restores it to a word or pdf file with the same layout as the original image;
|
||||
- In the key information extraction task, the OCR engine is first used to extract the text content, and then the SER(semantic entity recognition) module obtains the semantic entities in the image, and finally the RE(relationship extraction) module obtains the correspondence between the semantic entities, thereby extracting the required key information.
|
||||
<img src="./docs/ppstructurev2_pipeline.png" width="100%"/>
|
||||
<img src="https://user-images.githubusercontent.com/14270174/195265734-6f4b5a7f-59b1-4fcc-af6d-89afc9bd51e1.jpg" width="100%"/>
|
||||
|
||||
More technical details: 👉 [PP-Structurev2 Technical Report](docs/PP-Structurev2_introduction.md)
|
||||
More technical details: 👉 [PP-StructureV2 Technical Report](https://arxiv.org/abs/2210.05391)
|
||||
|
||||
PP-Structurev2 supports independent use or flexible collocation of each module. For example, you can use layout analysis alone or table recognition alone. Click the corresponding link below to get the tutorial for each independent module:
|
||||
PP-StructureV2 supports independent use or flexible collocation of each module. For example, you can use layout analysis alone or table recognition alone. Click the corresponding link below to get the tutorial for each independent module:
|
||||
|
||||
- [Layout Analysis](layout/README.md)
|
||||
- [Table Recognition](table/README.md)
|
||||
|
@ -32,7 +32,7 @@ PP-Structurev2 supports independent use or flexible collocation of each module.
|
|||
|
||||
## 2. Features
|
||||
|
||||
The main features of PP-Structurev2 are as follows:
|
||||
The main features of PP-StructureV2 are as follows:
|
||||
- Support layout analysis of documents in the form of images/pdfs, which can be divided into areas such as **text, titles, tables, figures, formulas, etc.**;
|
||||
- Support common Chinese and English **table detection** tasks;
|
||||
- Support structured table recognition, and output the final result to **Excel file**;
|
||||
|
@ -43,7 +43,7 @@ The main features of PP-Structurev2 are as follows:
|
|||
|
||||
## 3. Results
|
||||
|
||||
PP-Structurev2 supports the independent use or flexible collocation of each module. For example, layout analysis can be used alone, or table recognition can be used alone. Only the visualization effects of several representative usage methods are shown here.
|
||||
PP-StructureV2 supports the independent use or flexible collocation of each module. For example, layout analysis can be used alone, or table recognition can be used alone. Only the visualization effects of several representative usage methods are shown here.
|
||||
|
||||
### 3.1 Layout analysis and table recognition
|
||||
|
||||
|
@ -59,7 +59,7 @@ The following figure shows the effect of layout recovery based on the results of
|
|||
|
||||
* SER
|
||||
|
||||
Different colored boxes in the figure represent different categories.
|
||||
Different colored boxes in the figure represent different categories.
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186094456-01a1dd11-1433-4437-9ab2-6480ac94ec0a.png" width="600">
|
||||
|
@ -91,7 +91,7 @@ In the figure, the red box represents `Question`, the blue box represents `Answe
|
|||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186095641-5843b4da-34d7-4c1c-943a-b1036a859fe3.png" width="600">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185393805-c67ff571-cf7e-4217-a4b0-8b396c4f22bb.jpg" width="600">
|
||||
|
@ -114,4 +114,3 @@ For structural analysis related model downloads, please refer to:
|
|||
|
||||
For OCR related model downloads, please refer to:
|
||||
- [PP-OCR Model Zoo](../doc/doc_en/models_list_en.md)
|
||||
|
||||
|
|
|
@ -16,14 +16,15 @@
|
|||
|
||||
PP-Structure是PaddleOCR团队自研的智能文档分析系统,旨在帮助开发者更好的完成版面分析、表格识别等文档理解相关任务。
|
||||
|
||||
PP-Structurev2系统流程图如下所示,文档图像首先经过图像矫正模块,判断整图方向并完成转正,随后可以完成版面信息分析与关键信息抽取2类任务。
|
||||
PP-StructureV2系统流程图如下所示,文档图像首先经过图像矫正模块,判断整图方向并完成转正,随后可以完成版面信息分析与关键信息抽取2类任务。
|
||||
- 版面分析任务中,图像首先经过版面分析模型,将图像划分为文本、表格、图像等不同区域,随后对这些区域分别进行识别,如,将表格区域送入表格识别模块进行结构化识别,将文本区域送入OCR引擎进行文字识别,最后使用版面恢复模块将其恢复为与原始图像布局一致的word或者pdf格式的文件;
|
||||
- 关键信息抽取任务中,首先使用OCR引擎提取文本内容,然后由语义实体识别模块获取图像中的语义实体,最后经关系抽取模块获取语义实体之间的对应关系,从而提取需要的关键信息。
|
||||
<img src="./docs/ppstructurev2_pipeline.png" width="100%"/>
|
||||
|
||||
更多技术细节:👉 [PP-Structurev2技术报告](docs/PP-Structurev2_introduction.md)
|
||||
<img src="https://user-images.githubusercontent.com/14270174/195265734-6f4b5a7f-59b1-4fcc-af6d-89afc9bd51e1.jpg" width="100%"/>
|
||||
|
||||
PP-Structurev2支持各个模块独立使用或灵活搭配,如,可以单独使用版面分析,或单独使用表格识别,点击下面相应链接获取各个独立模块的使用教程:
|
||||
更多技术细节:👉 PP-StructureV2技术报告 [中文版](docs/PP-StructureV2_introduction.md),[英文版](https://arxiv.org/abs/2210.05391)。
|
||||
|
||||
PP-StructureV2支持各个模块独立使用或灵活搭配,如,可以单独使用版面分析,或单独使用表格识别,点击下面相应链接获取各个独立模块的使用教程:
|
||||
|
||||
- [版面分析](layout/README_ch.md)
|
||||
- [表格识别](table/README_ch.md)
|
||||
|
@ -33,7 +34,7 @@ PP-Structurev2支持各个模块独立使用或灵活搭配,如,可以单独
|
|||
<a name="2"></a>
|
||||
## 2. 特性
|
||||
|
||||
PP-Structurev2的主要特性如下:
|
||||
PP-StructureV2的主要特性如下:
|
||||
- 支持对图片/pdf形式的文档进行版面分析,可以划分**文字、标题、表格、图片、公式等**区域;
|
||||
- 支持通用的中英文**表格检测**任务;
|
||||
- 支持表格区域进行结构化识别,最终结果输出**Excel文件**;
|
||||
|
@ -44,7 +45,7 @@ PP-Structurev2的主要特性如下:
|
|||
|
||||
<a name="3"></a>
|
||||
## 3. 效果展示
|
||||
PP-Structurev2支持各个模块独立使用或灵活搭配,如,可以单独使用版面分析,或单独使用表格识别,这里仅展示几种代表性使用方式的可视化效果。
|
||||
PP-StructureV2支持各个模块独立使用或灵活搭配,如,可以单独使用版面分析,或单独使用表格识别,这里仅展示几种代表性使用方式的可视化效果。
|
||||
|
||||
<a name="31"></a>
|
||||
### 3.1 版面分析和表格识别
|
||||
|
@ -102,7 +103,7 @@ PP-Structurev2支持各个模块独立使用或灵活搭配,如,可以单独
|
|||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/25809855/186095641-5843b4da-34d7-4c1c-943a-b1036a859fe3.png" width="600">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 快速体验
|
||||
|
@ -119,4 +120,3 @@ PP-Structurev2支持各个模块独立使用或灵活搭配,如,可以单独
|
|||
|
||||
OCR相关模型下载可以参考:
|
||||
- [PP-OCR 模型库](../doc/doc_ch/models_list.md)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# PP-Structurev2
|
||||
# PP-StructureV2
|
||||
|
||||
## 目录
|
||||
|
||||
|
@ -16,11 +16,11 @@
|
|||
|
||||
现实场景中包含大量的文档图像,它们以图片等非结构化形式存储。基于文档图像的结构化分析与信息抽取对于数据的数字化存储以及产业的数字化转型至关重要。基于该考虑,PaddleOCR自研并发布了PP-Structure智能文档分析系统,旨在帮助开发者更好的完成版面分析、表格识别、关键信息抽取等文档理解相关任务。
|
||||
|
||||
近期,PaddleOCR团队针对PP-Structurev1的版面分析、表格识别、关键信息抽取模块,进行了共计8个方面的升级,同时新增整图方向矫正、文档复原等功能,打造出一个全新的、效果更优的文档分析系统:PP-Structurev2。
|
||||
近期,PaddleOCR团队针对PP-Structurev1的版面分析、表格识别、关键信息抽取模块,进行了共计8个方面的升级,同时新增整图方向矫正、文档复原等功能,打造出一个全新的、效果更优的文档分析系统:PP-StructureV2。
|
||||
|
||||
## 2. 简介
|
||||
|
||||
PP-Structurev2在PP-Structurev1的基础上进一步改进,主要有以下3个方面升级:
|
||||
PP-StructureV2在PP-Structurev1的基础上进一步改进,主要有以下3个方面升级:
|
||||
|
||||
* **系统功能升级** :新增图像矫正和版面复原模块,图像转word/pdf、关键信息抽取能力全覆盖!
|
||||
* **系统性能优化** :
|
||||
|
@ -29,7 +29,7 @@ PP-Structurev2在PP-Structurev1的基础上进一步改进,主要有以下3个
|
|||
* 关键信息抽取:设计视觉无关模型结构,语义实体识别精度提升**2.8%**,关系抽取精度提升**9.1%**。
|
||||
* **中文场景适配** :完成对版面分析与表格识别的中文场景适配,开源**开箱即用**的中文场景版面结构化模型!
|
||||
|
||||
PP-Structurev2系统流程图如下所示,文档图像首先经过图像矫正模块,判断整图方向并完成转正,随后可以完成版面信息分析与关键信息抽取2类任务。版面分析任务中,图像首先经过版面分析模型,将图像划分为文本、表格、图像等不同区域,随后对这些区域分别进行识别,如,将表格区域送入表格识别模块进行结构化识别,将文本区域送入OCR引擎进行文字识别,最后使用版面恢复模块将其恢复为与原始图像布局一致的word或者pdf格式的文件;关键信息抽取任务中,首先使用OCR引擎提取文本内容,然后由语义实体识别模块获取图像中的语义实体,最后经关系抽取模块获取语义实体之间的对应关系,从而提取需要的关键信息。
|
||||
PP-StructureV2系统流程图如下所示,文档图像首先经过图像矫正模块,判断整图方向并完成转正,随后可以完成版面信息分析与关键信息抽取2类任务。版面分析任务中,图像首先经过版面分析模型,将图像划分为文本、表格、图像等不同区域,随后对这些区域分别进行识别,如,将表格区域送入表格识别模块进行结构化识别,将文本区域送入OCR引擎进行文字识别,最后使用版面恢复模块将其恢复为与原始图像布局一致的word或者pdf格式的文件;关键信息抽取任务中,首先使用OCR引擎提取文本内容,然后由语义实体识别模块获取图像中的语义实体,最后经关系抽取模块获取语义实体之间的对应关系,从而提取需要的关键信息。
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185939247-57e53254-399c-46c4-a610-da4fa79232f5.png" width="1200">
|
||||
|
@ -62,7 +62,7 @@ PP-Structurev2系统流程图如下所示,文档图像首先经过图像矫正
|
|||
|
||||
## 3. 整图方向矫正
|
||||
|
||||
由于训练集一般以正方向图像为主,旋转过的文档图像直接输入模型会增加识别难度,影响识别效果。PP-Structurev2引入了整图方向矫正模块来判断含文字图像的方向,并将其进行方向调整。
|
||||
由于训练集一般以正方向图像为主,旋转过的文档图像直接输入模型会增加识别难度,影响识别效果。PP-StructureV2引入了整图方向矫正模块来判断含文字图像的方向,并将其进行方向调整。
|
||||
|
||||
我们直接调用PaddleClas中提供的文字图像方向分类模型-[PULC_text_image_orientation](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/PULC/PULC_text_image_orientation.md),该模型部分数据集图像如下所示。不同于文本行方向分类器,文字图像方向分类模型针对整图进行方向判别。文字图像方向分类模型在验证集上精度高达99%,单张图像CPU预测耗时仅为`2.16ms`。
|
||||
|
||||
|
@ -76,7 +76,7 @@ PP-Structurev2系统流程图如下所示,文档图像首先经过图像矫正
|
|||
|
||||
版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等,PP-Structurev1使用了PaddleDetection中开源的高效检测算法PP-YOLOv2完成版面分析的任务。
|
||||
|
||||
在PP-Structurev2中,我们发布基于PP-PicoDet的轻量级版面分析模型,并针对版面分析场景定制图像尺度,同时使用FGD知识蒸馏算法,进一步提升模型精度。最终CPU上`41ms`即可完成版面分析过程(仅包含模型推理时间,数据预处理耗时大约50ms左右)。在公开数据集PubLayNet 上,消融实验如下:
|
||||
在PP-StructureV2中,我们发布基于PP-PicoDet的轻量级版面分析模型,并针对版面分析场景定制图像尺度,同时使用FGD知识蒸馏算法,进一步提升模型精度。最终CPU上`41ms`即可完成版面分析过程(仅包含模型推理时间,数据预处理耗时大约50ms左右)。在公开数据集PubLayNet 上,消融实验如下:
|
||||
|
||||
| 实验序号 | 策略 | 模型存储(M) | mAP | CPU预测耗时(ms) |
|
||||
|:------:|:------:|:------:|:------:|:------:|
|
||||
|
@ -95,7 +95,7 @@ PP-Structurev2系统流程图如下所示,文档图像首先经过图像矫正
|
|||
| 模型 | mAP | CPU预测耗时 |
|
||||
|-------------------|-----------|------------|
|
||||
| layoutparser (Detectron2) | 88.98% | 2.9s |
|
||||
| PP-Structurev2 (PP-PicoDet) | **94%** | 41.2ms |
|
||||
| PP-StructureV2 (PP-PicoDet) | **94%** | 41.2ms |
|
||||
|
||||
[PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet)数据集是一个大型的文档图像数据集,包含Text、Title、Tale、Figure、List,共5个类别。数据集中包含335,703张训练集、11,245张验证集和11,405张测试集。训练数据与标注示例图如下所示:
|
||||
|
||||
|
@ -157,7 +157,7 @@ FGD(Focal and Global Knowledge Distillation for Detectors),是一种兼顾
|
|||
|
||||
### 4.2 表格识别
|
||||
|
||||
基于深度学习的表格识别算法种类丰富,PP-Structurev1中,我们基于文本识别算法RARE研发了端到端表格识别算法TableRec-RARE,模型输出为表格结构的HTML表示,进而可以方便地转化为Excel文件。PP-Structurev2中,我们对模型结构和损失函数等5个方面进行升级,提出了 SLANet (Structure Location Alignment Network) ,模型结构如下图所示:
|
||||
基于深度学习的表格识别算法种类丰富,PP-Structurev1中,我们基于文本识别算法RARE研发了端到端表格识别算法TableRec-RARE,模型输出为表格结构的HTML表示,进而可以方便地转化为Excel文件。PP-StructureV2中,我们对模型结构和损失函数等5个方面进行升级,提出了 SLANet (Structure Location Alignment Network) ,模型结构如下图所示:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/14270174/185940811-089c9265-4be9-4776-b365-6d1125606b4b.png" width="1200">
|
||||
|
@ -189,7 +189,7 @@ FGD(Focal and Global Knowledge Distillation for Detectors),是一种兼顾
|
|||
|
||||
**(1) CPU友好型轻量级骨干网络PP-LCNet**
|
||||
|
||||
PP-LCNet是结合Intel-CPU端侧推理特性而设计的轻量高性能骨干网络,该方案在图像分类任务上取得了比ShuffleNetV2、MobileNetV3、GhostNet等轻量级模型更优的“精度-速度”均衡。PP-Structurev2中,我们采用PP-LCNet作为骨干网络,表格识别模型精度从71.73%提升至72.98%;同时加载通过SSLD知识蒸馏方案训练得到的图像分类模型权重作为表格识别的预训练模型,最终精度进一步提升2.95%至74.71%。
|
||||
PP-LCNet是结合Intel-CPU端侧推理特性而设计的轻量高性能骨干网络,该方案在图像分类任务上取得了比ShuffleNetV2、MobileNetV3、GhostNet等轻量级模型更优的“精度-速度”均衡。PP-StructureV2中,我们采用PP-LCNet作为骨干网络,表格识别模型精度从71.73%提升至72.98%;同时加载通过SSLD知识蒸馏方案训练得到的图像分类模型权重作为表格识别的预训练模型,最终精度进一步提升2.95%至74.71%。
|
||||
|
||||
**(2)轻量级高低层特征融合模块CSP-PAN**
|
||||
|
||||
|
@ -199,7 +199,7 @@ PP-LCNet是结合Intel-CPU端侧推理特性而设计的轻量高性能骨干网
|
|||
|
||||
TableRec-RARE的TableAttentionHead如下图a所示,TableAttentionHead在执行完全部step的计算后拿到最终隐藏层状态表征(hiddens),随后hiddens经由SDM(Structure Decode Module)和CLDM(Cell Location Decode Module)模块生成全部的表格结构token和单元格坐标。但是这种设计忽略了单元格token和坐标之间一一对应的关系。
|
||||
|
||||
PP-Structurev2中,我们设计SLAHead模块,对单元格token和坐标之间做了对齐操作,如下图b所示。在SLAHead中,每一个step的隐藏层状态表征会分别送入SDM和CLDM来得到当前step的token和坐标,每个step的token和坐标输出分别进行concat得到表格的html表达和全部单元格的坐标。此外,考虑到表格识别模型的单元格准确率依赖于表格结构的识别准确,我们将损失函数中表格结构分支与单元格定位分支的权重比从1:1提升到8:1,并使用收敛更稳定的Smoothl1 Loss替换定位分支中的MSE Loss。最终模型精度从75.68%提高至77.7%。
|
||||
PP-StructureV2中,我们设计SLAHead模块,对单元格token和坐标之间做了对齐操作,如下图b所示。在SLAHead中,每一个step的隐藏层状态表征会分别送入SDM和CLDM来得到当前step的token和坐标,每个step的token和坐标输出分别进行concat得到表格的html表达和全部单元格的坐标。此外,考虑到表格识别模型的单元格准确率依赖于表格结构的识别准确,我们将损失函数中表格结构分支与单元格定位分支的权重比从1:1提升到8:1,并使用收敛更稳定的Smoothl1 Loss替换定位分支中的MSE Loss。最终模型精度从75.68%提高至77.7%。
|
||||
|
||||
|
||||
<div align="center">
|
||||
|
@ -211,7 +211,7 @@ PP-Structurev2中,我们设计SLAHead模块,对单元格token和坐标之间
|
|||
|
||||
TableRec-RARE算法中,我们使用`<td>`和`</td>`两个单独的token来表示一个非跨行列单元格,这种表示方式限制了网络对于单元格数量较多表格的处理能力。
|
||||
|
||||
PP-Structurev2中,我们参考TableMaster中的token处理方法,将`<td>`和`</td>`合并为一个token-`<td></td>`。合并token后,验证集中token长度大于500的图片也参与模型评估,最终模型精度降低为76.31%,但是端到端TEDS提升1.04%。
|
||||
PP-StructureV2中,我们参考TableMaster中的token处理方法,将`<td>`和`</td>`合并为一个token-`<td></td>`。合并token后,验证集中token长度大于500的图片也参与模型评估,最终模型精度降低为76.31%,但是端到端TEDS提升1.04%。
|
||||
|
||||
#### 4.2.2 中文场景适配
|
||||
|
||||
|
@ -249,7 +249,7 @@ PP-Structurev2中,我们参考TableMaster中的token处理方法,将`<td>`
|
|||
|
||||
### 4.3 版面恢复
|
||||
|
||||
版面恢复指的是文档图像经过OCR识别、版面分析、表格识别等方法处理后的内容可以与原始文档保持相同的排版方式,并输出到word等文档中。PP-Structurev2中,我们版面恢复系统,包含版面分析、表格识别、OCR文本检测与识别等子模块。
|
||||
版面恢复指的是文档图像经过OCR识别、版面分析、表格识别等方法处理后的内容可以与原始文档保持相同的排版方式,并输出到word等文档中。PP-StructureV2中,我们版面恢复系统,包含版面分析、表格识别、OCR文本检测与识别等子模块。
|
||||
下图展示了版面恢复的结果:
|
||||
|
||||
<div align="center">
|
||||
|
@ -258,7 +258,7 @@ PP-Structurev2中,我们参考TableMaster中的token处理方法,将`<td>`
|
|||
|
||||
## 5. 关键信息抽取
|
||||
|
||||
关键信息抽取指的是针对文档图像的文字内容,提取出用户关注的关键信息,如身份证中的姓名、住址等字段。PP-Structure中支持了基于多模态LayoutLM系列模型的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。PP-Structurev2中,我们对模型结构以及下游任务训练方法进行升级,提出了VI-LayoutXLM(Visual-feature Independent LayoutXLM),具体流程图如下所示。
|
||||
关键信息抽取指的是针对文档图像的文字内容,提取出用户关注的关键信息,如身份证中的姓名、住址等字段。PP-Structure中支持了基于多模态LayoutLM系列模型的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。PP-StructureV2中,我们对模型结构以及下游任务训练方法进行升级,提出了VI-LayoutXLM(Visual-feature Independent LayoutXLM),具体流程图如下所示。
|
||||
|
||||
|
||||
<div align="center">
|
||||
|
@ -394,7 +394,7 @@ RE任务的可视化结果如下所示。
|
|||
| 实验序号 | 策略 | F1-score |
|
||||
|:------:|:------:|:------:|
|
||||
| 1 | LayoutXLM | 82.28% |
|
||||
| 2 | PP-Structurev2 SER | **87.79%** |
|
||||
| 2 | PP-StructureV2 SER | **87.79%** |
|
||||
|
||||
|
||||
**RE任务结果**
|
||||
|
@ -402,7 +402,7 @@ RE任务的可视化结果如下所示。
|
|||
| 实验序号 | 策略 | F1-score |
|
||||
|:------:|:------:|:------:|
|
||||
| 1 | LayoutXLM | 53.13% |
|
||||
| 2 | PP-Structurev2 SER | **74.87%** |
|
||||
| 2 | PP-StructureV2 SER | **74.87%** |
|
||||
|
||||
|
||||
## 6. Reference
|
|
@ -18,13 +18,13 @@ cd ppstructure
|
|||
下载模型
|
||||
```bash
|
||||
mkdir inference && cd inference
|
||||
# 下载PP-Structurev2版面分析模型并解压
|
||||
# 下载PP-StructureV2版面分析模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar xf picodet_lcnet_x1_0_layout_infer.tar
|
||||
# 下载PP-OCRv3文本检测模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar
|
||||
# 下载PP-OCRv3文本识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
|
||||
# 下载PP-Structurev2表格识别模型并解压
|
||||
# 下载PP-StructureV2表格识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
|
||||
cd ..
|
||||
```
|
||||
|
|
|
@ -20,13 +20,13 @@ download model
|
|||
|
||||
```bash
|
||||
mkdir inference && cd inference
|
||||
# Download the PP-Structurev2 layout analysis model and unzip it
|
||||
# Download the PP-StructureV2 layout analysis model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar xf picodet_lcnet_x1_0_layout_infer.tar
|
||||
# Download the PP-OCRv3 text detection model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar
|
||||
# Download the PP-OCRv3 text recognition model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
|
||||
# Download the PP-Structurev2 form recognition model and unzip it
|
||||
# Download the PP-StructureV2 form recognition model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
|
||||
cd ..
|
||||
```
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 1021 KiB |
|
@ -227,7 +227,7 @@ for line in result:
|
|||
<a name="225"></a>
|
||||
#### 2.2.5 关键信息抽取
|
||||
|
||||
关键信息抽取暂不支持通过whl包调用,详细使用教程请参考:[关键信息抽取教程](../kie/README_ch.md)。
|
||||
关键信息抽取暂不支持通过whl包调用,详细使用教程请参考:[inference文档](./inference.md)。
|
||||
|
||||
<a name="226"></a>
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ paddleocr --image_dir=ppstructure/docs/table/table.jpg --type=structure --layout
|
|||
|
||||
#### 2.1.5 Key Information Extraction
|
||||
|
||||
Key information extraction does not currently support use by the whl package. For detailed usage tutorials, please refer to: [Key Information Extraction](../kie/README.md).
|
||||
Key information extraction does not currently support use by the whl package. For detailed usage tutorials, please refer to: [inference document](./inference_en.md).
|
||||
|
||||
<a name="216"></a>
|
||||
#### 2.1.6 layout recovery
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
|
||||
## 2. 关键信息抽取任务流程
|
||||
|
||||
PaddleOCR中实现了LayoutXLM等算法(基于Token),同时,在PP-Structurev2中,对LayoutXLM多模态预训练模型的网络结构进行简化,去除了其中的Visual backbone部分,设计了视觉无关的VI-LayoutXLM模型,同时引入符合人类阅读顺序的排序逻辑以及UDML知识蒸馏策略,最终同时提升了关键信息抽取模型的精度与推理速度。
|
||||
PaddleOCR中实现了LayoutXLM等算法(基于Token),同时,在PP-StructureV2中,对LayoutXLM多模态预训练模型的网络结构进行简化,去除了其中的Visual backbone部分,设计了视觉无关的VI-LayoutXLM模型,同时引入符合人类阅读顺序的排序逻辑以及UDML知识蒸馏策略,最终同时提升了关键信息抽取模型的精度与推理速度。
|
||||
|
||||
下面介绍怎样基于PaddleOCR完成关键信息抽取任务。
|
||||
|
||||
|
@ -115,7 +115,7 @@ Train:
|
|||
|
||||
数据量方面,一般来说,对于比较固定的场景,**50张**左右的训练图片即可达到可以接受的效果,可以使用[PPOCRLabel](../../PPOCRLabel/README_ch.md)完成KIE的标注过程。
|
||||
|
||||
模型方面,推荐使用PP-Structurev2中提出的VI-LayoutXLM模型,它基于LayoutXLM模型进行改进,去除其中的视觉特征提取模块,在精度基本无损的情况下,进一步提升了模型推理速度。更多教程请参考:[VI-LayoutXLM算法介绍](../../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)与[KIE关键信息抽取使用教程](../../doc/doc_ch/kie.md)。
|
||||
模型方面,推荐使用PP-StructureV2中提出的VI-LayoutXLM模型,它基于LayoutXLM模型进行改进,去除其中的视觉特征提取模块,在精度基本无损的情况下,进一步提升了模型推理速度。更多教程请参考:[VI-LayoutXLM算法介绍](../../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)与[KIE关键信息抽取使用教程](../../doc/doc_ch/kie.md)。
|
||||
|
||||
|
||||
#### 2.2.2 SER + RE
|
||||
|
@ -145,7 +145,7 @@ Train:
|
|||
|
||||
数据量方面,一般来说,对于比较固定的场景,**50张**左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
|
||||
|
||||
模型方面,推荐使用PP-Structurev2中提出的VI-LayoutXLM模型,它基于LayoutXLM模型进行改进,去除其中的视觉特征提取模块,在精度基本无损的情况下,进一步提升了模型推理速度。更多教程请参考:[VI-LayoutXLM算法介绍](../../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)与[KIE关键信息抽取使用教程](../../doc/doc_ch/kie.md)。
|
||||
模型方面,推荐使用PP-StructureV2中提出的VI-LayoutXLM模型,它基于LayoutXLM模型进行改进,去除其中的视觉特征提取模块,在精度基本无损的情况下,进一步提升了模型推理速度。更多教程请参考:[VI-LayoutXLM算法介绍](../../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)与[KIE关键信息抽取使用教程](../../doc/doc_ch/kie.md)。
|
||||
|
||||
|
||||
## 3. 参考文献
|
||||
|
|
|
@ -48,7 +48,7 @@ For more detailed introduction of the algorithms, please refer to Chapter 6 of [
|
|||
|
||||
## 2. KIE Pipeline
|
||||
|
||||
Token based methods such as LayoutXLM are implemented in PaddleOCR. What's more, in PP-Structurev2, we simplify the LayoutXLM model and proposed VI-LayoutXLM, in which the visual feature extraction module is removed for speed-up. The textline sorting strategy conforming to the human reading order and UDML knowledge distillation strategy are utilized for higher model accuracy.
|
||||
Token based methods such as LayoutXLM are implemented in PaddleOCR. What's more, in PP-StructureV2, we simplify the LayoutXLM model and proposed VI-LayoutXLM, in which the visual feature extraction module is removed for speed-up. The textline sorting strategy conforming to the human reading order and UDML knowledge distillation strategy are utilized for higher model accuracy.
|
||||
|
||||
|
||||
In the non end-to-end KIE method, KIE needs at least ** 2 steps**. Firstly, the OCR model is used to extract the text and its position. Secondly, the KIE model is used to extract the key information according to the image, text position and text content.
|
||||
|
@ -125,7 +125,7 @@ Take the ID card scenario as an example. The key information generally includes
|
|||
|
||||
In terms of data, generally speaking, for relatively fixed scenes, **50** training images can achieve acceptable effects. You can refer to [PPOCRLabel](../../PPOCRLabel/README.md) for finish the labeling process.
|
||||
|
||||
In terms of model, it is recommended to use the VI-layoutXLM model proposed in PP-Structurev2. It is improved based on the LayoutXLM model, removing the visual feature extraction module, and further improving the model inference speed without the significant reduction on model accuracy. For more tutorials, please refer to [VI-LayoutXLM introduction](../../doc/doc_en/algorithm_kie_vi_layoutxlm_en.md) and [KIE tutorial](../../doc/doc_en/kie_en.md).
|
||||
In terms of model, it is recommended to use the VI-layoutXLM model proposed in PP-StructureV2. It is improved based on the LayoutXLM model, removing the visual feature extraction module, and further improving the model inference speed without the significant reduction on model accuracy. For more tutorials, please refer to [VI-LayoutXLM introduction](../../doc/doc_en/algorithm_kie_vi_layoutxlm_en.md) and [KIE tutorial](../../doc/doc_en/kie_en.md).
|
||||
|
||||
|
||||
#### 2.2.2 SER + RE
|
||||
|
@ -155,7 +155,7 @@ For each textline, you need to add 'ID' and 'linking' field information. The 'ID
|
|||
|
||||
In terms of data, generally speaking, for relatively fixed scenes, about **50** training images can achieve acceptable effects.
|
||||
|
||||
In terms of model, it is recommended to use the VI-layoutXLM model proposed in PP-Structurev2. It is improved based on the LayoutXLM model, removing the visual feature extraction module, and further improving the model inference speed without the significant reduction on model accuracy. For more tutorials, please refer to [VI-LayoutXLM introduction](../../doc/doc_en/algorithm_kie_vi_layoutxlm_en.md) and [KIE tutorial](../../doc/doc_en/kie_en.md).
|
||||
In terms of model, it is recommended to use the VI-layoutXLM model proposed in PP-StructureV2. It is improved based on the LayoutXLM model, removing the visual feature extraction module, and further improving the model inference speed without the significant reduction on model accuracy. For more tutorials, please refer to [VI-LayoutXLM introduction](../../doc/doc_en/algorithm_kie_vi_layoutxlm_en.md) and [KIE tutorial](../../doc/doc_en/kie_en.md).
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ English | [简体中文](README_ch.md)
|
|||
- [2. Install](#2)
|
||||
- [2.1 Install PaddlePaddle](#2.1)
|
||||
- [2.2 Install PaddleOCR](#2.2)
|
||||
- [3. Quick Start using PDF parse](#3)
|
||||
- [4. Quick Start using OCR](#4)
|
||||
- [3. Quick Start using standard PDF parse](#3)
|
||||
- [4. Quick Start using image format PDF parse ](#4)
|
||||
- [4.1 Download models](#4.1)
|
||||
- [4.2 Layout recovery](#4.2)
|
||||
- [5. More](#5)
|
||||
|
@ -19,18 +19,18 @@ English | [简体中文](README_ch.md)
|
|||
The layout recovery module is used to restore the image or pdf to an
|
||||
editable Word file consistent with the original image layout.
|
||||
|
||||
Two layout recovery methods are provided:
|
||||
Two layout recovery methods are provided, you can choose by PDF format:
|
||||
|
||||
- PDF parse: Python based PDF to word library [pdf2docx] (https://github.com/dothinking/pdf2docx) is optimized, the method extracts data from PDF with PyMuPDF, then parse layout with rule, finally, generate docx with python-docx.
|
||||
- **Standard PDF parse(the input is standard PDF)**: Python based PDF to word library [pdf2docx] (https://github.com/dothinking/pdf2docx) is optimized, the method extracts data from PDF with PyMuPDF, then parse layout with rule, finally, generate docx with python-docx.
|
||||
|
||||
- OCR: Layout recovery combines [layout analysis](../layout/README.md)、[table recognition](../table/README.md) to better recover images, tables, titles, etc. supports input files in PDF and document image formats in Chinese and English.
|
||||
- **Image format PDF parse(the input can be standard PDF or image format PDF)**: Layout recovery combines [layout analysis](../layout/README.md)、[table recognition](../table/README.md) to better recover images, tables, titles, etc. supports input files in PDF and document image formats in Chinese and English.
|
||||
|
||||
The input formats and application scenarios of the two methods are as follows:
|
||||
|
||||
| method | input formats | application scenarios/problem |
|
||||
| :-----: | :----------: | :----------------------------------------------------------: |
|
||||
| PDF parse | pdf | Advantages: Better recovery for non-paper documents, each page remains on the same page after restoration<br>Disadvantages: English characters in some Chinese documents are garbled, some contents are still beyond the current page, the whole page content is restored to the table format, and the recovery effect of some pictures is not good |
|
||||
| OCR technique | pdf、picture | Advantages: More suitable for paper document content recovery, OCR recognition effect is more good<br>Disadvantages: Currently, the recovery is based on rules, the effect of content typesetting (spacing, fonts, etc.) need to be further improved, and the effect of layout recovery depends on layout analysis |
|
||||
| Standard PDF parse | pdf | Advantages: Better recovery for non-paper documents, each page remains on the same page after restoration<br>Disadvantages: English characters in some Chinese documents are garbled, some contents are still beyond the current page, the whole page content is restored to the table format, and the recovery effect of some pictures is not good |
|
||||
| Image format PDF parse( | pdf、picture | Advantages: More suitable for paper document content recovery, OCR recognition effect is more good<br>Disadvantages: Currently, the recovery is based on rules, the effect of content typesetting (spacing, fonts, etc.) need to be further improved, and the effect of layout recovery depends on layout analysis |
|
||||
|
||||
The following figure shows the effect of restoring the layout of documents by using PDF parse:
|
||||
|
||||
|
@ -103,7 +103,7 @@ pip3 install pdf2docx-0.0.0-py3-none-any.whl
|
|||
|
||||
<a name="3"></a>
|
||||
|
||||
## 3. Quick Start using PDF parse
|
||||
## 3. Quick Start using standard PDF parse
|
||||
|
||||
`use_pdf2docx_api` use PDF parse for layout recovery, The whl package is also provided for quick use, follow the above code, for more infomation please refer to [quickstart](../docs/quickstart_en.md) for details.
|
||||
|
||||
|
@ -124,7 +124,7 @@ python3 predict_system.py \
|
|||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Quick Start using OCR
|
||||
## 4. Quick Start using image format PDF parse
|
||||
|
||||
Through layout analysis, we divided the image/PDF documents into regions, located the key regions, such as text, table, picture, etc., and recorded the location, category, and regional pixel value information of each region. Different regions are processed separately, where:
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
- [2. 安装](#2)
|
||||
- [2.1 安装PaddlePaddle](#2.1)
|
||||
- [2.2 安装PaddleOCR](#2.2)
|
||||
- [3.使用 PDF解析进行版面恢复](#3)
|
||||
- [4. 使用 OCR技术进行版面恢复](#4)
|
||||
- [3.使用标准PDF解析进行版面恢复](#3)
|
||||
- [4. 使用图片格式PDF解析进行版面恢复](#4)
|
||||
- [4.1 下载模型](#4.1)
|
||||
- [4.2 版面恢复](#4.2)
|
||||
- [5. 更多](#5)
|
||||
|
@ -18,17 +18,17 @@
|
|||
|
||||
版面恢复就是将输入的图片、pdf内容仍然像原文档那样排列着,段落不变、顺序不变的输出到word文档中等。
|
||||
|
||||
提供了2种版面恢复方法:
|
||||
提供了2种版面恢复方法,可根据输入PDF的格式进行选择:
|
||||
|
||||
- PDF解析:基于Python的pdf转word库[pdf2docx](https://github.com/dothinking/pdf2docx)进行优化,该方法通过PyMuPDF获取页面元素,然后利用规则解析章节、段落、表格等布局及样式,最后通过python-docx将解析的内容元素重建到word文档中。
|
||||
- OCR技术:结合[版面分析](../layout/README_ch.md)、[表格识别](../table/README_ch.md)技术,从而更好地恢复图片、表格、标题等内容,支持中、英文pdf文档、文档图片格式的输入文件。
|
||||
- **标准PDF解析(输入须为标准PDF)**:基于Python的pdf转word库[pdf2docx](https://github.com/dothinking/pdf2docx)进行优化,该方法通过PyMuPDF获取页面元素,然后利用规则解析章节、段落、表格等布局及样式,最后通过python-docx将解析的内容元素重建到word文档中。
|
||||
- **图片格式PDF解析(输入可为标准PDF或图片格式PDF)**:结合[版面分析](../layout/README_ch.md)、[表格识别](../table/README_ch.md)技术,从而更好地恢复图片、表格、标题等内容,支持中、英文pdf文档、文档图片格式的输入文件。
|
||||
|
||||
2种方法输入格式、适用场景如下:
|
||||
|
||||
| 方法 | 支持输入文件 | 适用场景/存在问题 |
|
||||
| :-----: | :----------: | :----------------------------------------------------------: |
|
||||
| PDF解析 | pdf | 优点:非论文文档恢复效果更优、每一页内容恢复后仍在同一页<br>缺点:有些中文文档中的英文乱码、仍存在内容超出当前页面的情况、整页内容恢复为表格格式、部分图片恢复效果不佳 |
|
||||
| OCR技术 | pdf、图片 | 优点:更适合论文文档正文内容的恢复、中英文文档OCR识别效果好<br>缺点:目前内容恢复基于规则,内容排版效果(间距、字体等)待进一步提升、版面恢复效果依赖于版面分析效果 |
|
||||
| 方法 | 支持输入文件 | 适用场景/存在问题 |
|
||||
| :-------------: | :----------: | :----------------------------------------------------------: |
|
||||
| 标准PDF解析 | pdf | 优点:非论文文档恢复效果更优、每一页内容恢复后仍在同一页<br>缺点:有些中文文档中的英文乱码、仍存在内容超出当前页面的情况、整页内容恢复为表格格式、部分图片恢复效果不佳 |
|
||||
| 图片格式PDF解析 | pdf、图片 | 优点:更适合论文文档正文内容的恢复、中英文文档OCR识别效果好<br>缺点:目前内容恢复基于规则,内容排版效果(间距、字体等)待进一步提升、版面恢复效果依赖于版面分析效果 |
|
||||
|
||||
下图展示了通过PDF解析版面恢复效果:
|
||||
|
||||
|
@ -99,7 +99,7 @@ pip3 install pdf2docx-0.0.0-py3-none-any.whl
|
|||
|
||||
<a name="3"></a>
|
||||
|
||||
## 3.使用 PDF解析进行版面恢复
|
||||
## 3.使用标准PDF解析进行版面恢复
|
||||
|
||||
`use_pdf2docx_api`表示使用PDF解析的方式进行版面恢复,通过whl包的形式方便快速使用,代码如下,更多信息详见 [quickstart](../docs/quickstart.md)。
|
||||
|
||||
|
@ -121,7 +121,7 @@ python3 predict_system.py \
|
|||
|
||||
<a name="4"></a>
|
||||
|
||||
## 4.使用 OCR技术进行版面恢复
|
||||
## 4.使用图片格式PDF解析进行版面恢复
|
||||
|
||||
我们通过版面分析对图片/pdf形式的文档进行区域划分,定位其中的关键区域,如文字、表格、图片等,记录每个区域的位置、类别、区域像素值信息。对不同的区域分别处理,其中:
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ mkdir inference && cd inference
|
|||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar
|
||||
# Download the PP-OCRv3 text recognition model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
|
||||
# Download the PP-Structurev2 form recognition model and unzip it
|
||||
# Download the PP-StructureV2 form recognition model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
|
||||
cd ..
|
||||
# run
|
||||
|
|
|
@ -71,7 +71,7 @@ mkdir inference && cd inference
|
|||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar
|
||||
# 下载PP-OCRv3文本识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
|
||||
# 下载PP-Structurev2中文表格识别模型并解压
|
||||
# 下载PP-StructureV2中文表格识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
|
||||
cd ..
|
||||
# 执行表格识别
|
||||
|
|
|
@ -15,4 +15,5 @@ premailer
|
|||
openpyxl
|
||||
attrdict
|
||||
Polygon3
|
||||
PyMuPDF==1.19.0
|
||||
lanms-neo==1.0.2
|
||||
PyMuPDF==1.19.0
|
|
@ -12,7 +12,7 @@ Global:
|
|||
checkpoints:
|
||||
save_inference_dir: ./output/SLANet/infer
|
||||
use_visualdl: False
|
||||
infer_img: doc/table/table.jpg
|
||||
infer_img: ppstructure/docs/table/table.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:slanet
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:amp
|
||||
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
|
||||
Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./ppstructure/docs/table/table.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o
|
||||
quant_export:
|
||||
fpgm_export:
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_train
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:True
|
||||
--precision:fp32
|
||||
--table_model_dir:
|
||||
--image_dir:./ppstructure/docs/table/table.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,488,488]}]
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:slanet_PACT
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:fp32
|
||||
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=2
|
||||
Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./ppstructure/docs/table/table.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:pact_train
|
||||
norm_train:null
|
||||
pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/slanet/SLANet.yml -o
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:null
|
||||
quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o
|
||||
fpgm_export:
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer
|
||||
infer_export:null
|
||||
infer_quant:True
|
||||
inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--table_model_dir:
|
||||
--image_dir:./ppstructure/docs/table/table.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,488,488]}]
|
|
@ -0,0 +1,21 @@
|
|||
===========================train_params===========================
|
||||
model_name:slanet_KL
|
||||
python:python3.7
|
||||
Global.pretrained_model:
|
||||
Global.save_inference_dir:null
|
||||
infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_infer/
|
||||
infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/slanet/SLANet.yml -o
|
||||
infer_quant:True
|
||||
inference:ppstructure/table/predict_table.py --det_model_dir=./inference/ch_PP-OCRv3_det_infer --rec_model_dir=./inference/ch_PP-OCRv3_rec_infer --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:int8
|
||||
--table_model_dir:
|
||||
--image_dir:./ppstructure/docs/table/table.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
null:null
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:vi_layoutxlm_ser
|
||||
python:python3.7
|
||||
gpu_list:192.168.0.1,192.168.0.2;0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:fp32
|
||||
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
|
||||
Architecture.Backbone.checkpoints:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Architecture.Backbone.checkpoints:
|
||||
norm_export:tools/export_model.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
|
||||
quant_export:
|
||||
fpgm_export:
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:null
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--ser_model_dir:
|
||||
--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,224,224]}]
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:vi_layoutxlm_ser
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:amp
|
||||
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
|
||||
Architecture.Backbone.checkpoints:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Architecture.Backbone.checkpoints:
|
||||
norm_export:tools/export_model.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
|
||||
quant_export:
|
||||
fpgm_export:
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:null
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--ser_model_dir:
|
||||
--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,224,224]}]
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:vi_layoutxlm_ser_PACT
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:fp32
|
||||
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
|
||||
Architecture.Backbone.pretrained:./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:pact_train
|
||||
norm_train:null
|
||||
pact_train:deploy/slim/quantization/quant.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o Global.eval_batch_step=[2000,10]
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Architecture.Backbone.checkpoints:
|
||||
norm_export:null
|
||||
quant_export:deploy/slim/quantization/export_model.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
|
||||
fpgm_export: null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:null
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--ser_model_dir:
|
||||
--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,224,224]}]
|
|
@ -0,0 +1,21 @@
|
|||
===========================train_params===========================
|
||||
model_name:vi_layoutxlm_ser_KL
|
||||
python:python3.7
|
||||
Global.pretrained_model:
|
||||
Global.save_inference_dir:null
|
||||
infer_model:./inference/ser_vi_layoutxlm_xfund_infer/
|
||||
infer_export:deploy/slim/quantization/quant_kl.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o Train.loader.batch_size_per_card=1 Eval.loader.batch_size_per_card=1
|
||||
infer_quant:True
|
||||
inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:int8
|
||||
--ser_model_dir:
|
||||
--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
null:null
|
|
@ -164,7 +164,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
|
|||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
|
||||
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
|
||||
fi
|
||||
if [ ${model_name} == "slanet" ];then
|
||||
if [[ ${model_name} =~ "slanet" ]];then
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate
|
||||
cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../
|
||||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
|
||||
|
@ -267,12 +267,16 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
|
|||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar --no-check-certificate
|
||||
cd ./pretrain_models/ && tar xf ser_LayoutXLM_xfun_zh.tar && cd ../
|
||||
fi
|
||||
if [ ${model_name} == "vi_layoutxlm_ser" ]; then
|
||||
if [[ ${model_name} =~ "vi_layoutxlm_ser" ]]; then
|
||||
${python_name} -m pip install -r ppstructure/kie/requirements.txt
|
||||
${python_name} -m pip install opencv-python -U
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
|
||||
cd ./train_data/ && tar xf XFUND.tar
|
||||
cd ../
|
||||
if [ ${model_name} == "vi_layoutxlm_ser_PACT" ]; then
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar --no-check-certificate
|
||||
cd ./pretrain_models/ && tar xf ser_vi_layoutxlm_xfund_pretrained.tar && cd ../
|
||||
fi
|
||||
fi
|
||||
if [ ${model_name} == "det_r18_ct" ]; then
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
|
||||
|
@ -356,7 +360,8 @@ elif [ ${MODE} = "whole_infer" ];then
|
|||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
|
||||
cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
|
||||
cd ./train_data/ && tar xf XFUND.tar && cd ../
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/pubtabnet.tar --no-check-certificate
|
||||
cd ./train_data/ && tar xf XFUND.tar && tar xf pubtabnet.tar && cd ../
|
||||
head -n 2 train_data/XFUND/zh_val/val.json > train_data/XFUND/zh_val/val_lite.json
|
||||
mv train_data/XFUND/zh_val/val_lite.json train_data/XFUND/zh_val/val.json
|
||||
if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
|
||||
|
@ -532,6 +537,18 @@ elif [ ${MODE} = "whole_infer" ];then
|
|||
fi
|
||||
cd ../
|
||||
fi
|
||||
if [[ ${model_name} =~ "slanet" ]];then
|
||||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
|
||||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
|
||||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
|
||||
cd ./inference/ && tar xf en_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../
|
||||
fi
|
||||
if [[ ${model_name} =~ "vi_layoutxlm_ser" ]]; then
|
||||
${python_name} -m pip install -r ppstructure/kie/requirements.txt
|
||||
${python_name} -m pip install opencv-python -U
|
||||
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar --no-check-certificate
|
||||
cd ./inference/ && tar xf ser_vi_layoutxlm_xfund_infer.tar & cd ../
|
||||
fi
|
||||
if [[ ${model_name} =~ "layoutxlm_ser" ]]; then
|
||||
${python_name} -m pip install -r ppstructure/kie/requirements.txt
|
||||
${python_name} -m pip install opencv-python -U
|
||||
|
|
|
@ -67,6 +67,7 @@ class TextDetector(object):
|
|||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = args.use_dilation
|
||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||
postprocess_params["box_type"] = args.det_box_type
|
||||
elif self.det_algorithm == "DB++":
|
||||
postprocess_params['name'] = 'DBPostProcess'
|
||||
postprocess_params["thresh"] = args.det_db_thresh
|
||||
|
@ -75,6 +76,7 @@ class TextDetector(object):
|
|||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = args.use_dilation
|
||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||
postprocess_params["box_type"] = args.det_box_type
|
||||
pre_process_list[1] = {
|
||||
'NormalizeImage': {
|
||||
'std': [1.0, 1.0, 1.0],
|
||||
|
@ -98,8 +100,8 @@ class TextDetector(object):
|
|||
postprocess_params['name'] = 'SASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
||||
self.det_sast_polygon = args.det_sast_polygon
|
||||
if self.det_sast_polygon:
|
||||
|
||||
if args.det_box_type == 'poly':
|
||||
postprocess_params["sample_pts_num"] = 6
|
||||
postprocess_params["expand_scale"] = 1.2
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.2
|
||||
|
@ -107,14 +109,14 @@ class TextDetector(object):
|
|||
postprocess_params["sample_pts_num"] = 2
|
||||
postprocess_params["expand_scale"] = 1.0
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.3
|
||||
|
||||
elif self.det_algorithm == "PSE":
|
||||
postprocess_params['name'] = 'PSEPostProcess'
|
||||
postprocess_params["thresh"] = args.det_pse_thresh
|
||||
postprocess_params["box_thresh"] = args.det_pse_box_thresh
|
||||
postprocess_params["min_area"] = args.det_pse_min_area
|
||||
postprocess_params["box_type"] = args.det_pse_box_type
|
||||
postprocess_params["box_type"] = args.det_box_type
|
||||
postprocess_params["scale"] = args.det_pse_scale
|
||||
self.det_pse_box_type = args.det_pse_box_type
|
||||
elif self.det_algorithm == "FCE":
|
||||
pre_process_list[0] = {
|
||||
'DetResizeForTest': {
|
||||
|
@ -126,7 +128,7 @@ class TextDetector(object):
|
|||
postprocess_params["alpha"] = args.alpha
|
||||
postprocess_params["beta"] = args.beta
|
||||
postprocess_params["fourier_degree"] = args.fourier_degree
|
||||
postprocess_params["box_type"] = args.det_fce_box_type
|
||||
postprocess_params["box_type"] = args.det_box_type
|
||||
elif self.det_algorithm == "CT":
|
||||
pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
|
||||
postprocess_params['name'] = 'CTPostProcess'
|
||||
|
@ -190,6 +192,8 @@ class TextDetector(object):
|
|||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if type(box) is list:
|
||||
box = np.array(box)
|
||||
box = self.order_points_clockwise(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
||||
|
@ -204,6 +208,8 @@ class TextDetector(object):
|
|||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if type(box) is list:
|
||||
box = np.array(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
|
@ -262,12 +268,10 @@ class TextDetector(object):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
#self.predictor.try_shrink_memory()
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
|
||||
self.det_algorithm in ["PSE", "FCE", "CT"] and
|
||||
self.postprocess_op.box_type == 'poly'):
|
||||
|
||||
if self.args.det_box_type == 'poly':
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
|
|
@ -50,6 +50,7 @@ def init_args():
|
|||
parser.add_argument("--det_model_dir", type=str)
|
||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||
parser.add_argument("--det_box_type", type=str, default='quad')
|
||||
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
|
@ -58,6 +59,7 @@ def init_args():
|
|||
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||
parser.add_argument("--use_dilation", type=str2bool, default=False)
|
||||
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
|
@ -66,13 +68,11 @@ def init_args():
|
|||
# SAST parmas
|
||||
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
|
||||
|
||||
# PSE parmas
|
||||
parser.add_argument("--det_pse_thresh", type=float, default=0)
|
||||
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
|
||||
parser.add_argument("--det_pse_min_area", type=float, default=16)
|
||||
parser.add_argument("--det_pse_box_type", type=str, default='quad')
|
||||
parser.add_argument("--det_pse_scale", type=int, default=1)
|
||||
|
||||
# FCE parmas
|
||||
|
@ -80,7 +80,6 @@ def init_args():
|
|||
parser.add_argument("--alpha", type=float, default=1.0)
|
||||
parser.add_argument("--beta", type=float, default=1.0)
|
||||
parser.add_argument("--fourier_degree", type=int, default=5)
|
||||
parser.add_argument("--det_fce_box_type", type=str, default='poly')
|
||||
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
|
||||
|
|
|
@ -220,7 +220,7 @@ def train(config,
|
|||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
||||
"RobustScanner", "RFL"
|
||||
"RobustScanner", "RFL", 'DRRG'
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
|
@ -629,7 +629,7 @@ def preprocess(is_train=False):
|
|||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue