diff --git a/ppstructure/docs/inference.md b/ppstructure/docs/inference.md index ed73b5b0f..7aa2fd0d9 100644 --- a/ppstructure/docs/inference.md +++ b/ppstructure/docs/inference.md @@ -81,13 +81,14 @@ mkdir inference && cd inference # 下载SER XFUND 模型并解压 wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar cd .. -python3 kie/predict_kie_token_ser.py \ +python3 predict_system.py \ --kie_algorithm=LayoutXLM \ - --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ + --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \ --image_dir=./docs/kie/input/zh_val_42.jpg \ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \ --vis_font_path=../doc/fonts/simfang.ttf \ - --ocr_order_method="tb-yx" + --ocr_order_method="tb-yx" \ + --mode=kie ``` 运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。 diff --git a/ppstructure/docs/inference_en.md b/ppstructure/docs/inference_en.md index ebf4aaf07..1bb683a68 100644 --- a/ppstructure/docs/inference_en.md +++ b/ppstructure/docs/inference_en.md @@ -82,13 +82,14 @@ mkdir inference && cd inference # download model wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar cd .. -python3 kie/predict_kie_token_ser.py \ +python3 predict_system.py \ --kie_algorithm=LayoutXLM \ - --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ + --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \ --image_dir=./docs/kie/input/zh_val_42.jpg \ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \ --vis_font_path=../doc/fonts/simfang.ttf \ - --ocr_order_method="tb-yx" + --ocr_order_method="tb-yx" \ + --mode=kie ``` After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name. diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py index 2846749ab..b29a8f69d 100644 --- a/ppstructure/kie/predict_kie_token_ser_re.py +++ b/ppstructure/kie/predict_kie_token_ser_re.py @@ -29,7 +29,7 @@ import tools.infer.utility as utility from tools.infer_kie_token_ser_re import make_input from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger -from ppocr.utils.visual import draw_re_results +from ppocr.utils.visual import draw_ser_results, draw_re_results from ppocr.utils.utility import get_image_file_list, check_and_read from ppstructure.utility import parse_args from ppstructure.kie.predict_kie_token_ser import SerPredictor @@ -41,15 +41,20 @@ class SerRePredictor(object): def __init__(self, args): self.use_visual_backbone = args.use_visual_backbone self.ser_engine = SerPredictor(args) - - postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'} - self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors, self.config = \ - utility.create_predictor(args, 're', logger) + if args.re_model_dir is not None: + postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'} + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 're', logger) + else: + self.predictor = None def __call__(self, img): starttime = time.time() - ser_results, ser_inputs, _ = self.ser_engine(img) + ser_results, ser_inputs, ser_elapse = self.ser_engine(img) + if self.predictor is None: + return ser_results, ser_elapse + re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) if self.use_visual_backbone == False: re_input.pop(4) @@ -77,7 +82,7 @@ class SerRePredictor(object): def main(args): image_file_list = get_image_file_list(args.image_dir) - ser_predictor = SerRePredictor(args) + ser_re_predictor = SerRePredictor(args) count = 0 total_time = 0 @@ -93,7 +98,7 @@ def main(args): if img is None: logger.info("error in loading image:{}".format(image_file)) continue - re_res, elapse = ser_predictor(img) + re_res, elapse = ser_re_predictor(img) re_res = re_res[0] res_str = '{}\t{}\n'.format( @@ -103,14 +108,20 @@ def main(args): "ocr_info": re_res, }, ensure_ascii=False)) f_w.write(res_str) - - img_res = draw_re_results( - image_file, re_res, font_path=args.vis_font_path) - - img_save_path = os.path.join( - args.output, - os.path.splitext(os.path.basename(image_file))[0] + - "_ser_re.jpg") + if ser_re_predictor.predictor is not None: + img_res = draw_re_results( + image_file, re_res, font_path=args.vis_font_path) + img_save_path = os.path.join( + args.output, + os.path.splitext(os.path.basename(image_file))[0] + + "_ser_re.jpg") + else: + img_res = draw_ser_results( + image_file, re_res, font_path=args.vis_font_path) + img_save_path = os.path.join( + args.output, + os.path.splitext(os.path.basename(image_file))[0] + + "_ser.jpg") cv2.imwrite(img_save_path, img_res) logger.info("save vis result to {}".format(img_save_path)) diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index e7b389b50..417002d1e 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -30,7 +30,7 @@ from copy import deepcopy from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.logging import get_logger -from ppocr.utils.visual import draw_re_results +from ppocr.utils.visual import draw_ser_results, draw_re_results from tools.infer.predict_system import TextSystem from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.table.predict_table import TableSystem, to_excel @@ -180,6 +180,7 @@ class StructureSystem(object): elif self.mode == 'kie': re_res, elapse = self.kie_predictor(img) time_dict['kie'] = elapse + time_dict['all'] = elapse return re_res[0], time_dict return None, None @@ -246,8 +247,12 @@ def main(args): draw_img = draw_structure_result(img, res, args.vis_font_path) save_structure_res(res, save_folder, img_name, index) elif structure_sys.mode == 'kie': - draw_img = draw_re_results( - img, res, font_path=args.vis_font_path) + if structure_sys.kie_predictor.predictor is not None: + draw_img = draw_re_results( + img, res, font_path=args.vis_font_path) + else: + draw_img = draw_ser_results( + img, res, font_path=args.vis_font_path) with open( os.path.join(save_folder, img_name,