mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
update hpi config (#14187)
This commit is contained in:
parent
aae2a3e6c9
commit
eaef336f9d
@ -39,42 +39,42 @@ def setup_orderdict():
|
|||||||
def dump_infer_config(config, path, logger):
|
def dump_infer_config(config, path, logger):
|
||||||
setup_orderdict()
|
setup_orderdict()
|
||||||
infer_cfg = OrderedDict()
|
infer_cfg = OrderedDict()
|
||||||
if config["Global"].get("hpi_config_path", None):
|
|
||||||
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
|
|
||||||
rec_resize_img_dict = next(
|
|
||||||
(
|
|
||||||
item
|
|
||||||
for item in config["Eval"]["dataset"]["transforms"]
|
|
||||||
if "RecResizeImg" in item
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if rec_resize_img_dict:
|
|
||||||
dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"]
|
|
||||||
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
|
|
||||||
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
|
|
||||||
"dynamic_shapes"
|
|
||||||
]["x"] = [dynamic_shapes for i in range(3)]
|
|
||||||
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
|
|
||||||
"max_batch_size"
|
|
||||||
] = 1
|
|
||||||
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
|
|
||||||
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
|
|
||||||
"x"
|
|
||||||
] = [dynamic_shapes for i in range(3)]
|
|
||||||
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
|
|
||||||
else:
|
|
||||||
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
|
|
||||||
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
|
|
||||||
del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"]
|
|
||||||
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
|
|
||||||
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
|
|
||||||
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
|
|
||||||
hpi_config["Hpi"]["selected_backends"]["gpu"] = "paddle_infer"
|
|
||||||
infer_cfg["Hpi"] = hpi_config["Hpi"]
|
|
||||||
if config["Global"].get("pdx_model_name", None):
|
if config["Global"].get("pdx_model_name", None):
|
||||||
infer_cfg["Global"] = {}
|
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
|
||||||
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
|
if config["Global"].get("uniform_output_enabled", None):
|
||||||
|
arch_config = config["Architecture"]
|
||||||
|
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
|
||||||
|
common_dynamic_shapes = {
|
||||||
|
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
|
||||||
|
}
|
||||||
|
elif arch_config["model_type"] == "det":
|
||||||
|
common_dynamic_shapes = {
|
||||||
|
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
|
||||||
|
}
|
||||||
|
elif arch_config["algorithm"] == "SLANet":
|
||||||
|
common_dynamic_shapes = {
|
||||||
|
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
|
||||||
|
}
|
||||||
|
elif arch_config["algorithm"] == "LaTeXOCR":
|
||||||
|
common_dynamic_shapes = {
|
||||||
|
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
common_dynamic_shapes = None
|
||||||
|
|
||||||
|
backend_keys = ["paddle_infer", "tensorrt"]
|
||||||
|
hpi_config = {
|
||||||
|
"backend_configs": {
|
||||||
|
key: {
|
||||||
|
(
|
||||||
|
"dynamic_shapes" if key == "tensorrt" else "trt_dynamic_shapes"
|
||||||
|
): common_dynamic_shapes
|
||||||
|
}
|
||||||
|
for key in backend_keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if common_dynamic_shapes:
|
||||||
|
infer_cfg["Hpi"] = hpi_config
|
||||||
|
|
||||||
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
|
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
|
||||||
postprocess = OrderedDict()
|
postprocess = OrderedDict()
|
||||||
@ -96,10 +96,8 @@ def dump_infer_config(config, path, logger):
|
|||||||
|
|
||||||
infer_cfg["PostProcess"] = postprocess
|
infer_cfg["PostProcess"] = postprocess
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
yaml.dump(
|
yaml.dump(infer_cfg, f, default_flow_style=False, allow_unicode=True)
|
||||||
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
|
|
||||||
)
|
|
||||||
logger.info("Export inference config file to {}".format(os.path.join(path)))
|
logger.info("Export inference config file to {}".format(os.path.join(path)))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user