fix_uniform

pull/3260/head
zhangyubo0722 2024-09-25 11:08:39 +00:00 committed by Tingquan Gao
parent a9a730e3b3
commit 7f35c77027
2 changed files with 12 additions and 5 deletions

View File

@ -236,6 +236,8 @@ def dump_infer_config(inference_config, path):
setup_orderdict()
infer_cfg = OrderedDict()
config = copy.deepcopy(inference_config)
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
if config.get("Infer"):
transforms = config["Infer"]["transforms"]
elif config["DataLoader"]["Eval"].get("Query"):
@ -251,9 +253,15 @@ def dump_infer_config(inference_config, path):
transform = next((item for item in transforms
if 'ResizeImage' in item), None)
if transform:
if isinstance(transform["ResizeImage"]["size"], list):
dynamic_shapes = transform["ResizeImage"]["size"][0]
elif isinstance(transform["ResizeImage"]["size"], int):
dynamic_shapes = transform["ResizeImage"]["size"]
else:
dynamic_shapes = 224
raise ValueError(
"ResizeImage size must be either a list or an int.")
else:
raise ValueError("No valid transform found.")
# Configuration required config for high-performance inference.
if config["Global"].get("hpi_config_path", None):
hpi_config = convert_to_dict(
@ -272,9 +280,6 @@ def dump_infer_config(inference_config, path):
hpi_config["Hpi"]["backend_config"]["tensorrt"][
"max_batch_size"] = 1
infer_cfg["Hpi"] = hpi_config["Hpi"]
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {}
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
for transform in transforms:
if "NormalizeImage" in transform:
transform["NormalizeImage"]["channel_num"] = 3

View File

@ -215,6 +215,8 @@ def save_model_info(model_info, save_path, prefix):
"""
save model info to the target path
"""
if paddle.distributed.get_rank() != 0:
return
save_path = os.path.join(save_path, prefix)
if not os.path.exists(save_path):
os.makedirs(save_path)