add ser to ppstructure system
parent
c647a6da28
commit
d4a4c07c56
|
@ -81,13 +81,14 @@ mkdir inference && cd inference
|
|||
# 下载SER XFUND 模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
|
||||
cd ..
|
||||
python3 kie/predict_kie_token_ser.py \
|
||||
python3 predict_system.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
|
||||
--ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf \
|
||||
--ocr_order_method="tb-yx"
|
||||
--ocr_order_method="tb-yx" \
|
||||
--mode=kie
|
||||
```
|
||||
|
||||
运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。
|
||||
|
|
|
@ -82,13 +82,14 @@ mkdir inference && cd inference
|
|||
# download model
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
|
||||
cd ..
|
||||
python3 kie/predict_kie_token_ser.py \
|
||||
python3 predict_system.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
|
||||
--ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf \
|
||||
--ocr_order_method="tb-yx"
|
||||
--ocr_order_method="tb-yx" \
|
||||
--mode=kie
|
||||
```
|
||||
|
||||
After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name.
|
||||
|
|
|
@ -29,7 +29,7 @@ import tools.infer.utility as utility
|
|||
from tools.infer_kie_token_ser_re import make_input
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.visual import draw_re_results
|
||||
from ppocr.utils.visual import draw_ser_results, draw_re_results
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read
|
||||
from ppstructure.utility import parse_args
|
||||
from ppstructure.kie.predict_kie_token_ser import SerPredictor
|
||||
|
@ -41,15 +41,20 @@ class SerRePredictor(object):
|
|||
def __init__(self, args):
|
||||
self.use_visual_backbone = args.use_visual_backbone
|
||||
self.ser_engine = SerPredictor(args)
|
||||
|
||||
postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 're', logger)
|
||||
if args.re_model_dir is not None:
|
||||
postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 're', logger)
|
||||
else:
|
||||
self.predictor = None
|
||||
|
||||
def __call__(self, img):
|
||||
starttime = time.time()
|
||||
ser_results, ser_inputs, _ = self.ser_engine(img)
|
||||
ser_results, ser_inputs, ser_elapse = self.ser_engine(img)
|
||||
if self.predictor is None:
|
||||
return ser_results, ser_elapse
|
||||
|
||||
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
|
||||
if self.use_visual_backbone == False:
|
||||
re_input.pop(4)
|
||||
|
@ -77,7 +82,7 @@ class SerRePredictor(object):
|
|||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
ser_predictor = SerRePredictor(args)
|
||||
ser_re_predictor = SerRePredictor(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
|
||||
|
@ -93,7 +98,7 @@ def main(args):
|
|||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
re_res, elapse = ser_predictor(img)
|
||||
re_res, elapse = ser_re_predictor(img)
|
||||
re_res = re_res[0]
|
||||
|
||||
res_str = '{}\t{}\n'.format(
|
||||
|
@ -103,14 +108,20 @@ def main(args):
|
|||
"ocr_info": re_res,
|
||||
}, ensure_ascii=False))
|
||||
f_w.write(res_str)
|
||||
|
||||
img_res = draw_re_results(
|
||||
image_file, re_res, font_path=args.vis_font_path)
|
||||
|
||||
img_save_path = os.path.join(
|
||||
args.output,
|
||||
os.path.splitext(os.path.basename(image_file))[0] +
|
||||
"_ser_re.jpg")
|
||||
if ser_re_predictor.predictor is not None:
|
||||
img_res = draw_re_results(
|
||||
image_file, re_res, font_path=args.vis_font_path)
|
||||
img_save_path = os.path.join(
|
||||
args.output,
|
||||
os.path.splitext(os.path.basename(image_file))[0] +
|
||||
"_ser_re.jpg")
|
||||
else:
|
||||
img_res = draw_ser_results(
|
||||
image_file, re_res, font_path=args.vis_font_path)
|
||||
img_save_path = os.path.join(
|
||||
args.output,
|
||||
os.path.splitext(os.path.basename(image_file))[0] +
|
||||
"_ser.jpg")
|
||||
|
||||
cv2.imwrite(img_save_path, img_res)
|
||||
logger.info("save vis result to {}".format(img_save_path))
|
||||
|
|
|
@ -30,7 +30,7 @@ from copy import deepcopy
|
|||
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.visual import draw_re_results
|
||||
from ppocr.utils.visual import draw_ser_results, draw_re_results
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from ppstructure.layout.predict_layout import LayoutPredictor
|
||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||
|
@ -180,6 +180,7 @@ class StructureSystem(object):
|
|||
elif self.mode == 'kie':
|
||||
re_res, elapse = self.kie_predictor(img)
|
||||
time_dict['kie'] = elapse
|
||||
time_dict['all'] = elapse
|
||||
return re_res[0], time_dict
|
||||
return None, None
|
||||
|
||||
|
@ -246,8 +247,12 @@ def main(args):
|
|||
draw_img = draw_structure_result(img, res, args.vis_font_path)
|
||||
save_structure_res(res, save_folder, img_name, index)
|
||||
elif structure_sys.mode == 'kie':
|
||||
draw_img = draw_re_results(
|
||||
img, res, font_path=args.vis_font_path)
|
||||
if structure_sys.kie_predictor.predictor is not None:
|
||||
draw_img = draw_re_results(
|
||||
img, res, font_path=args.vis_font_path)
|
||||
else:
|
||||
draw_img = draw_ser_results(
|
||||
img, res, font_path=args.vis_font_path)
|
||||
|
||||
with open(
|
||||
os.path.join(save_folder, img_name,
|
||||
|
|
Loading…
Reference in New Issue