add vqa_ser to ppstructure predict pipeline
parent
585dbc3016
commit
e16ae81e15
|
@ -153,7 +153,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_in
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
|
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
|
||||||
```
|
```
|
||||||
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
|
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,37 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
|
||||||
- 支持表格区域进行结构化分析,最终结果输出Excel文件
|
- 支持表格区域进行结构化分析,最终结果输出Excel文件
|
||||||
- 支持python whl包和命令行两种方式,简单易用
|
- 支持python whl包和命令行两种方式,简单易用
|
||||||
- 支持版面分析和表格结构化两类任务自定义训练
|
- 支持版面分析和表格结构化两类任务自定义训练
|
||||||
|
- 支持文档关键信息提取-SER和RE任务
|
||||||
|
|
||||||
|
|
||||||
## 1. 效果展示
|
## 1. 效果展示
|
||||||
|
|
||||||
|
### 1.1 版面分析和表格识别
|
||||||
|
|
||||||
<img src="../doc/table/ppstructure.GIF" width="100%"/>
|
<img src="../doc/table/ppstructure.GIF" width="100%"/>
|
||||||
|
|
||||||
|
### 1.2 VQA
|
||||||
|
|
||||||
|
* SER
|
||||||
|
|
||||||
|
 | 
|
||||||
|
---|---
|
||||||
|
|
||||||
|
图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
|
||||||
|
|
||||||
|
* 深紫色:HEADER
|
||||||
|
* 浅紫色:QUESTION
|
||||||
|
* 军绿色:ANSWER
|
||||||
|
|
||||||
|
在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
|
||||||
|
|
||||||
|
* RE
|
||||||
|
|
||||||
|
 | 
|
||||||
|
---|---
|
||||||
|
|
||||||
|
|
||||||
|
图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
|
||||||
|
|
||||||
|
|
||||||
## 2. 安装
|
## 2. 安装
|
||||||
|
@ -33,10 +59,16 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
|
||||||
```
|
```
|
||||||
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||||
|
|
||||||
- **(2) 安装 Layout-Parser**
|
- **(2) 安装依赖 **
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# 版面分析所需 Layout-Parser
|
||||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||||
|
|
||||||
|
# VQA所需 PaddleNLP
|
||||||
|
git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
|
||||||
|
cd PaddleNLP
|
||||||
|
pip3 install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
|
### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
|
||||||
|
@ -44,7 +76,7 @@ pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-
|
||||||
- **(1) PIP快速安装PaddleOCR whl包(仅预测)**
|
- **(1) PIP快速安装PaddleOCR whl包(仅预测)**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install "paddleocr>=2.2" # 推荐使用2.2+版本
|
pip3 install "paddleocr>=2.2" # 推荐使用2.2+版本
|
||||||
```
|
```
|
||||||
|
|
||||||
- **(2) 完整克隆PaddleOCR源码(预测+训练)**
|
- **(2) 完整克隆PaddleOCR源码(预测+训练)**
|
||||||
|
@ -63,12 +95,14 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
|
||||||
|
|
||||||
### 3.1 命令行使用(默认参数,极简)
|
### 3.1 命令行使用(默认参数,极简)
|
||||||
|
|
||||||
|
* 版面分析+表格识别
|
||||||
```bash
|
```bash
|
||||||
paddleocr --image_dir=../doc/table/1.png --type=structure
|
paddleocr --image_dir=../doc/table/1.png --type=structure
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.2 Python脚本使用(自定义参数,灵活)
|
### 3.2 Python脚本使用(自定义参数,灵活)
|
||||||
|
|
||||||
|
* 版面分析+表格识别
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
@ -98,6 +132,7 @@ im_show.save('result.jpg')
|
||||||
### 3.3 返回结果说明
|
### 3.3 返回结果说明
|
||||||
PP-Structure的返回结果为一个dict组成的list,示例如下
|
PP-Structure的返回结果为一个dict组成的list,示例如下
|
||||||
|
|
||||||
|
* 版面分析+表格识别
|
||||||
```shell
|
```shell
|
||||||
[
|
[
|
||||||
{ 'type': 'Text',
|
{ 'type': 'Text',
|
||||||
|
@ -130,7 +165,7 @@ dict 里各个字段说明如下
|
||||||
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
|
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
|
||||||
|
|
||||||
|
|
||||||
## 4. PP-Structure Pipeline介绍
|
## 4. PP-Structure 版面分析+表格识别 Pipeline介绍
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
@ -148,6 +183,8 @@ dict 里各个字段说明如下
|
||||||
|
|
||||||
使用如下命令即可完成预测引擎的推理
|
使用如下命令即可完成预测引擎的推理
|
||||||
|
|
||||||
|
* 版面分析+表格识别
|
||||||
|
|
||||||
```python
|
```python
|
||||||
cd ppstructure
|
cd ppstructure
|
||||||
|
|
||||||
|
@ -161,9 +198,24 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_in
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
|
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
|
||||||
```
|
```
|
||||||
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
|
运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
|
||||||
|
|
||||||
|
* VQA
|
||||||
|
|
||||||
|
```python
|
||||||
|
cd ppstructure
|
||||||
|
|
||||||
|
# 下载模型
|
||||||
|
mkdir inference && cd inference
|
||||||
|
# 下载SER xfun 模型并解压
|
||||||
|
wget https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar && tar xf PP-Layout_v1.0_ser_pretrained.tar
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
python3 predict_system.py --model_name_or_path=vqa/PP-Layout_v1.0_ser_pretrained/ --mode=vqa --image_dir=vqa/images/input/zh_val_0.jpg --vis_font_path=../doc/fonts/simfang.ttf
|
||||||
|
```
|
||||||
|
运行完成后,每张图片会在`output`字段指定的目录下的`vqa`目录下存放可视化之后的图片,图片名和输入图片名一致。
|
||||||
|
|
||||||
**Model List**
|
**Model List**
|
||||||
|
|
||||||
|
@ -185,4 +237,11 @@ OCR和表格识别模型
|
||||||
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|
||||||
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||||
|
|
||||||
|
VQA
|
||||||
|
|
||||||
|
|模型名称|模型简介|推理模型大小|下载地址|
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
|
||||||
|
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
|
||||||
|
|
||||||
如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
|
如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
|
||||||
|
|
|
@ -30,6 +30,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from tools.infer.predict_system import TextSystem
|
from tools.infer.predict_system import TextSystem
|
||||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||||
|
from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
|
||||||
from ppstructure.utility import parse_args, draw_structure_result
|
from ppstructure.utility import parse_args, draw_structure_result
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
@ -37,53 +38,75 @@ logger = get_logger()
|
||||||
|
|
||||||
class OCRSystem(object):
|
class OCRSystem(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
import layoutparser as lp
|
self.mode = args.mode
|
||||||
# args.det_limit_type = 'resize_long'
|
if self.mode == 'structure':
|
||||||
args.drop_score = 0
|
import layoutparser as lp
|
||||||
if not args.show_log:
|
# args.det_limit_type = 'resize_long'
|
||||||
logger.setLevel(logging.INFO)
|
args.drop_score = 0
|
||||||
self.text_system = TextSystem(args)
|
if not args.show_log:
|
||||||
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
logger.setLevel(logging.INFO)
|
||||||
|
self.text_system = TextSystem(args)
|
||||||
|
self.table_system = TableSystem(args,
|
||||||
|
self.text_system.text_detector,
|
||||||
|
self.text_system.text_recognizer)
|
||||||
|
|
||||||
config_path = None
|
config_path = None
|
||||||
model_path = None
|
model_path = None
|
||||||
if os.path.isdir(args.layout_path_model):
|
if os.path.isdir(args.layout_path_model):
|
||||||
model_path = args.layout_path_model
|
model_path = args.layout_path_model
|
||||||
else:
|
else:
|
||||||
config_path = args.layout_path_model
|
config_path = args.layout_path_model
|
||||||
self.table_layout = lp.PaddleDetectionLayoutModel(config_path=config_path,
|
self.table_layout = lp.PaddleDetectionLayoutModel(
|
||||||
model_path=model_path,
|
config_path=config_path,
|
||||||
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
model_path=model_path,
|
||||||
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
|
threshold=0.5,
|
||||||
self.use_angle_cls = args.use_angle_cls
|
enable_mkldnn=args.enable_mkldnn,
|
||||||
self.drop_score = args.drop_score
|
enforce_cpu=not args.use_gpu,
|
||||||
|
thread_num=args.cpu_threads)
|
||||||
|
self.use_angle_cls = args.use_angle_cls
|
||||||
|
self.drop_score = args.drop_score
|
||||||
|
elif self.mode == 'vqa':
|
||||||
|
self.vqa_engine = SerPredictor(args)
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
if self.mode == 'structure':
|
||||||
layout_res = self.table_layout.detect(img[..., ::-1])
|
ori_im = img.copy()
|
||||||
res_list = []
|
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||||
for region in layout_res:
|
res_list = []
|
||||||
x1, y1, x2, y2 = region.coordinates
|
for region in layout_res:
|
||||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
x1, y1, x2, y2 = region.coordinates
|
||||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||||
if region.type == 'Table':
|
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||||
res = self.table_system(roi_img)
|
if region.type == 'Table':
|
||||||
else:
|
res = self.table_system(roi_img)
|
||||||
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
else:
|
||||||
filter_boxes = [x + [x1, y1] for x in filter_boxes]
|
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||||
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
|
filter_boxes = [x + [x1, y1] for x in filter_boxes]
|
||||||
# remove style char
|
filter_boxes = [
|
||||||
style_token = ['<strike>', '<strike>', '<sup>', '</sub>', '<b>', '</b>', '<sub>', '</sup>',
|
x.reshape(-1).tolist() for x in filter_boxes
|
||||||
'<overline>', '</overline>', '<underline>', '</underline>', '<i>', '</i>']
|
]
|
||||||
filter_rec_res_tmp = []
|
# remove style char
|
||||||
for rec_res in filter_rec_res:
|
style_token = [
|
||||||
rec_str, rec_conf = rec_res
|
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
|
||||||
for token in style_token:
|
'</b>', '<sub>', '</sup>', '<overline>', '</overline>',
|
||||||
if token in rec_str:
|
'<underline>', '</underline>', '<i>', '</i>'
|
||||||
rec_str = rec_str.replace(token, '')
|
]
|
||||||
filter_rec_res_tmp.append((rec_str, rec_conf))
|
filter_rec_res_tmp = []
|
||||||
res = (filter_boxes, filter_rec_res_tmp)
|
for rec_res in filter_rec_res:
|
||||||
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'img': roi_img, 'res': res})
|
rec_str, rec_conf = rec_res
|
||||||
|
for token in style_token:
|
||||||
|
if token in rec_str:
|
||||||
|
rec_str = rec_str.replace(token, '')
|
||||||
|
filter_rec_res_tmp.append((rec_str, rec_conf))
|
||||||
|
res = (filter_boxes, filter_rec_res_tmp)
|
||||||
|
res_list.append({
|
||||||
|
'type': region.type,
|
||||||
|
'bbox': [x1, y1, x2, y2],
|
||||||
|
'img': roi_img,
|
||||||
|
'res': res
|
||||||
|
})
|
||||||
|
elif self.mode == 'vqa':
|
||||||
|
res_list, _ = self.vqa_engine(img)
|
||||||
return res_list
|
return res_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,29 +114,35 @@ def save_structure_res(res, save_folder, img_name):
|
||||||
excel_save_folder = os.path.join(save_folder, img_name)
|
excel_save_folder = os.path.join(save_folder, img_name)
|
||||||
os.makedirs(excel_save_folder, exist_ok=True)
|
os.makedirs(excel_save_folder, exist_ok=True)
|
||||||
# save res
|
# save res
|
||||||
with open(os.path.join(excel_save_folder, 'res.txt'), 'w', encoding='utf8') as f:
|
with open(
|
||||||
|
os.path.join(excel_save_folder, 'res.txt'), 'w',
|
||||||
|
encoding='utf8') as f:
|
||||||
for region in res:
|
for region in res:
|
||||||
if region['type'] == 'Table':
|
if region['type'] == 'Table':
|
||||||
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
excel_path = os.path.join(excel_save_folder,
|
||||||
|
'{}.xlsx'.format(region['bbox']))
|
||||||
to_excel(region['res'], excel_path)
|
to_excel(region['res'], excel_path)
|
||||||
if region['type'] == 'Figure':
|
if region['type'] == 'Figure':
|
||||||
roi_img = region['img']
|
roi_img = region['img']
|
||||||
img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox']))
|
img_path = os.path.join(excel_save_folder,
|
||||||
|
'{}.jpg'.format(region['bbox']))
|
||||||
cv2.imwrite(img_path, roi_img)
|
cv2.imwrite(img_path, roi_img)
|
||||||
else:
|
else:
|
||||||
for box, rec_res in zip(region['res'][0], region['res'][1]):
|
for box, rec_res in zip(region['res'][0], region['res'][1]):
|
||||||
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
f.write('{}\t{}\n'.format(
|
||||||
|
np.array(box).reshape(-1).tolist(), rec_res))
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
image_file_list = image_file_list
|
image_file_list = image_file_list
|
||||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||||
save_folder = args.output
|
|
||||||
os.makedirs(save_folder, exist_ok=True)
|
|
||||||
|
|
||||||
structure_sys = OCRSystem(args)
|
structure_sys = OCRSystem(args)
|
||||||
img_num = len(image_file_list)
|
img_num = len(image_file_list)
|
||||||
|
save_folder = os.path.join(args.output, structure_sys.mode)
|
||||||
|
os.makedirs(save_folder, exist_ok=True)
|
||||||
|
|
||||||
for i, image_file in enumerate(image_file_list):
|
for i, image_file in enumerate(image_file_list):
|
||||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
@ -126,10 +155,16 @@ def main(args):
|
||||||
continue
|
continue
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
res = structure_sys(img)
|
res = structure_sys(img)
|
||||||
save_structure_res(res, save_folder, img_name)
|
|
||||||
draw_img = draw_structure_result(img, res, args.vis_font_path)
|
if structure_sys.mode == 'structure':
|
||||||
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
|
save_structure_res(res, save_folder, img_name)
|
||||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
draw_img = draw_structure_result(img, res, args.vis_font_path)
|
||||||
|
img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
|
||||||
|
elif structure_sys.mode == 'vqa':
|
||||||
|
draw_img = draw_ser_results(img, res, args.vis_font_path)
|
||||||
|
img_save_path = os.path.join(save_folder, img_name + '.jpg')
|
||||||
|
cv2.imwrite(img_save_path, draw_img)
|
||||||
|
logger.info('result save to {}'.format(img_save_path))
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@ We evaluated the algorithm on the PubTabNet<sup>[1]</sup> eval dataset, and the
|
||||||
|
|
||||||
|
|
||||||
|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
|
|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| EDD<sup>[2]</sup> | 88.3 |
|
| EDD<sup>[2]</sup> | 88.3 |
|
||||||
| Ours | 93.32 |
|
| Ours | 93.32 |
|
||||||
|
|
||||||
## 3. How to use
|
## 3. How to use
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||||
cd ..
|
cd ..
|
||||||
# run
|
# run
|
||||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||||
```
|
```
|
||||||
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
|
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
|
||||||
|
|
||||||
|
@ -82,8 +82,8 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
|
||||||
The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
|
The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
|
||||||
```json
|
```json
|
||||||
{"PMC4289340_004_00.png": [
|
{"PMC4289340_004_00.png": [
|
||||||
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
|
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
|
||||||
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
|
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
|
||||||
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
|
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
|
||||||
]}
|
]}
|
||||||
```
|
```
|
||||||
|
@ -95,7 +95,7 @@ In gt json, the key is the image name, the value is the corresponding gt, and gt
|
||||||
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
|
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
|
||||||
```python
|
```python
|
||||||
cd PaddleOCR/ppstructure
|
cd PaddleOCR/ppstructure
|
||||||
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||||
```
|
```
|
||||||
|
|
||||||
If the PubLatNet eval dataset is used, it will be output
|
If the PubLatNet eval dataset is used, it will be output
|
||||||
|
@ -113,4 +113,4 @@ After running, the excel sheet of each picture will be saved in the directory sp
|
||||||
|
|
||||||
Reference
|
Reference
|
||||||
1. https://github.com/ibm-aur-nlp/PubTabNet
|
1. https://github.com/ibm-aur-nlp/PubTabNet
|
||||||
2. https://arxiv.org/pdf/1911.10683
|
2. https://arxiv.org/pdf/1911.10683
|
||||||
|
|
|
@ -34,9 +34,9 @@
|
||||||
|
|
||||||
|
|
||||||
|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
|
|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| EDD<sup>[2]</sup> | 88.3 |
|
| EDD<sup>[2]</sup> | 88.3 |
|
||||||
| Ours | 93.32 |
|
| Ours | 93.32 |
|
||||||
|
|
||||||
<a name="3"></a>
|
<a name="3"></a>
|
||||||
## 3. 使用
|
## 3. 使用
|
||||||
|
@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||||
cd ..
|
cd ..
|
||||||
# 执行预测
|
# 执行预测
|
||||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||||
```
|
```
|
||||||
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
|
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
|
||||||
|
|
||||||
|
@ -94,8 +94,8 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
|
||||||
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
|
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
|
||||||
```json
|
```json
|
||||||
{"PMC4289340_004_00.png": [
|
{"PMC4289340_004_00.png": [
|
||||||
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
|
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
|
||||||
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
|
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
|
||||||
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
|
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
|
||||||
]}
|
]}
|
||||||
```
|
```
|
||||||
|
@ -107,7 +107,7 @@ json 中,key为图片名,value为对应的gt,gt是一个由三个item组
|
||||||
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
|
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
|
||||||
```python
|
```python
|
||||||
cd PaddleOCR/ppstructure
|
cd PaddleOCR/ppstructure
|
||||||
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||||
```
|
```
|
||||||
如使用PubLatNet评估数据集,将会输出
|
如使用PubLatNet评估数据集,将会输出
|
||||||
```bash
|
```bash
|
||||||
|
@ -123,4 +123,4 @@ python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model
|
||||||
|
|
||||||
Reference
|
Reference
|
||||||
1. https://github.com/ibm-aur-nlp/PubTabNet
|
1. https://github.com/ibm-aur-nlp/PubTabNet
|
||||||
2. https://arxiv.org/pdf/1911.10683
|
2. https://arxiv.org/pdf/1911.10683
|
||||||
|
|
|
@ -21,13 +21,31 @@ def init_args():
|
||||||
parser = infer_args()
|
parser = infer_args()
|
||||||
|
|
||||||
# params for output
|
# params for output
|
||||||
parser.add_argument("--output", type=str, default='./output/table')
|
parser.add_argument("--output", type=str, default='./output')
|
||||||
# params for table structure
|
# params for table structure
|
||||||
parser.add_argument("--table_max_len", type=int, default=488)
|
parser.add_argument("--table_max_len", type=int, default=488)
|
||||||
parser.add_argument("--table_model_dir", type=str)
|
parser.add_argument("--table_model_dir", type=str)
|
||||||
parser.add_argument("--table_char_type", type=str, default='en')
|
parser.add_argument("--table_char_type", type=str, default='en')
|
||||||
parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
|
parser.add_argument(
|
||||||
parser.add_argument("--layout_path_model", type=str, default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
|
"--table_char_dict_path",
|
||||||
|
type=str,
|
||||||
|
default="../ppocr/utils/dict/table_structure_dict.txt")
|
||||||
|
parser.add_argument(
|
||||||
|
"--layout_path_model",
|
||||||
|
type=str,
|
||||||
|
default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
|
||||||
|
|
||||||
|
# params for ser
|
||||||
|
parser.add_argument("--model_name_or_path", type=str)
|
||||||
|
parser.add_argument("--max_seq_length", type=int, default=512)
|
||||||
|
parser.add_argument(
|
||||||
|
"--label_map_path", type=str, default='./vqa/labels/labels_ser.txt')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
default='structure',
|
||||||
|
help='structure and vqa is supported')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,5 +66,6 @@ def draw_structure_result(image, result, font_path):
|
||||||
boxes.append(np.array(box).reshape(-1, 2))
|
boxes.append(np.array(box).reshape(-1, 2))
|
||||||
txts.append(rec_res[0])
|
txts.append(rec_res[0])
|
||||||
scores.append(rec_res[1])
|
scores.append(rec_res[1])
|
||||||
im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
|
im_show = draw_ocr_box_txt(
|
||||||
return im_show
|
image, boxes, txts, scores, font_path=font_path, drop_score=0)
|
||||||
|
return im_show
|
||||||
|
|
|
@ -23,12 +23,10 @@ from PIL import Image
|
||||||
import paddle
|
import paddle
|
||||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||||||
|
|
||||||
from paddleocr import PaddleOCR
|
|
||||||
|
|
||||||
# relative reference
|
# relative reference
|
||||||
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
|
from .utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
|
||||||
|
|
||||||
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
|
from .utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
|
||||||
|
|
||||||
|
|
||||||
def trans_poly_to_bbox(poly):
|
def trans_poly_to_bbox(poly):
|
||||||
|
@ -52,6 +50,7 @@ def parse_ocr_info_for_ser(ocr_result):
|
||||||
|
|
||||||
class SerPredictor(object):
|
class SerPredictor(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
|
||||||
self.max_seq_length = args.max_seq_length
|
self.max_seq_length = args.max_seq_length
|
||||||
|
|
||||||
# init ser token and model
|
# init ser token and model
|
||||||
|
@ -62,9 +61,11 @@ class SerPredictor(object):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# init ocr_engine
|
# init ocr_engine
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
|
||||||
self.ocr_engine = PaddleOCR(
|
self.ocr_engine = PaddleOCR(
|
||||||
rec_model_dir=args.ocr_rec_model_dir,
|
rec_model_dir=args.rec_model_dir,
|
||||||
det_model_dir=args.ocr_det_model_dir,
|
det_model_dir=args.det_model_dir,
|
||||||
use_angle_cls=False,
|
use_angle_cls=False,
|
||||||
show_log=False)
|
show_log=False)
|
||||||
# init dict
|
# init dict
|
||||||
|
|
|
@ -380,8 +380,8 @@ def parse_args():
|
||||||
parser.add_argument("--seed", type=int, default=2048,
|
parser.add_argument("--seed", type=int, default=2048,
|
||||||
help="random seed for initialization",)
|
help="random seed for initialization",)
|
||||||
|
|
||||||
parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
|
parser.add_argument("--rec_model_dir", default=None, type=str, )
|
||||||
parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
|
parser.add_argument("--det_model_dir", default=None, type=str, )
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
|
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
|
||||||
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
|
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
|
||||||
|
|
Loading…
Reference in New Issue