diff --git a/ppocr/utils/export_model.py b/ppocr/utils/export_model.py index 2c6b921e5..91c41dda3 100644 --- a/ppocr/utils/export_model.py +++ b/ppocr/utils/export_model.py @@ -318,42 +318,31 @@ def dynamic_to_static(model, arch_config, logger, input_shape=None): def export_single_model( - model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None + model, + arch_config, + save_path, + logger, + yaml_path, + config, + input_shape=None, + quanter=None, ): model = dynamic_to_static(model, arch_config, logger, input_shape) if quanter is None: - paddle_version = version.parse(paddle.__version__) - if ( - paddle_version >= version.parse("3.0.0b2") - or paddle_version == version.parse("0.0.0") - ) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]: - save_path = os.path.dirname(save_path) - for enable_pir in [True, False]: - if not enable_pir: - save_path_no_pir = os.path.join(save_path, "inference") - model.forward.rollback() - with paddle.pir_utils.OldIrGuard(): - model = dynamic_to_static( - model, arch_config, logger, input_shape - ) - paddle.jit.save(model, save_path_no_pir) - else: - save_path_pir = os.path.join( - os.path.dirname(save_path), - f"{os.path.basename(save_path)}_pir", - "inference", - ) - paddle.jit.save(model, save_path_pir) - shutil.copy( - yaml_path, - os.path.join( - os.path.dirname(save_path_pir), os.path.basename(yaml_path) - ), - ) - else: + if config["Global"].get("export_with_pir", False): + paddle_version = version.parse(paddle.__version__) + assert ( + paddle_version >= version.parse("3.0.0b2") + or paddle_version == version.parse("0.0.0") + ) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"] paddle.jit.save(model, save_path) + else: + model.forward.rollback() + with paddle.pir_utils.OldIrGuard(): + model = dynamic_to_static(model, arch_config, logger, input_shape) + paddle.jit.save(model, save_path) else: quanter.save_quantized_model(model, save_path) logger.info("inference model is saved to {}".format(save_path)) @@ -472,9 +461,16 @@ def export(config, base_model=None, save_path=None): sub_model_save_path, logger, yaml_path, + config, ) else: save_path = os.path.join(save_path, "inference") export_single_model( - model, arch_config, save_path, logger, yaml_path, input_shape=input_shape + model, + arch_config, + save_path, + logger, + yaml_path, + config, + input_shape=input_shape, )