PaddleOCR/ppstructure/vqa/infer_ser_e2e.py

157 lines
5.1 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def parse_ocr_info_for_ser(ocr_result):
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
class SerPredictor(object):
def __init__(self, args):
self.args = args
self.max_seq_length = args.max_seq_length
# init ser token and model
tokenizer_class, base_model_class, model_class = MODELS[
args.ser_model_type]
self.tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path)
self.model = model_class.from_pretrained(args.model_name_or_path)
self.model.eval()
# init ocr_engine
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(
rec_model_dir=args.rec_model_dir,
det_model_dir=args.det_model_dir,
use_angle_cls=False,
show_log=False)
# init dict
label2id_map, self.id2label_map = get_bio_label_maps(
args.label_map_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
def __call__(self, img):
ocr_result = self.ocr_engine.ocr(img, cls=False)
ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess(
tokenizer=self.tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=self.max_seq_length)
if self.args.ser_model_type == 'LayoutLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif self.args.ser_model_type == 'LayoutXLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
self.label2id_map_for_draw)
return ocr_info, inputs
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_engine = SerPredictor(args)
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
result, _ = ser_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"ser_resule": result,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, result)
cv2.imwrite(save_img_path, img_res)