fix_uniform
parent
a9a730e3b3
commit
7f35c77027
|
@ -236,6 +236,8 @@ def dump_infer_config(inference_config, path):
|
|||
setup_orderdict()
|
||||
infer_cfg = OrderedDict()
|
||||
config = copy.deepcopy(inference_config)
|
||||
if config["Global"].get("pdx_model_name", None):
|
||||
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
|
||||
if config.get("Infer"):
|
||||
transforms = config["Infer"]["transforms"]
|
||||
elif config["DataLoader"]["Eval"].get("Query"):
|
||||
|
@ -251,9 +253,15 @@ def dump_infer_config(inference_config, path):
|
|||
transform = next((item for item in transforms
|
||||
if 'ResizeImage' in item), None)
|
||||
if transform:
|
||||
if isinstance(transform["ResizeImage"]["size"], list):
|
||||
dynamic_shapes = transform["ResizeImage"]["size"][0]
|
||||
elif isinstance(transform["ResizeImage"]["size"], int):
|
||||
dynamic_shapes = transform["ResizeImage"]["size"]
|
||||
else:
|
||||
dynamic_shapes = 224
|
||||
raise ValueError(
|
||||
"ResizeImage size must be either a list or an int.")
|
||||
else:
|
||||
raise ValueError("No valid transform found.")
|
||||
# Configuration required config for high-performance inference.
|
||||
if config["Global"].get("hpi_config_path", None):
|
||||
hpi_config = convert_to_dict(
|
||||
|
@ -272,9 +280,6 @@ def dump_infer_config(inference_config, path):
|
|||
hpi_config["Hpi"]["backend_config"]["tensorrt"][
|
||||
"max_batch_size"] = 1
|
||||
infer_cfg["Hpi"] = hpi_config["Hpi"]
|
||||
if config["Global"].get("pdx_model_name", None):
|
||||
infer_cfg["Global"] = {}
|
||||
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
|
||||
for transform in transforms:
|
||||
if "NormalizeImage" in transform:
|
||||
transform["NormalizeImage"]["channel_num"] = 3
|
||||
|
|
|
@ -215,6 +215,8 @@ def save_model_info(model_info, save_path, prefix):
|
|||
"""
|
||||
save model info to the target path
|
||||
"""
|
||||
if paddle.distributed.get_rank() != 0:
|
||||
return
|
||||
save_path = os.path.join(save_path, prefix)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
|
Loading…
Reference in New Issue