diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index 15e408c22..5df438c40 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -236,6 +236,8 @@ def dump_infer_config(inference_config, path): setup_orderdict() infer_cfg = OrderedDict() config = copy.deepcopy(inference_config) + if config["Global"].get("pdx_model_name", None): + infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]} if config.get("Infer"): transforms = config["Infer"]["transforms"] elif config["DataLoader"]["Eval"].get("Query"): @@ -251,9 +253,15 @@ def dump_infer_config(inference_config, path): transform = next((item for item in transforms if 'ResizeImage' in item), None) if transform: - dynamic_shapes = transform["ResizeImage"]["size"][0] + if isinstance(transform["ResizeImage"]["size"], list): + dynamic_shapes = transform["ResizeImage"]["size"][0] + elif isinstance(transform["ResizeImage"]["size"], int): + dynamic_shapes = transform["ResizeImage"]["size"] + else: + raise ValueError( + "ResizeImage size must be either a list or an int.") else: - dynamic_shapes = 224 + raise ValueError("No valid transform found.") # Configuration required config for high-performance inference. if config["Global"].get("hpi_config_path", None): hpi_config = convert_to_dict( @@ -272,9 +280,6 @@ def dump_infer_config(inference_config, path): hpi_config["Hpi"]["backend_config"]["tensorrt"][ "max_batch_size"] = 1 infer_cfg["Hpi"] = hpi_config["Hpi"] - if config["Global"].get("pdx_model_name", None): - infer_cfg["Global"] = {} - infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"] for transform in transforms: if "NormalizeImage" in transform: transform["NormalizeImage"]["channel_num"] = 3 diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 470a96f29..13ccab2ad 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -215,6 +215,8 @@ def save_model_info(model_info, save_path, prefix): """ save model info to the target path """ + if paddle.distributed.get_rank() != 0: + return save_path = os.path.join(save_path, prefix) if not os.path.exists(save_path): os.makedirs(save_path)