unify kie and ser for vqa data format (#6704)
* unify kie and ser for vqa data format * fix config and label ops * fix doc * add distort bboxpull/6752/head^2
parent
466214f9f8
commit
e13ec733a6
|
@ -17,7 +17,7 @@ Global:
|
|||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
class_path: ./train_data/wildreceipt/class_list.txt
|
||||
class_path: &class_path ./train_data/wildreceipt/class_list.txt
|
||||
infer_img: ./train_data/wildreceipt/1.txt
|
||||
save_res_path: ./output/sdmgr_kie/predicts_kie.txt
|
||||
img_scale: [ 1024, 512 ]
|
||||
|
@ -72,6 +72,7 @@ Train:
|
|||
order: 'hwc'
|
||||
- KieLabelEncode: # Class handling label
|
||||
character_dict_path: ./train_data/wildreceipt/dict.txt
|
||||
class_path: *class_path
|
||||
- KieResize:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
|
@ -88,7 +89,6 @@ Eval:
|
|||
data_dir: ./train_data/wildreceipt
|
||||
label_file_list:
|
||||
- ./train_data/wildreceipt/wildreceipt_test.txt
|
||||
# - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
|
|
|
@ -43,7 +43,7 @@ Optimizer:
|
|||
|
||||
PostProcess:
|
||||
name: VQASerTokenLayoutLMPostProcess
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
class_path: &class_path train_data/XFUND/class_list_xfun.txt
|
||||
|
||||
Metric:
|
||||
name: VQASerTokenMetric
|
||||
|
@ -54,7 +54,7 @@ Train:
|
|||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
- train_data/XFUND/zh_train/train.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
|
@ -89,7 +89,7 @@ Eval:
|
|||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
- train_data/XFUND/zh_val/val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
|
|
|
@ -44,7 +44,7 @@ Optimizer:
|
|||
|
||||
PostProcess:
|
||||
name: VQASerTokenLayoutLMPostProcess
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
class_path: &class_path train_data/XFUND/class_list_xfun.txt
|
||||
|
||||
Metric:
|
||||
name: VQASerTokenMetric
|
||||
|
@ -55,7 +55,7 @@ Train:
|
|||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
- train_data/XFUND/zh_train/train.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
|
@ -90,7 +90,7 @@ Eval:
|
|||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
- train_data/XFUND/zh_val/val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
|
|
|
@ -11,7 +11,7 @@ Global:
|
|||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
seed: 2022
|
||||
infer_img: doc/vqa/input/zh_val_42.jpg
|
||||
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
|
||||
save_res_path: ./output/ser
|
||||
|
||||
Architecture:
|
||||
|
@ -54,7 +54,7 @@ Train:
|
|||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
- train_data/XFUND/zh_train/train.json
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -90,7 +90,7 @@ Eval:
|
|||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
- train_data/XFUND/zh_val/val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
|
|
|
@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object):
|
|||
|
||||
|
||||
class KieLabelEncode(object):
|
||||
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
class_path,
|
||||
norm=10,
|
||||
directed=False,
|
||||
**kwargs):
|
||||
super(KieLabelEncode, self).__init__()
|
||||
self.dict = dict({'': 0})
|
||||
self.label2classid_map = dict()
|
||||
with open(character_dict_path, 'r', encoding='utf-8') as fr:
|
||||
idx = 1
|
||||
for line in fr:
|
||||
char = line.strip()
|
||||
self.dict[char] = idx
|
||||
idx += 1
|
||||
with open(class_path, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
for idx, line in enumerate(lines):
|
||||
line = line.strip("\n")
|
||||
self.label2classid_map[line] = idx
|
||||
self.norm = norm
|
||||
self.directed = directed
|
||||
|
||||
|
@ -408,7 +419,7 @@ class KieLabelEncode(object):
|
|||
text_ind = [self.dict[c] for c in text if c in self.dict]
|
||||
text_inds.append(text_ind)
|
||||
if 'label' in ann.keys():
|
||||
labels.append(ann['label'])
|
||||
labels.append(self.label2classid_map[ann['label']])
|
||||
elif 'key_cls' in ann.keys():
|
||||
labels.append(ann['key_cls'])
|
||||
else:
|
||||
|
@ -876,15 +887,16 @@ class VQATokenLabelEncode(object):
|
|||
for info in ocr_info:
|
||||
if train_re:
|
||||
# for re
|
||||
if len(info["text"]) == 0:
|
||||
if len(info["transcription"]) == 0:
|
||||
empty_entity.add(info["id"])
|
||||
continue
|
||||
id2label[info["id"]] = info["label"]
|
||||
relations.extend([tuple(sorted(l)) for l in info["linking"]])
|
||||
# smooth_box
|
||||
info["bbox"] = self.trans_poly_to_bbox(info["points"])
|
||||
bbox = self._smooth_box(info["bbox"], height, width)
|
||||
|
||||
text = info["text"]
|
||||
text = info["transcription"]
|
||||
encode_res = self.tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
|
@ -944,29 +956,29 @@ class VQATokenLabelEncode(object):
|
|||
data['entity_id_to_index_map'] = entity_id_to_index_map
|
||||
return data
|
||||
|
||||
def _load_ocr_info(self, data):
|
||||
def trans_poly_to_bbox(poly):
|
||||
def trans_poly_to_bbox(self, poly):
|
||||
x1 = np.min([p[0] for p in poly])
|
||||
x2 = np.max([p[0] for p in poly])
|
||||
y1 = np.min([p[1] for p in poly])
|
||||
y2 = np.max([p[1] for p in poly])
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def _load_ocr_info(self, data):
|
||||
if self.infer_mode:
|
||||
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
|
||||
ocr_info = []
|
||||
for res in ocr_result:
|
||||
ocr_info.append({
|
||||
"text": res[1][0],
|
||||
"bbox": trans_poly_to_bbox(res[0]),
|
||||
"poly": res[0],
|
||||
"transcription": res[1][0],
|
||||
"bbox": self.trans_poly_to_bbox(res[0]),
|
||||
"points": res[0],
|
||||
})
|
||||
return ocr_info
|
||||
else:
|
||||
info = data['label']
|
||||
# read text info
|
||||
info_dict = json.loads(info)
|
||||
return info_dict["ocr_info"]
|
||||
return info_dict
|
||||
|
||||
def _smooth_box(self, bbox, height, width):
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
|
@ -977,7 +989,7 @@ class VQATokenLabelEncode(object):
|
|||
|
||||
def _parse_label(self, label, encode_res):
|
||||
gt_label = []
|
||||
if label.lower() == "other":
|
||||
if label.lower() in ["other", "others", "ignore"]:
|
||||
gt_label.extend([0] * len(encode_res["input_ids"]))
|
||||
else:
|
||||
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
||||
|
|
|
@ -13,7 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
|
||||
from .augment import DistortBBox
|
||||
|
||||
__all__ = [
|
||||
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
|
||||
'VQATokenPad',
|
||||
'VQASerTokenChunk',
|
||||
'VQAReTokenChunk',
|
||||
'VQAReTokenRelation',
|
||||
'DistortBBox',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class DistortBBox:
|
||||
def __init__(self, prob=0.5, max_scale=1, **kwargs):
|
||||
"""Random distort bbox
|
||||
"""
|
||||
self.prob = prob
|
||||
self.max_scale = max_scale
|
||||
|
||||
def __call__(self, data):
|
||||
if random.random() > self.prob:
|
||||
return data
|
||||
bbox = np.array(data['bbox'])
|
||||
rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
|
||||
bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
|
||||
data['bbox'] = np.clip(data['bbox'], 0, 1000)
|
||||
data['bbox'] = bbox.tolist()
|
||||
sys.stdout.flush()
|
||||
return data
|
|
@ -91,18 +91,19 @@ def check_and_read_gif(img_path):
|
|||
def load_vqa_bio_label_maps(label_map_path):
|
||||
with open(label_map_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
if "O" not in lines:
|
||||
lines.insert(0, "O")
|
||||
labels = []
|
||||
for line in lines:
|
||||
if line == "O":
|
||||
labels.append("O")
|
||||
else:
|
||||
old_lines = [line.strip() for line in lines]
|
||||
lines = ["O"]
|
||||
for line in old_lines:
|
||||
# "O" has already been in lines
|
||||
if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
|
||||
continue
|
||||
lines.append(line)
|
||||
labels = ["O"]
|
||||
for line in lines[1:]:
|
||||
labels.append("B-" + line)
|
||||
labels.append("I-" + line)
|
||||
label2id_map = {label: idx for idx, label in enumerate(labels)}
|
||||
id2label_map = {idx: label for idx, label in enumerate(labels)}
|
||||
label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
|
||||
id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
|
||||
return label2id_map, id2label_map
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|||
def draw_ser_results(image,
|
||||
ocr_results,
|
||||
font_path="doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
font_size=14):
|
||||
np.random.seed(2021)
|
||||
color = (np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)),
|
||||
|
@ -40,9 +40,15 @@ def draw_ser_results(image,
|
|||
if ocr_info["pred_id"] not in color_map:
|
||||
continue
|
||||
color = color_map[ocr_info["pred_id"]]
|
||||
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
|
||||
text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
|
||||
|
||||
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
|
||||
if "bbox" in ocr_info:
|
||||
# draw with ocr engine
|
||||
bbox = ocr_info["bbox"]
|
||||
else:
|
||||
# draw with ocr groundtruth
|
||||
bbox = trans_poly_to_bbox(ocr_info["points"])
|
||||
draw_box_txt(bbox, text, draw, font, font_size, color)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
||||
|
@ -62,6 +68,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color):
|
|||
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
|
||||
|
||||
|
||||
def trans_poly_to_bbox(poly):
|
||||
x1 = np.min([p[0] for p in poly])
|
||||
x2 = np.max([p[0] for p in poly])
|
||||
y1 = np.min([p[1] for p in poly])
|
||||
y2 = np.max([p[1] for p in poly])
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
|
||||
def draw_re_results(image,
|
||||
result,
|
||||
font_path="doc/fonts/simfang.ttf",
|
||||
|
@ -80,10 +94,10 @@ def draw_re_results(image,
|
|||
color_line = (0, 255, 0)
|
||||
|
||||
for ocr_info_head, ocr_info_tail in result:
|
||||
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
|
||||
font_size, color_head)
|
||||
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
|
||||
font_size, color_tail)
|
||||
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
|
||||
draw, font, font_size, color_head)
|
||||
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
|
||||
draw, font, font_size, color_tail)
|
||||
|
||||
center_head = (
|
||||
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
|
||||
|
|
|
@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类
|
|||
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
|
||||
|
||||
```
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
|
||||
```
|
||||
|
||||
执行预测:
|
||||
|
|
|
@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu
|
|||
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
|
||||
|
||||
```shell
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
|
||||
```
|
||||
|
||||
Download the pretrained model and predict the result:
|
||||
|
|
|
@ -125,13 +125,13 @@ If you want to experience the prediction process directly, you can download the
|
|||
|
||||
* Download the processed dataset
|
||||
|
||||
The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar).
|
||||
The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar).
|
||||
|
||||
|
||||
Download and unzip the dataset, and place the dataset in the current directory after unzipping.
|
||||
|
||||
```shell
|
||||
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
|
||||
````
|
||||
|
||||
* Convert the dataset
|
||||
|
|
|
@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt
|
|||
|
||||
* 下载处理好的数据集
|
||||
|
||||
处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
|
||||
处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)。
|
||||
|
||||
|
||||
下载并解压该数据集,解压后将数据集放置在当前目录下。
|
||||
|
||||
```shell
|
||||
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
|
||||
```
|
||||
|
||||
* 转换数据集
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
QUESTION
|
||||
ANSWER
|
||||
HEADER
|
|
@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None):
|
|||
|
||||
json_info = json.loads(lines[0])
|
||||
documents = json_info["documents"]
|
||||
label_info = {}
|
||||
with open(output_file, "w", encoding='utf-8') as fout:
|
||||
for idx, document in enumerate(documents):
|
||||
label_info = []
|
||||
img_info = document["img"]
|
||||
document = document["document"]
|
||||
image_path = img_info["fname"]
|
||||
|
||||
label_info["height"] = img_info["height"]
|
||||
label_info["width"] = img_info["width"]
|
||||
|
||||
label_info["ocr_info"] = []
|
||||
|
||||
for doc in document:
|
||||
label_info["ocr_info"].append({
|
||||
"text": doc["text"],
|
||||
x1, y1, x2, y2 = doc["box"]
|
||||
points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
label_info.append({
|
||||
"transcription": doc["text"],
|
||||
"label": doc["label"],
|
||||
"bbox": doc["box"],
|
||||
"points": points,
|
||||
"id": doc["id"],
|
||||
"linking": doc["linking"],
|
||||
"words": doc["words"]
|
||||
"linking": doc["linking"]
|
||||
})
|
||||
|
||||
fout.write(image_path + "\t" + json.dumps(
|
||||
|
|
|
@ -39,13 +39,12 @@ import time
|
|||
|
||||
|
||||
def read_class_list(filepath):
|
||||
dict = {}
|
||||
ret = {}
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
key, value = line.split(" ")
|
||||
dict[key] = value.rstrip()
|
||||
return dict
|
||||
for idx, line in enumerate(lines):
|
||||
ret[idx] = line.strip("\n")
|
||||
return ret
|
||||
|
||||
|
||||
def draw_kie_result(batch, node, idx_to_cls, count):
|
||||
|
@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
|
|||
x_min = int(min([point[0] for point in new_box]))
|
||||
y_min = int(min([point[1] for point in new_box]))
|
||||
|
||||
pred_label = str(node_pred_label[i])
|
||||
pred_label = node_pred_label[i]
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||
|
@ -109,8 +108,7 @@ def main():
|
|||
save_res_path = config['Global']['save_res_path']
|
||||
class_path = config['Global']['class_path']
|
||||
idx_to_cls = read_class_list(class_path)
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
|
||||
|
||||
model.eval()
|
||||
|
||||
|
|
|
@ -86,15 +86,16 @@ class SerPredictor(object):
|
|||
]
|
||||
|
||||
transforms.append(op)
|
||||
if config["Global"].get("infer_mode", None) is None:
|
||||
global_config['infer_mode'] = True
|
||||
self.ops = create_operators(config['Eval']['dataset']['transforms'],
|
||||
global_config)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, img_path):
|
||||
with open(img_path, 'rb') as f:
|
||||
def __call__(self, data):
|
||||
with open(data["img_path"], 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
data["image"] = img
|
||||
batch = transform(data, self.ops)
|
||||
batch = to_tensor(batch)
|
||||
preds = self.model(batch)
|
||||
|
@ -112,20 +113,35 @@ if __name__ == '__main__':
|
|||
|
||||
ser_engine = SerPredictor(config)
|
||||
|
||||
if config["Global"].get("infer_mode", None) is False:
|
||||
data_dir = config['Eval']['dataset']['data_dir']
|
||||
with open(config['Global']['infer_img'], "rb") as f:
|
||||
infer_imgs = f.readlines()
|
||||
else:
|
||||
infer_imgs = get_image_file_list(config['Global']['infer_img'])
|
||||
|
||||
with open(
|
||||
os.path.join(config['Global']['save_res_path'],
|
||||
"infer_results.txt"),
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
for idx, info in enumerate(infer_imgs):
|
||||
if config["Global"].get("infer_mode", None) is False:
|
||||
data_line = info.decode('utf-8')
|
||||
substr = data_line.strip("\n").split("\t")
|
||||
img_path = os.path.join(data_dir, substr[0])
|
||||
data = {'img_path': img_path, 'label': substr[1]}
|
||||
else:
|
||||
img_path = info
|
||||
data = {'img_path': img_path}
|
||||
|
||||
save_img_path = os.path.join(
|
||||
config['Global']['save_res_path'],
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
|
||||
logger.info("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
result, _ = ser_engine(img_path)
|
||||
result, _ = ser_engine(data)
|
||||
result = result[0]
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
|
|
|
@ -576,8 +576,8 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
|
||||
'ViTSTR', 'ABINet'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue