299 lines
10 KiB
Python
299 lines
10 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import cv2
|
||
import numpy as np
|
||
from copy import deepcopy
|
||
|
||
import paddle
|
||
|
||
# relative reference
|
||
from vaq_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
|
||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
|
||
|
||
MODELS = {
|
||
'LayoutXLM':
|
||
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
|
||
'LayoutLM':
|
||
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
|
||
}
|
||
|
||
|
||
def pad_sentences(tokenizer,
|
||
encoded_inputs,
|
||
max_seq_len=512,
|
||
pad_to_max_seq_len=True,
|
||
return_attention_mask=True,
|
||
return_token_type_ids=True,
|
||
return_overflowing_tokens=False,
|
||
return_special_tokens_mask=False):
|
||
# Padding with larger size, reshape is carried out
|
||
max_seq_len = (
|
||
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
|
||
|
||
needs_to_be_padded = pad_to_max_seq_len and \
|
||
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
|
||
|
||
if needs_to_be_padded:
|
||
difference = max_seq_len - len(encoded_inputs["input_ids"])
|
||
if tokenizer.padding_side == 'right':
|
||
if return_attention_mask:
|
||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||
"input_ids"]) + [0] * difference
|
||
if return_token_type_ids:
|
||
encoded_inputs["token_type_ids"] = (
|
||
encoded_inputs["token_type_ids"] +
|
||
[tokenizer.pad_token_type_id] * difference)
|
||
if return_special_tokens_mask:
|
||
encoded_inputs["special_tokens_mask"] = encoded_inputs[
|
||
"special_tokens_mask"] + [1] * difference
|
||
encoded_inputs["input_ids"] = encoded_inputs[
|
||
"input_ids"] + [tokenizer.pad_token_id] * difference
|
||
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
|
||
] * difference
|
||
else:
|
||
assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
|
||
tokenizer.padding_side)
|
||
else:
|
||
if return_attention_mask:
|
||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||
"input_ids"])
|
||
|
||
return encoded_inputs
|
||
|
||
|
||
def split_page(encoded_inputs, max_seq_len=512):
|
||
"""
|
||
truncate is often used in training process
|
||
"""
|
||
for key in encoded_inputs:
|
||
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
|
||
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
|
||
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
|
||
else: # for bbox
|
||
encoded_inputs[key] = encoded_inputs[key].reshape(
|
||
[-1, max_seq_len, 4])
|
||
return encoded_inputs
|
||
|
||
|
||
def preprocess(
|
||
tokenizer,
|
||
ori_img,
|
||
ocr_info,
|
||
img_size=(224, 224),
|
||
pad_token_label_id=-100,
|
||
max_seq_len=512,
|
||
add_special_ids=False,
|
||
return_attention_mask=True, ):
|
||
ocr_info = deepcopy(ocr_info)
|
||
height = ori_img.shape[0]
|
||
width = ori_img.shape[1]
|
||
|
||
img = cv2.resize(ori_img,
|
||
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
|
||
|
||
segment_offset_id = []
|
||
words_list = []
|
||
bbox_list = []
|
||
input_ids_list = []
|
||
token_type_ids_list = []
|
||
|
||
for info in ocr_info:
|
||
# x1, y1, x2, y2
|
||
bbox = info["bbox"]
|
||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||
|
||
text = info["text"]
|
||
encode_res = tokenizer.encode(
|
||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||
|
||
if not add_special_ids:
|
||
# TODO: use tok.all_special_ids to remove
|
||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
|
||
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
|
||
|
||
input_ids_list.extend(encode_res["input_ids"])
|
||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||
words_list.append(text)
|
||
segment_offset_id.append(len(input_ids_list))
|
||
|
||
encoded_inputs = {
|
||
"input_ids": input_ids_list,
|
||
"token_type_ids": token_type_ids_list,
|
||
"bbox": bbox_list,
|
||
"attention_mask": [1] * len(input_ids_list),
|
||
}
|
||
|
||
encoded_inputs = pad_sentences(
|
||
tokenizer,
|
||
encoded_inputs,
|
||
max_seq_len=max_seq_len,
|
||
return_attention_mask=return_attention_mask)
|
||
|
||
encoded_inputs = split_page(encoded_inputs)
|
||
|
||
fake_bs = encoded_inputs["input_ids"].shape[0]
|
||
|
||
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
|
||
[fake_bs] + list(img.shape))
|
||
|
||
encoded_inputs["segment_offset_id"] = segment_offset_id
|
||
|
||
return encoded_inputs
|
||
|
||
|
||
def postprocess(attention_mask, preds, label_map_path):
|
||
if isinstance(preds, paddle.Tensor):
|
||
preds = preds.numpy()
|
||
preds = np.argmax(preds, axis=2)
|
||
|
||
_, label_map = get_bio_label_maps(label_map_path)
|
||
|
||
preds_list = [[] for _ in range(preds.shape[0])]
|
||
|
||
# keep batch info
|
||
for i in range(preds.shape[0]):
|
||
for j in range(preds.shape[1]):
|
||
if attention_mask[i][j] == 1:
|
||
preds_list[i].append(label_map[preds[i][j]])
|
||
|
||
return preds_list
|
||
|
||
|
||
def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
|
||
preds_list):
|
||
# must ensure the preds_list is generated from the same image
|
||
preds = [p for pred in preds_list for p in pred]
|
||
label2id_map, _ = get_bio_label_maps(label_map_path)
|
||
for key in label2id_map:
|
||
if key.startswith("I-"):
|
||
label2id_map[key] = label2id_map["B" + key[1:]]
|
||
|
||
id2label_map = dict()
|
||
for key in label2id_map:
|
||
val = label2id_map[key]
|
||
if key == "O":
|
||
id2label_map[val] = key
|
||
if key.startswith("B-") or key.startswith("I-"):
|
||
id2label_map[val] = key[2:]
|
||
else:
|
||
id2label_map[val] = key
|
||
|
||
for idx in range(len(segment_offset_id)):
|
||
if idx == 0:
|
||
start_id = 0
|
||
else:
|
||
start_id = segment_offset_id[idx - 1]
|
||
|
||
end_id = segment_offset_id[idx]
|
||
|
||
curr_pred = preds[start_id:end_id]
|
||
curr_pred = [label2id_map[p] for p in curr_pred]
|
||
|
||
if len(curr_pred) <= 0:
|
||
pred_id = 0
|
||
else:
|
||
counts = np.bincount(curr_pred)
|
||
pred_id = np.argmax(counts)
|
||
ocr_info[idx]["pred_id"] = int(pred_id)
|
||
ocr_info[idx]["pred"] = id2label_map[pred_id]
|
||
return ocr_info
|
||
|
||
|
||
@paddle.no_grad()
|
||
def infer(args):
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# init token and model
|
||
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
|
||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||
model = model_class.from_pretrained(args.model_name_or_path)
|
||
|
||
model.eval()
|
||
|
||
# load ocr results json
|
||
ocr_results = dict()
|
||
with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
|
||
lines = fin.readlines()
|
||
for line in lines:
|
||
img_name, json_info = line.split("\t")
|
||
ocr_results[os.path.basename(img_name)] = json.loads(json_info)
|
||
|
||
# get infer img list
|
||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||
|
||
# loop for infer
|
||
with open(
|
||
os.path.join(args.output_dir, "infer_results.txt"),
|
||
"w",
|
||
encoding='utf-8') as fout:
|
||
for idx, img_path in enumerate(infer_imgs):
|
||
save_img_path = os.path.join(args.output_dir,
|
||
os.path.basename(img_path))
|
||
print("process: [{}/{}], save result to {}".format(
|
||
idx, len(infer_imgs), save_img_path))
|
||
|
||
img = cv2.imread(img_path)
|
||
|
||
ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
|
||
inputs = preprocess(
|
||
tokenizer=tokenizer,
|
||
ori_img=img,
|
||
ocr_info=ocr_info,
|
||
max_seq_len=args.max_seq_length)
|
||
if args.ser_model_type == 'LayoutLM':
|
||
preds = model(
|
||
input_ids=inputs["input_ids"],
|
||
bbox=inputs["bbox"],
|
||
token_type_ids=inputs["token_type_ids"],
|
||
attention_mask=inputs["attention_mask"])
|
||
elif args.ser_model_type == 'LayoutXLM':
|
||
preds = model(
|
||
input_ids=inputs["input_ids"],
|
||
bbox=inputs["bbox"],
|
||
image=inputs["image"],
|
||
token_type_ids=inputs["token_type_ids"],
|
||
attention_mask=inputs["attention_mask"])
|
||
preds = preds[0]
|
||
|
||
preds = postprocess(inputs["attention_mask"], preds,
|
||
args.label_map_path)
|
||
ocr_info = merge_preds_list_with_ocr_info(
|
||
args.label_map_path, ocr_info, inputs["segment_offset_id"],
|
||
preds)
|
||
|
||
fout.write(img_path + "\t" + json.dumps(
|
||
{
|
||
"ocr_info": ocr_info,
|
||
}, ensure_ascii=False) + "\n")
|
||
|
||
img_res = draw_ser_results(img, ocr_info)
|
||
cv2.imwrite(save_img_path, img_res)
|
||
|
||
return
|
||
|
||
|
||
if __name__ == "__main__":
|
||
args = parse_args()
|
||
infer(args)
|