add re
|
@ -1,42 +1,62 @@
|
|||
# 视觉问答(VQA)
|
||||
# 文档视觉问答(DOC-VQA)
|
||||
|
||||
VQA主要特性如下:
|
||||
DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。
|
||||
|
||||
PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。
|
||||
|
||||
主要特性如下:
|
||||
|
||||
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
|
||||
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对)
|
||||
- 支持SER任务与OCR引擎联合的端到端系统预测与评估。
|
||||
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对
|
||||
- 支持SER任务和RE任务的自定义训练
|
||||
- 支持OCR+SER的端到端系统预测与评估。
|
||||
- 支持OCR+SER+RE的端到端系统预测。
|
||||
|
||||
|
||||
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
|
||||
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
|
||||
|
||||
## 1. 效果演示
|
||||
## 1 性能
|
||||
|
||||
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
|
||||
|
||||
|任务| Hmean| 模型下载地址|
|
||||
|:---:|:---:| :---:|
|
||||
|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
|
||||
|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
|
||||
|
||||
|
||||
|
||||
## 2. 效果演示
|
||||
|
||||
**注意:** 测试图片来源于XFUN数据集。
|
||||
|
||||
### 1.1 SER
|
||||
### 2.1 SER
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/result_ser/zh_val_0_ser.jpg" width = "600" />
|
||||
</div>
|
||||
 | 
|
||||
---|---
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/result_ser/zh_val_42_ser.jpg" width = "600" />
|
||||
</div>
|
||||
图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
|
||||
|
||||
其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
|
||||
* 深紫色:HEADER
|
||||
* 浅紫色:QUESTION
|
||||
* 军绿色:ANSWER
|
||||
|
||||
在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
|
||||
|
||||
|
||||
### 1.2 RE
|
||||
### 2.2 RE
|
||||
|
||||
* Coming soon!
|
||||
 | 
|
||||
---|---
|
||||
|
||||
|
||||
图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
|
||||
|
||||
## 2. 安装
|
||||
|
||||
### 2.1 安装依赖
|
||||
## 3. 安装
|
||||
|
||||
### 3.1 安装依赖
|
||||
|
||||
- **(1) 安装PaddlePaddle**
|
||||
|
||||
|
@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
|
|||
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||
|
||||
|
||||
### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
|
||||
### 3.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
|
||||
|
||||
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
|
||||
|
||||
```bash
|
||||
pip install "paddleocr>=2.2" # 推荐使用2.2+版本
|
||||
pip install paddleocr
|
||||
```
|
||||
|
||||
- **(2)下载VQA源码(预测+训练)**
|
||||
|
@ -85,13 +105,14 @@ pip install -e .
|
|||
- **(4)安装VQA的`requirements`**
|
||||
|
||||
```bash
|
||||
cd ppstructure/vqa
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 3. 使用
|
||||
## 4. 使用
|
||||
|
||||
|
||||
### 3.1 数据和预训练模型准备
|
||||
### 4.1 数据和预训练模型准备
|
||||
|
||||
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
|
||||
|
||||
|
@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
|
|||
|
||||
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
|
||||
|
||||
如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。
|
||||
|
||||
* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
|
||||
* RE任务预训练模型下载链接:coming soon!
|
||||
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
|
||||
|
||||
|
||||
### 3.2 SER任务
|
||||
### 4.2 SER任务
|
||||
|
||||
* 启动训练
|
||||
|
||||
```shell
|
||||
python train_ser.py \
|
||||
python3.7 train_ser.py \
|
||||
--model_name_or_path "layoutxlm-base-uncased" \
|
||||
--train_data_dir "XFUND/zh_train/image" \
|
||||
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
|
||||
|
@ -131,13 +149,7 @@ python train_ser.py \
|
|||
--seed 2048
|
||||
```
|
||||
|
||||
最终会打印出`precision`, `recall`, `f1`等指标,如下所示。
|
||||
|
||||
```
|
||||
best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
|
||||
```
|
||||
|
||||
模型和训练日志会保存在`./output/ser/`文件夹中。
|
||||
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
|
||||
|
||||
* 使用评估集合中提供的OCR识别结果进行预测
|
||||
|
||||
|
@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0
|
|||
python3.7 infer_ser_e2e.py \
|
||||
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
|
||||
--max_seq_length 512 \
|
||||
--output_dir "output_res_e2e/"
|
||||
--output_dir "output_res_e2e/" \
|
||||
--infer_imgs "images/input/zh_val_0.jpg"
|
||||
```
|
||||
|
||||
* 对`OCR引擎 + SER`预测系统进行端到端评估
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
|
||||
python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
|
||||
```
|
||||
|
||||
|
||||
3.3 RE任务
|
||||
### 3.3 RE任务
|
||||
|
||||
coming soon!
|
||||
* 启动训练
|
||||
|
||||
```shell
|
||||
python3 train_re.py \
|
||||
--model_name_or_path "layoutxlm-base-uncased" \
|
||||
--train_data_dir "XFUND/zh_train/image" \
|
||||
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path 'labels/labels_ser.txt' \
|
||||
--num_train_epochs 2 \
|
||||
--eval_steps 10 \
|
||||
--save_steps 500 \
|
||||
--output_dir "output/re/" \
|
||||
--learning_rate 5e-5 \
|
||||
--warmup_steps 50 \
|
||||
--per_gpu_train_batch_size 8 \
|
||||
--per_gpu_eval_batch_size 8 \
|
||||
--evaluate_during_training \
|
||||
--seed 2048
|
||||
|
||||
```
|
||||
|
||||
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
|
||||
|
||||
* 使用评估集合中提供的OCR识别结果进行预测
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 infer_re.py \
|
||||
--model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
|
||||
--max_seq_length 512 \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path 'labels/labels_ser.txt' \
|
||||
--output_dir "output_res" \
|
||||
--per_gpu_eval_batch_size 1 \
|
||||
--seed 2048
|
||||
```
|
||||
|
||||
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
|
||||
|
||||
* 使用`OCR引擎 + SER + RE`串联结果
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# python3.7 infer_ser_re_e2e.py \
|
||||
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
|
||||
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
|
||||
--max_seq_length 512 \
|
||||
--output_dir "output_ser_re_e2e_train/" \
|
||||
--infer_imgs "images/input/zh_val_21.jpg"
|
||||
```
|
||||
|
||||
## 参考链接
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import numbers
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataCollator:
|
||||
def __call__(self, batch):
|
||||
data_dict = {}
|
||||
to_tensor_keys = []
|
||||
for sample in batch:
|
||||
for k, v in sample.items():
|
||||
if k not in data_dict:
|
||||
data_dict[k] = []
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if k not in to_tensor_keys:
|
||||
to_tensor_keys.append(k)
|
||||
data_dict[k].append(v)
|
||||
for k in to_tensor_keys:
|
||||
data_dict[k] = paddle.to_tensor(data_dict[k])
|
||||
return data_dict
|
||||
|
||||
|
||||
class DataCollatorNoBatch:
|
||||
def __call__(self, batch):
|
||||
return batch[0]
|
After Width: | Height: | Size: 1.4 MiB |
After Width: | Height: | Size: 1.1 MiB |
After Width: | Height: | Size: 1.1 MiB |
After Width: | Height: | Size: 1005 KiB |
Before Width: | Height: | Size: 1.2 MiB After Width: | Height: | Size: 1.2 MiB |
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
|
@ -0,0 +1,162 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from utils import parse_args, get_bio_label_maps, draw_re_results
|
||||
from data_collator import DataCollator
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def infer(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logger = get_logger()
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
model = LayoutXLMForRelationExtraction.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=8,
|
||||
shuffle=False,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
# 读取gt的oct数据
|
||||
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
|
||||
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
|
||||
with paddle.no_grad():
|
||||
outputs = model(**batch)
|
||||
pred_relations = outputs['pred_relations']
|
||||
|
||||
ocr_info = ocr_info_list[idx]
|
||||
image_path = ocr_info['image_path']
|
||||
ocr_info = ocr_info['ocr_info']
|
||||
|
||||
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
|
||||
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
|
||||
|
||||
# 进行 relations 到 ocr信息的转换
|
||||
result = []
|
||||
used_tail_id = []
|
||||
for relations in pred_relations:
|
||||
for relation in relations:
|
||||
if relation['tail_id'] in used_tail_id:
|
||||
continue
|
||||
if relation['head_id'] not in ocr_info or relation[
|
||||
'tail_id'] not in ocr_info:
|
||||
continue
|
||||
used_tail_id.append(relation['tail_id'])
|
||||
ocr_info_head = ocr_info[relation['head_id']]
|
||||
ocr_info_tail = ocr_info[relation['tail_id']]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
|
||||
img = cv2.imread(image_path)
|
||||
img_show = draw_re_results(img, result)
|
||||
save_path = os.path.join(args.output_dir, os.path.basename(image_path))
|
||||
cv2.imwrite(save_path, img_show)
|
||||
|
||||
|
||||
def load_ocr(img_folder, json_path):
|
||||
import json
|
||||
d = []
|
||||
with open(json_path, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
image_name, info_str = line.split("\t")
|
||||
info_dict = json.loads(info_str)
|
||||
info_dict['image_path'] = os.path.join(img_folder, image_name)
|
||||
d.append(info_dict)
|
||||
return d
|
||||
|
||||
|
||||
def filter_bg_by_txt(ocr_info, batch, tokenizer):
|
||||
entities = batch['entities'][0]
|
||||
input_ids = batch['input_ids'][0]
|
||||
|
||||
new_info_dict = {}
|
||||
for i in range(len(entities['start'])):
|
||||
entitie_head = entities['start'][i]
|
||||
entitie_tail = entities['end'][i]
|
||||
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
|
||||
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
|
||||
txt = tokenizer.convert_tokens_to_string(txt)
|
||||
|
||||
for i, info in enumerate(ocr_info):
|
||||
if info['text'] == txt:
|
||||
new_info_dict[i] = info
|
||||
return new_info_dict
|
||||
|
||||
|
||||
def post_process(pred_relations, ocr_info, img):
|
||||
result = []
|
||||
for relations in pred_relations:
|
||||
for relation in relations:
|
||||
ocr_info_head = ocr_info[relation['head_id']]
|
||||
ocr_info_tail = ocr_info[relation['tail_id']]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
return result
|
||||
|
||||
|
||||
def draw_re(result, image_path, output_folder):
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
for ocr_info_head, ocr_info_tail in result:
|
||||
cv2.rectangle(
|
||||
img,
|
||||
tuple(ocr_info_head['bbox'][:2]),
|
||||
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
|
||||
thickness=2)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
tuple(ocr_info_tail['bbox'][:2]),
|
||||
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
|
||||
thickness=2)
|
||||
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
|
||||
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
|
||||
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
|
||||
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
|
||||
cv2.line(
|
||||
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
|
||||
plt.imshow(img)
|
||||
plt.savefig(
|
||||
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
|
||||
# plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
infer(args)
|
|
@ -23,8 +23,10 @@ from PIL import Image
|
|||
import paddle
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
# relative reference
|
||||
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine
|
||||
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
|
||||
|
||||
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
|
||||
|
||||
|
@ -48,74 +50,82 @@ def parse_ocr_info_for_ser(ocr_result):
|
|||
return ocr_info
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
class SerPredictor(object):
|
||||
def __init__(self, args):
|
||||
self.max_seq_length = args.max_seq_length
|
||||
|
||||
# init token and model
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = LayoutXLMForTokenClassification.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
model.eval()
|
||||
# init ser token and model
|
||||
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
self.model = LayoutXLMForTokenClassification.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
self.model.eval()
|
||||
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
label2id_map_for_draw = dict()
|
||||
for key in label2id_map:
|
||||
if key.startswith("I-"):
|
||||
label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
|
||||
else:
|
||||
label2id_map_for_draw[key] = label2id_map[key]
|
||||
# init ocr_engine
|
||||
self.ocr_engine = PaddleOCR(
|
||||
rec_model_dir=args.ocr_rec_model_dir,
|
||||
det_model_dir=args.ocr_det_model_dir,
|
||||
use_angle_cls=False,
|
||||
show_log=False)
|
||||
# init dict
|
||||
label2id_map, self.id2label_map = get_bio_label_maps(
|
||||
args.label_map_path)
|
||||
self.label2id_map_for_draw = dict()
|
||||
for key in label2id_map:
|
||||
if key.startswith("I-"):
|
||||
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
|
||||
else:
|
||||
self.label2id_map_for_draw[key] = label2id_map[key]
|
||||
|
||||
# get infer img list
|
||||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||||
def __call__(self, img):
|
||||
ocr_result = self.ocr_engine.ocr(img, cls=False)
|
||||
|
||||
ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
|
||||
args.ocr_det_model_dir)
|
||||
ocr_info = parse_ocr_info_for_ser(ocr_result)
|
||||
|
||||
# loop for infer
|
||||
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
|
||||
inputs = preprocess(
|
||||
tokenizer=self.tokenizer,
|
||||
ori_img=img,
|
||||
ocr_info=ocr_info,
|
||||
max_seq_len=self.max_seq_length)
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
image=inputs["image"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
|
||||
ocr_result = ocr_engine.ocr(img_path, cls=False)
|
||||
|
||||
ocr_info = parse_ocr_info_for_ser(ocr_result)
|
||||
|
||||
inputs = preprocess(
|
||||
tokenizer=tokenizer,
|
||||
ori_img=img,
|
||||
ocr_info=ocr_info,
|
||||
max_seq_len=args.max_seq_length)
|
||||
|
||||
outputs = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
image=inputs["image"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
|
||||
preds = outputs[0]
|
||||
preds = postprocess(inputs["attention_mask"], preds, id2label_map)
|
||||
ocr_info = merge_preds_list_with_ocr_info(
|
||||
ocr_info, inputs["segment_offset_id"], preds,
|
||||
label2id_map_for_draw)
|
||||
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"ocr_info": ocr_info,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_ser_results(img, ocr_info)
|
||||
cv2.imwrite(
|
||||
os.path.join(args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] +
|
||||
"_ser.jpg"), img_res)
|
||||
|
||||
return
|
||||
preds = outputs[0]
|
||||
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
|
||||
ocr_info = merge_preds_list_with_ocr_info(
|
||||
ocr_info, inputs["segment_offset_id"], preds,
|
||||
self.label2id_map_for_draw)
|
||||
return ocr_info, inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
infer(args)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# get infer img list
|
||||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||||
|
||||
# loop for infer
|
||||
ser_engine = SerPredictor(args)
|
||||
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
result, _ = ser_engine(img)
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"ser_resule": result,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_ser_results(img, result)
|
||||
cv2.imwrite(
|
||||
os.path.join(args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] +
|
||||
"_ser.jpg"), img_res)
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
|
||||
import paddle
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
|
||||
|
||||
# relative reference
|
||||
from utils import parse_args, get_image_file_list, draw_re_results
|
||||
from infer_ser_e2e import SerPredictor
|
||||
|
||||
|
||||
def make_input(ser_input, ser_result, max_seq_len=512):
|
||||
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
|
||||
|
||||
entities = ser_input['entities'][0]
|
||||
assert len(entities) == len(ser_result)
|
||||
|
||||
# entities
|
||||
start = []
|
||||
end = []
|
||||
label = []
|
||||
entity_idx_dict = {}
|
||||
for i, (res, entity) in enumerate(zip(ser_result, entities)):
|
||||
if res['pred'] == 'O':
|
||||
continue
|
||||
entity_idx_dict[len(start)] = i
|
||||
start.append(entity['start'])
|
||||
end.append(entity['end'])
|
||||
label.append(entities_labels[res['pred']])
|
||||
entities = dict(start=start, end=end, label=label)
|
||||
|
||||
# relations
|
||||
head = []
|
||||
tail = []
|
||||
for i in range(len(entities["label"])):
|
||||
for j in range(len(entities["label"])):
|
||||
if entities["label"][i] == 1 and entities["label"][j] == 2:
|
||||
head.append(i)
|
||||
tail.append(j)
|
||||
|
||||
relations = dict(head=head, tail=tail)
|
||||
|
||||
batch_size = ser_input["input_ids"].shape[0]
|
||||
entities_batch = []
|
||||
relations_batch = []
|
||||
for b in range(batch_size):
|
||||
entities_batch.append(entities)
|
||||
relations_batch.append(relations)
|
||||
|
||||
ser_input['entities'] = entities_batch
|
||||
ser_input['relations'] = relations_batch
|
||||
|
||||
ser_input.pop('segment_offset_id')
|
||||
return ser_input, entity_idx_dict
|
||||
|
||||
|
||||
class SerReSystem(object):
|
||||
def __init__(self, args):
|
||||
self.ser_engine = SerPredictor(args)
|
||||
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
|
||||
args.re_model_name_or_path)
|
||||
self.model = LayoutXLMForRelationExtraction.from_pretrained(
|
||||
args.re_model_name_or_path)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, img):
|
||||
ser_result, ser_inputs = self.ser_engine(img)
|
||||
re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
|
||||
|
||||
re_result = self.model(**re_input)
|
||||
|
||||
pred_relations = re_result['pred_relations'][0]
|
||||
# 进行 relations 到 ocr信息的转换
|
||||
result = []
|
||||
used_tail_id = []
|
||||
for relation in pred_relations:
|
||||
if relation['tail_id'] in used_tail_id:
|
||||
continue
|
||||
used_tail_id.append(relation['tail_id'])
|
||||
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
|
||||
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# get infer img list
|
||||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||||
|
||||
# loop for infer
|
||||
ser_re_engine = SerReSystem(args)
|
||||
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
result = ser_re_engine(img)
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"result": result,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_re_results(img, result)
|
||||
cv2.imwrite(
|
||||
os.path.join(args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] +
|
||||
"_re.jpg"), img_res)
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
|
||||
|
||||
|
||||
def get_last_checkpoint(folder):
|
||||
content = os.listdir(folder)
|
||||
checkpoints = [
|
||||
path for path in content
|
||||
if _re_checkpoint.search(path) is not None and os.path.isdir(
|
||||
os.path.join(folder, path))
|
||||
]
|
||||
if len(checkpoints) == 0:
|
||||
return
|
||||
return os.path.join(
|
||||
folder,
|
||||
max(checkpoints,
|
||||
key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
|
||||
|
||||
|
||||
def re_score(pred_relations, gt_relations, mode="strict"):
|
||||
"""Evaluate RE predictions
|
||||
|
||||
Args:
|
||||
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
|
||||
gt_relations (list) : list of list of ground truth relations
|
||||
|
||||
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"tail": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"head_type": ent_type,
|
||||
"tail_type": ent_type,
|
||||
"type": rel_type}
|
||||
|
||||
vocab (Vocab) : dataset vocabulary
|
||||
mode (str) : in 'strict' or 'boundaries'"""
|
||||
|
||||
assert mode in ["strict", "boundaries"]
|
||||
|
||||
relation_types = [v for v in [0, 1] if not v == 0]
|
||||
scores = {
|
||||
rel: {
|
||||
"tp": 0,
|
||||
"fp": 0,
|
||||
"fn": 0
|
||||
}
|
||||
for rel in relation_types + ["ALL"]
|
||||
}
|
||||
|
||||
# Count GT relations and Predicted relations
|
||||
n_sents = len(gt_relations)
|
||||
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
|
||||
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
|
||||
|
||||
# Count TP, FP and FN per type
|
||||
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
|
||||
for rel_type in relation_types:
|
||||
# strict mode takes argument types into account
|
||||
if mode == "strict":
|
||||
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in pred_sent if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
# boundaries mode only takes argument spans into account
|
||||
elif mode == "boundaries":
|
||||
pred_rels = {(rel["head"], rel["tail"])
|
||||
for rel in pred_sent if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["tail"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
|
||||
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
|
||||
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
|
||||
|
||||
# Compute per entity Precision / Recall / F1
|
||||
for rel_type in scores.keys():
|
||||
if scores[rel_type]["tp"]:
|
||||
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fp"] + scores[rel_type]["tp"])
|
||||
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fn"] + scores[rel_type]["tp"])
|
||||
else:
|
||||
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
|
||||
|
||||
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
|
||||
scores[rel_type]["f1"] = (
|
||||
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
|
||||
(scores[rel_type]["p"] + scores[rel_type]["r"]))
|
||||
else:
|
||||
scores[rel_type]["f1"] = 0
|
||||
|
||||
# Compute micro F1 Scores
|
||||
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
|
||||
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
|
||||
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
|
||||
|
||||
if tp:
|
||||
precision = tp / (tp + fp)
|
||||
recall = tp / (tp + fn)
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
|
||||
else:
|
||||
precision, recall, f1 = 0, 0, 0
|
||||
|
||||
scores["ALL"]["p"] = precision
|
||||
scores["ALL"]["r"] = recall
|
||||
scores["ALL"]["f1"] = f1
|
||||
scores["ALL"]["tp"] = tp
|
||||
scores["ALL"]["fp"] = fp
|
||||
scores["ALL"]["fn"] = fn
|
||||
|
||||
# Compute Macro F1 Scores
|
||||
scores["ALL"]["Macro_f1"] = np.mean(
|
||||
[scores[ent_type]["f1"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_p"] = np.mean(
|
||||
[scores[ent_type]["p"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_r"] = np.mean(
|
||||
[scores[ent_type]["r"] for ent_type in relation_types])
|
||||
|
||||
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
|
||||
|
||||
# logger.info(
|
||||
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
|
||||
# n_sents, n_rels, n_found, tp
|
||||
# )
|
||||
# )
|
||||
# logger.info(
|
||||
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
|
||||
# )
|
||||
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
|
||||
# logger.info(
|
||||
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
|
||||
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
|
||||
# )
|
||||
# )
|
||||
|
||||
# for rel_type in relation_types:
|
||||
# logger.info(
|
||||
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
|
||||
# rel_type,
|
||||
# scores[rel_type]["tp"],
|
||||
# scores[rel_type]["fp"],
|
||||
# scores[rel_type]["fn"],
|
||||
# scores[rel_type]["p"],
|
||||
# scores[rel_type]["r"],
|
||||
# scores[rel_type]["f1"],
|
||||
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
|
||||
# )
|
||||
# )
|
||||
|
||||
return scores
|
|
@ -0,0 +1,261 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from utils import parse_args, get_bio_label_maps, print_arguments
|
||||
from data_collator import DataCollator
|
||||
from metric import re_score
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
|
||||
|
||||
def cal_metric(re_preds, re_labels, entities):
|
||||
gt_relations = []
|
||||
for b in range(len(re_labels)):
|
||||
rel_sent = []
|
||||
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
|
||||
rel = {}
|
||||
rel["head_id"] = head
|
||||
rel["head"] = (entities[b]["start"][rel["head_id"]],
|
||||
entities[b]["end"][rel["head_id"]])
|
||||
rel["head_type"] = entities[b]["label"][rel["head_id"]]
|
||||
|
||||
rel["tail_id"] = tail
|
||||
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
|
||||
entities[b]["end"][rel["tail_id"]])
|
||||
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
|
||||
|
||||
rel["type"] = 1
|
||||
rel_sent.append(rel)
|
||||
gt_relations.append(rel_sent)
|
||||
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
|
||||
return re_metrics
|
||||
|
||||
|
||||
def evaluate(model, eval_dataloader, logger, prefix=""):
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
|
||||
|
||||
re_preds = []
|
||||
re_labels = []
|
||||
entities = []
|
||||
eval_loss = 0.0
|
||||
model.eval()
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
with paddle.no_grad():
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss'].mean().item()
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
|
||||
idx, len(eval_dataloader), loss))
|
||||
|
||||
eval_loss += loss
|
||||
re_preds.extend(outputs['pred_relations'])
|
||||
re_labels.extend(batch['relations'])
|
||||
entities.extend(batch['entities'])
|
||||
re_metrics = cal_metric(re_preds, re_labels, entities)
|
||||
re_metrics = {
|
||||
"precision": re_metrics["ALL"]["p"],
|
||||
"recall": re_metrics["ALL"]["r"],
|
||||
"f1": re_metrics["ALL"]["f1"],
|
||||
}
|
||||
model.train()
|
||||
return re_metrics
|
||||
|
||||
|
||||
def train(args):
|
||||
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
|
||||
print_arguments(args, logger)
|
||||
|
||||
# Added here for reproducibility (even between python 2 and 3)
|
||||
set_seed(args.seed)
|
||||
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
# dist mode
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
|
||||
model = LayoutXLMForRelationExtraction(model, dropout=None)
|
||||
|
||||
# dist mode
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
model = paddle.distributed.DataParallel(model)
|
||||
|
||||
train_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.train_data_dir,
|
||||
label_path=args.train_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
train_sampler = paddle.io.DistributedBatchSampler(
|
||||
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * \
|
||||
max(1, paddle.distributed.get_world_size())
|
||||
train_dataloader = paddle.io.DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=8,
|
||||
use_shared_memory=True,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=8,
|
||||
shuffle=False,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
t_total = len(train_dataloader) * args.num_train_epochs
|
||||
|
||||
# build linear decay with warmup lr sch
|
||||
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
|
||||
learning_rate=args.learning_rate,
|
||||
decay_steps=t_total,
|
||||
end_lr=0.0,
|
||||
power=1.0)
|
||||
if args.warmup_steps > 0:
|
||||
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
|
||||
lr_scheduler,
|
||||
args.warmup_steps,
|
||||
start_lr=0,
|
||||
end_lr=args.learning_rate, )
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=args.learning_rate,
|
||||
parameters=model.parameters(),
|
||||
epsilon=args.adam_epsilon,
|
||||
grad_clip=grad_clip,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = {}".format(len(train_dataset)))
|
||||
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per GPU = {}".format(
|
||||
args.per_gpu_train_batch_size))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = {}".
|
||||
format(args.train_batch_size * paddle.distributed.get_world_size()))
|
||||
logger.info(" Total optimization steps = {}".format(t_total))
|
||||
|
||||
global_step = 0
|
||||
model.clear_gradients()
|
||||
train_dataloader_len = len(train_dataloader)
|
||||
best_metirc = {'f1': 0}
|
||||
model.train()
|
||||
|
||||
for epoch in range(int(args.num_train_epochs)):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
outputs = model(**batch)
|
||||
# model outputs are always tuple in ppnlp (see doc)
|
||||
loss = outputs['loss']
|
||||
loss = loss.mean()
|
||||
|
||||
logger.info(
|
||||
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
|
||||
format(epoch, args.num_train_epochs, step, train_dataloader_len,
|
||||
global_step, np.mean(loss.numpy()), optimizer.get_lr()))
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
# lr_scheduler.step() # Update learning rate schedule
|
||||
|
||||
global_step += 1
|
||||
|
||||
if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
|
||||
global_step % args.eval_steps == 0):
|
||||
# Log metrics
|
||||
if (paddle.distributed.get_rank() == 0 and args.
|
||||
evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(model, eval_dataloader, logger)
|
||||
if results['f1'] > best_metirc['f1']:
|
||||
best_metirc = results
|
||||
output_dir = os.path.join(args.output_dir,
|
||||
"checkpoint-best")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
paddle.save(args,
|
||||
os.path.join(output_dir,
|
||||
"training_args.bin"))
|
||||
logger.info("Saving model checkpoint to {}".format(
|
||||
output_dir))
|
||||
logger.info("eval results: {}".format(results))
|
||||
logger.info("best_metirc: {}".format(best_metirc))
|
||||
|
||||
if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and
|
||||
global_step % args.save_steps == 0):
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-latest")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
paddle.save(args,
|
||||
os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to {}".format(
|
||||
output_dir))
|
||||
logger.info("best_metirc: {}".format(best_metirc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
train(args)
|
|
@ -12,8 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
import copy
|
||||
import logging
|
||||
|
@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
|
|||
from xfun import XFUNDataset
|
||||
from utils import parse_args
|
||||
from utils import get_bio_label_maps
|
||||
from utils import print_arguments
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
|
@ -38,17 +44,8 @@ def set_seed(args):
|
|||
|
||||
def train(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
filename=os.path.join(args.output_dir, "train.log")
|
||||
if paddle.distributed.get_rank() == 0 else None,
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO
|
||||
if paddle.distributed.get_rank() == 0 else logging.WARN, )
|
||||
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
logger.addHandler(ch)
|
||||
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
|
||||
print_arguments(args, logger)
|
||||
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
@ -136,10 +133,10 @@ def train(args):
|
|||
loss = outputs[0]
|
||||
loss = loss.mean()
|
||||
logger.info(
|
||||
"[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ".
|
||||
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
|
||||
format(epoch_id, args.num_train_epochs, step,
|
||||
len(train_dataloader),
|
||||
lr_scheduler.get_lr(), loss.numpy()[0]))
|
||||
len(train_dataloader), global_step,
|
||||
loss.numpy()[0], lr_scheduler.get_lr()))
|
||||
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
|
@ -154,13 +151,9 @@ def train(args):
|
|||
# Only evaluate when single GPU otherwise metrics may not average well
|
||||
if paddle.distributed.get_rank(
|
||||
) == 0 and args.evaluate_during_training:
|
||||
results, _ = evaluate(
|
||||
args,
|
||||
model,
|
||||
tokenizer,
|
||||
label2id_map,
|
||||
id2label_map,
|
||||
pad_token_label_id, )
|
||||
results, _ = evaluate(args, model, tokenizer, label2id_map,
|
||||
id2label_map, pad_token_label_id,
|
||||
logger)
|
||||
|
||||
if best_metrics is None or results["f1"] >= best_metrics[
|
||||
"f1"]:
|
||||
|
@ -204,6 +197,7 @@ def evaluate(args,
|
|||
label2id_map,
|
||||
id2label_map,
|
||||
pad_token_label_id,
|
||||
logger,
|
||||
prefix=""):
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
|
@ -299,15 +293,6 @@ def evaluate(args,
|
|||
return results, preds_list
|
||||
|
||||
|
||||
def print_arguments(args):
|
||||
"""print arguments"""
|
||||
print('----------- Configuration Arguments -----------')
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
print('%s: %s' % (arg, value))
|
||||
print('------------------------------------------------')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print_arguments(args)
|
||||
train(args)
|
||||
|
|
|
@ -24,8 +24,6 @@ import paddle
|
|||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
|
||||
def get_bio_label_maps(label_map_path):
|
||||
with open(label_map_path, "r") as fin:
|
||||
|
@ -66,9 +64,9 @@ def get_image_file_list(img_file):
|
|||
|
||||
def draw_ser_results(image,
|
||||
ocr_results,
|
||||
font_path="../doc/fonts/simfang.ttf",
|
||||
font_path="../../doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
np.random.seed(0)
|
||||
np.random.seed(2021)
|
||||
color = (np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)))
|
||||
|
@ -82,38 +80,64 @@ def draw_ser_results(image,
|
|||
draw = ImageDraw.Draw(img_new)
|
||||
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
|
||||
for ocr_info in ocr_results:
|
||||
if ocr_info["pred_id"] not in color_map:
|
||||
continue
|
||||
color = color_map[ocr_info["pred_id"]]
|
||||
|
||||
# draw ocr results outline
|
||||
bbox = ocr_info["bbox"]
|
||||
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
|
||||
draw.rectangle(bbox, fill=color)
|
||||
|
||||
# draw ocr results
|
||||
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
|
||||
start_y = max(0, bbox[0][1] - font_size)
|
||||
tw = font.getsize(text)[0]
|
||||
draw.rectangle(
|
||||
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1,
|
||||
start_y + font_size)],
|
||||
fill=(0, 0, 255))
|
||||
draw.text(
|
||||
(bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
|
||||
|
||||
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
||||
|
||||
|
||||
def build_ocr_engine(rec_model_dir, det_model_dir):
|
||||
ocr_engine = PaddleOCR(
|
||||
rec_model_dir=rec_model_dir,
|
||||
det_model_dir=det_model_dir,
|
||||
use_angle_cls=False)
|
||||
return ocr_engine
|
||||
def draw_box_txt(bbox, text, draw, font, font_size, color):
|
||||
# draw ocr results outline
|
||||
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
|
||||
draw.rectangle(bbox, fill=color)
|
||||
|
||||
# draw ocr results
|
||||
start_y = max(0, bbox[0][1] - font_size)
|
||||
tw = font.getsize(text)[0]
|
||||
draw.rectangle(
|
||||
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
|
||||
fill=(0, 0, 255))
|
||||
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
|
||||
|
||||
|
||||
def draw_re_results(image,
|
||||
result,
|
||||
font_path="../../doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
np.random.seed(0)
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
img_new = image.copy()
|
||||
draw = ImageDraw.Draw(img_new)
|
||||
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
color_head = (0, 0, 255)
|
||||
color_tail = (255, 0, 0)
|
||||
color_line = (0, 255, 0)
|
||||
|
||||
for ocr_info_head, ocr_info_tail in result:
|
||||
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
|
||||
font_size, color_head)
|
||||
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
|
||||
font_size, color_tail)
|
||||
|
||||
center_head = (
|
||||
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
|
||||
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
|
||||
center_tail = (
|
||||
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
|
||||
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
|
||||
|
||||
draw.line([center_head, center_tail], fill=color_line, width=5)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
||||
|
||||
|
||||
# pad sentences
|
||||
|
@ -130,7 +154,7 @@ def pad_sentences(tokenizer,
|
|||
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
|
||||
|
||||
needs_to_be_padded = pad_to_max_seq_len and \
|
||||
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
|
||||
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_seq_len - len(encoded_inputs["input_ids"])
|
||||
|
@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512):
|
|||
truncate is often used in training process
|
||||
"""
|
||||
for key in encoded_inputs:
|
||||
if key == 'entities':
|
||||
encoded_inputs[key] = [encoded_inputs[key]]
|
||||
continue
|
||||
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
|
||||
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
|
||||
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
|
||||
|
@ -192,6 +219,7 @@ def preprocess(
|
|||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
entities = []
|
||||
|
||||
for info in ocr_info:
|
||||
# x1, y1, x2, y2
|
||||
|
@ -211,6 +239,13 @@ def preprocess(
|
|||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
|
||||
|
||||
# for re
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": "O",
|
||||
})
|
||||
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
|
@ -222,6 +257,7 @@ def preprocess(
|
|||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
"entities": entities
|
||||
}
|
||||
|
||||
encoded_inputs = pad_sentences(
|
||||
|
@ -294,35 +330,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
|
|||
return ocr_info
|
||||
|
||||
|
||||
def print_arguments(args, logger=None):
|
||||
print_func = logger.info if logger is not None else print
|
||||
"""print arguments"""
|
||||
print_func('----------- Configuration Arguments -----------')
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
print_func('%s: %s' % (arg, value))
|
||||
print_func('------------------------------------------------')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
# yapf: disable
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,)
|
||||
parser.add_argument("--train_data_dir", default=None, type=str, required=False,)
|
||||
parser.add_argument("--train_label_path", default=None, type=str, required=False,)
|
||||
parser.add_argument("--eval_data_dir", default=None, type=str, required=False,)
|
||||
parser.add_argument("--eval_label_path", default=None, type=str, required=False,)
|
||||
parser.add_argument("--model_name_or_path",
|
||||
default=None, type=str, required=True,)
|
||||
parser.add_argument("--re_model_name_or_path",
|
||||
default=None, type=str, required=False,)
|
||||
parser.add_argument("--train_data_dir", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--train_label_path", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--eval_data_dir", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--eval_label_path", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,)
|
||||
parser.add_argument("--max_seq_length", default=512, type=int,)
|
||||
parser.add_argument("--evaluate_during_training", action="store_true",)
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",)
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",)
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",)
|
||||
parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",)
|
||||
parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",)
|
||||
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",)
|
||||
parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",)
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8,
|
||||
type=int, help="Batch size per GPU/CPU for training.",)
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8,
|
||||
type=int, help="Batch size per GPU/CPU for eval.",)
|
||||
parser.add_argument("--learning_rate", default=5e-5,
|
||||
type=float, help="The initial learning rate for Adam.",)
|
||||
parser.add_argument("--weight_decay", default=0.0,
|
||||
type=float, help="Weight decay if we apply some.",)
|
||||
parser.add_argument("--adam_epsilon", default=1e-8,
|
||||
type=float, help="Epsilon for Adam optimizer.",)
|
||||
parser.add_argument("--max_grad_norm", default=1.0,
|
||||
type=float, help="Max gradient norm.",)
|
||||
parser.add_argument("--num_train_epochs", default=3, type=int,
|
||||
help="Total number of training epochs to perform.",)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.",)
|
||||
parser.add_argument("--eval_steps", type=int, default=10,
|
||||
help="eval every X updates steps.",)
|
||||
parser.add_argument("--save_steps", type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.",)
|
||||
parser.add_argument("--seed", type=int, default=2048,
|
||||
help="random seed for initialization",)
|
||||
|
||||
parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
|
||||
parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
|
||||
parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
|
||||
parser.add_argument(
|
||||
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
|
||||
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
|
||||
parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results")
|
||||
parser.add_argument("--ocr_json_path", default=None,
|
||||
type=str, required=False, help="ocr prediction results")
|
||||
# yapf: enable
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
|