support export with pir and no pir (#14379)

pull/14434/head
zhangyubo0722 2024-12-19 20:16:26 +08:00 committed by GitHub
parent 04c989b7fe
commit 0697d248f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 10 deletions

View File

@ -16,11 +16,13 @@ import os
import yaml
import json
import copy
import shutil
import paddle
import paddle.nn as nn
from paddle.jit import to_static
from collections import OrderedDict
from packaging import version
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
@ -39,13 +41,15 @@ def setup_orderdict():
def dump_infer_config(config, path, logger):
setup_orderdict()
infer_cfg = OrderedDict()
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
if config["Global"].get("uniform_output_enabled", None):
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
common_dynamic_shapes = {
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
"x": [[1, 3, 24, 160], [1, 3, 48, 320], [8, 3, 96, 640]]
}
elif arch_config["model_type"] == "det":
common_dynamic_shapes = {
@ -53,7 +57,7 @@ def dump_infer_config(config, path, logger):
}
elif arch_config["algorithm"] == "SLANet":
common_dynamic_shapes = {
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]]
}
elif arch_config["algorithm"] == "LaTeXOCR":
common_dynamic_shapes = {
@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
logger.info("Export inference config file to {}".format(os.path.join(path)))
def export_single_model(
model, arch_config, save_path, logger, input_shape=None, quanter=None
):
def dynamic_to_static(model, arch_config, logger, input_shape=None):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
@ -262,9 +264,46 @@ def export_single_model(
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
layer.rep()
return model
def export_single_model(
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
):
model = dynamic_to_static(model, arch_config, logger, input_shape)
if quanter is None:
paddle.jit.save(model, save_path)
paddle_version = version.parse(paddle.__version__)
if (
paddle_version >= version.parse("3.0.0b2")
or paddle_version == version.parse("0.0.0")
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
save_path = os.path.dirname(save_path)
for enable_pir in [True, False]:
if not enable_pir:
save_path_no_pir = os.path.join(save_path, "inference")
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(
model, arch_config, logger, input_shape
)
paddle.jit.save(model, save_path_no_pir)
else:
save_path_pir = os.path.join(
os.path.dirname(save_path),
f"{os.path.basename(save_path)}_pir",
"inference",
)
paddle.jit.save(model, save_path_pir)
shutil.copy(
yaml_path,
os.path.join(
os.path.dirname(save_path_pir), os.path.basename(yaml_path)
),
)
else:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path))
@ -362,7 +401,7 @@ def export(config, base_model=None, save_path=None):
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
else:
input_shape = None
dump_infer_config(config, yaml_path, logger)
if arch_config["algorithm"] in [
"Distillation",
]: # distillation model
@ -370,11 +409,14 @@ def export(config, base_model=None, save_path=None):
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(
model.model_list[idx], archs[idx], sub_model_save_path, logger
model.model_list[idx],
archs[idx],
sub_model_save_path,
logger,
yaml_path,
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
)
dump_infer_config(config, yaml_path, logger)

View File

@ -14,3 +14,4 @@ requests
albumentations==1.4.10
# to be compatible with albumentations
albucore==0.0.13
packaging