fix_uniform

This commit is contained in:
zhangyubo0722 2024-09-25 11:08:39 +00:00 committed by Tingquan Gao
parent a9a730e3b3
commit 7f35c77027
2 changed files with 12 additions and 5 deletions

View File

@ -236,6 +236,8 @@ def dump_infer_config(inference_config, path):
setup_orderdict() setup_orderdict()
infer_cfg = OrderedDict() infer_cfg = OrderedDict()
config = copy.deepcopy(inference_config) config = copy.deepcopy(inference_config)
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
if config.get("Infer"): if config.get("Infer"):
transforms = config["Infer"]["transforms"] transforms = config["Infer"]["transforms"]
elif config["DataLoader"]["Eval"].get("Query"): elif config["DataLoader"]["Eval"].get("Query"):
@ -251,9 +253,15 @@ def dump_infer_config(inference_config, path):
transform = next((item for item in transforms transform = next((item for item in transforms
if 'ResizeImage' in item), None) if 'ResizeImage' in item), None)
if transform: if transform:
if isinstance(transform["ResizeImage"]["size"], list):
dynamic_shapes = transform["ResizeImage"]["size"][0] dynamic_shapes = transform["ResizeImage"]["size"][0]
elif isinstance(transform["ResizeImage"]["size"], int):
dynamic_shapes = transform["ResizeImage"]["size"]
else: else:
dynamic_shapes = 224 raise ValueError(
"ResizeImage size must be either a list or an int.")
else:
raise ValueError("No valid transform found.")
# Configuration required config for high-performance inference. # Configuration required config for high-performance inference.
if config["Global"].get("hpi_config_path", None): if config["Global"].get("hpi_config_path", None):
hpi_config = convert_to_dict( hpi_config = convert_to_dict(
@ -272,9 +280,6 @@ def dump_infer_config(inference_config, path):
hpi_config["Hpi"]["backend_config"]["tensorrt"][ hpi_config["Hpi"]["backend_config"]["tensorrt"][
"max_batch_size"] = 1 "max_batch_size"] = 1
infer_cfg["Hpi"] = hpi_config["Hpi"] infer_cfg["Hpi"] = hpi_config["Hpi"]
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {}
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
for transform in transforms: for transform in transforms:
if "NormalizeImage" in transform: if "NormalizeImage" in transform:
transform["NormalizeImage"]["channel_num"] = 3 transform["NormalizeImage"]["channel_num"] = 3

View File

@ -215,6 +215,8 @@ def save_model_info(model_info, save_path, prefix):
""" """
save model info to the target path save model info to the target path
""" """
if paddle.distributed.get_rank() != 0:
return
save_path = os.path.join(save_path, prefix) save_path = os.path.join(save_path, prefix)
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)