support export with pir and no pir (#14379)
parent
04c989b7fe
commit
0697d248f8
|
@ -16,11 +16,13 @@ import os
|
|||
import yaml
|
||||
import json
|
||||
import copy
|
||||
import shutil
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.jit import to_static
|
||||
|
||||
from collections import OrderedDict
|
||||
from packaging import version
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
@ -39,13 +41,15 @@ def setup_orderdict():
|
|||
def dump_infer_config(config, path, logger):
|
||||
setup_orderdict()
|
||||
infer_cfg = OrderedDict()
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
if config["Global"].get("pdx_model_name", None):
|
||||
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
|
||||
if config["Global"].get("uniform_output_enabled", None):
|
||||
arch_config = config["Architecture"]
|
||||
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
|
||||
common_dynamic_shapes = {
|
||||
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
|
||||
"x": [[1, 3, 24, 160], [1, 3, 48, 320], [8, 3, 96, 640]]
|
||||
}
|
||||
elif arch_config["model_type"] == "det":
|
||||
common_dynamic_shapes = {
|
||||
|
@ -53,7 +57,7 @@ def dump_infer_config(config, path, logger):
|
|||
}
|
||||
elif arch_config["algorithm"] == "SLANet":
|
||||
common_dynamic_shapes = {
|
||||
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
|
||||
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]]
|
||||
}
|
||||
elif arch_config["algorithm"] == "LaTeXOCR":
|
||||
common_dynamic_shapes = {
|
||||
|
@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
|
|||
logger.info("Export inference config file to {}".format(os.path.join(path)))
|
||||
|
||||
|
||||
def export_single_model(
|
||||
model, arch_config, save_path, logger, input_shape=None, quanter=None
|
||||
):
|
||||
def dynamic_to_static(model, arch_config, logger, input_shape=None):
|
||||
if arch_config["algorithm"] == "SRN":
|
||||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
|
@ -262,9 +264,46 @@ def export_single_model(
|
|||
for layer in model.sublayers():
|
||||
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
|
||||
layer.rep()
|
||||
return model
|
||||
|
||||
|
||||
def export_single_model(
|
||||
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
|
||||
):
|
||||
|
||||
model = dynamic_to_static(model, arch_config, logger, input_shape)
|
||||
|
||||
if quanter is None:
|
||||
paddle.jit.save(model, save_path)
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
if (
|
||||
paddle_version >= version.parse("3.0.0b2")
|
||||
or paddle_version == version.parse("0.0.0")
|
||||
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
|
||||
save_path = os.path.dirname(save_path)
|
||||
for enable_pir in [True, False]:
|
||||
if not enable_pir:
|
||||
save_path_no_pir = os.path.join(save_path, "inference")
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = dynamic_to_static(
|
||||
model, arch_config, logger, input_shape
|
||||
)
|
||||
paddle.jit.save(model, save_path_no_pir)
|
||||
else:
|
||||
save_path_pir = os.path.join(
|
||||
os.path.dirname(save_path),
|
||||
f"{os.path.basename(save_path)}_pir",
|
||||
"inference",
|
||||
)
|
||||
paddle.jit.save(model, save_path_pir)
|
||||
shutil.copy(
|
||||
yaml_path,
|
||||
os.path.join(
|
||||
os.path.dirname(save_path_pir), os.path.basename(yaml_path)
|
||||
),
|
||||
)
|
||||
else:
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
quanter.save_quantized_model(model, save_path)
|
||||
logger.info("inference model is saved to {}".format(save_path))
|
||||
|
@ -362,7 +401,7 @@ def export(config, base_model=None, save_path=None):
|
|||
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
|
||||
else:
|
||||
input_shape = None
|
||||
|
||||
dump_infer_config(config, yaml_path, logger)
|
||||
if arch_config["algorithm"] in [
|
||||
"Distillation",
|
||||
]: # distillation model
|
||||
|
@ -370,11 +409,14 @@ def export(config, base_model=None, save_path=None):
|
|||
for idx, name in enumerate(model.model_name_list):
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(
|
||||
model.model_list[idx], archs[idx], sub_model_save_path, logger
|
||||
model.model_list[idx],
|
||||
archs[idx],
|
||||
sub_model_save_path,
|
||||
logger,
|
||||
yaml_path,
|
||||
)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(
|
||||
model, arch_config, save_path, logger, input_shape=input_shape
|
||||
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
|
||||
)
|
||||
dump_infer_config(config, yaml_path, logger)
|
||||
|
|
|
@ -14,3 +14,4 @@ requests
|
|||
albumentations==1.4.10
|
||||
# to be compatible with albumentations
|
||||
albucore==0.0.13
|
||||
packaging
|
||||
|
|
Loading…
Reference in New Issue