diff --git a/ppocr/utils/export_model.py b/ppocr/utils/export_model.py index c58e5be5f..327fa0542 100644 --- a/ppocr/utils/export_model.py +++ b/ppocr/utils/export_model.py @@ -16,11 +16,13 @@ import os import yaml import json import copy +import shutil import paddle import paddle.nn as nn from paddle.jit import to_static from collections import OrderedDict +from packaging import version from argparse import ArgumentParser, RawDescriptionHelpFormatter from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -39,13 +41,15 @@ def setup_orderdict(): def dump_infer_config(config, path, logger): setup_orderdict() infer_cfg = OrderedDict() + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) if config["Global"].get("pdx_model_name", None): infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]} if config["Global"].get("uniform_output_enabled", None): arch_config = config["Architecture"] if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]: common_dynamic_shapes = { - "x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]] + "x": [[1, 3, 24, 160], [1, 3, 48, 320], [8, 3, 96, 640]] } elif arch_config["model_type"] == "det": common_dynamic_shapes = { @@ -53,7 +57,7 @@ def dump_infer_config(config, path, logger): } elif arch_config["algorithm"] == "SLANet": common_dynamic_shapes = { - "x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]] + "x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]] } elif arch_config["algorithm"] == "LaTeXOCR": common_dynamic_shapes = { @@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger): logger.info("Export inference config file to {}".format(os.path.join(path))) -def export_single_model( - model, arch_config, save_path, logger, input_shape=None, quanter=None -): +def dynamic_to_static(model, arch_config, logger, input_shape=None): if arch_config["algorithm"] == "SRN": max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ @@ -262,9 +264,46 @@ def export_single_model( for layer in model.sublayers(): if hasattr(layer, "rep") and not getattr(layer, "is_repped"): layer.rep() + return model + + +def export_single_model( + model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None +): + + model = dynamic_to_static(model, arch_config, logger, input_shape) if quanter is None: - paddle.jit.save(model, save_path) + paddle_version = version.parse(paddle.__version__) + if ( + paddle_version >= version.parse("3.0.0b2") + or paddle_version == version.parse("0.0.0") + ) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]: + save_path = os.path.dirname(save_path) + for enable_pir in [True, False]: + if not enable_pir: + save_path_no_pir = os.path.join(save_path, "inference") + model.forward.rollback() + with paddle.pir_utils.OldIrGuard(): + model = dynamic_to_static( + model, arch_config, logger, input_shape + ) + paddle.jit.save(model, save_path_no_pir) + else: + save_path_pir = os.path.join( + os.path.dirname(save_path), + f"{os.path.basename(save_path)}_pir", + "inference", + ) + paddle.jit.save(model, save_path_pir) + shutil.copy( + yaml_path, + os.path.join( + os.path.dirname(save_path_pir), os.path.basename(yaml_path) + ), + ) + else: + paddle.jit.save(model, save_path) else: quanter.save_quantized_model(model, save_path) logger.info("inference model is saved to {}".format(save_path)) @@ -362,7 +401,7 @@ def export(config, base_model=None, save_path=None): input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None else: input_shape = None - + dump_infer_config(config, yaml_path, logger) if arch_config["algorithm"] in [ "Distillation", ]: # distillation model @@ -370,11 +409,14 @@ def export(config, base_model=None, save_path=None): for idx, name in enumerate(model.model_name_list): sub_model_save_path = os.path.join(save_path, name, "inference") export_single_model( - model.model_list[idx], archs[idx], sub_model_save_path, logger + model.model_list[idx], + archs[idx], + sub_model_save_path, + logger, + yaml_path, ) else: save_path = os.path.join(save_path, "inference") export_single_model( - model, arch_config, save_path, logger, input_shape=input_shape + model, arch_config, save_path, logger, yaml_path, input_shape=input_shape ) - dump_infer_config(config, yaml_path, logger) diff --git a/requirements.txt b/requirements.txt index 87d0d1c11..4256d25d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ requests albumentations==1.4.10 # to be compatible with albumentations albucore==0.0.13 +packaging