mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix_uniform
This commit is contained in:
parent
a9a730e3b3
commit
7f35c77027
@ -236,6 +236,8 @@ def dump_infer_config(inference_config, path):
|
|||||||
setup_orderdict()
|
setup_orderdict()
|
||||||
infer_cfg = OrderedDict()
|
infer_cfg = OrderedDict()
|
||||||
config = copy.deepcopy(inference_config)
|
config = copy.deepcopy(inference_config)
|
||||||
|
if config["Global"].get("pdx_model_name", None):
|
||||||
|
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
|
||||||
if config.get("Infer"):
|
if config.get("Infer"):
|
||||||
transforms = config["Infer"]["transforms"]
|
transforms = config["Infer"]["transforms"]
|
||||||
elif config["DataLoader"]["Eval"].get("Query"):
|
elif config["DataLoader"]["Eval"].get("Query"):
|
||||||
@ -251,9 +253,15 @@ def dump_infer_config(inference_config, path):
|
|||||||
transform = next((item for item in transforms
|
transform = next((item for item in transforms
|
||||||
if 'ResizeImage' in item), None)
|
if 'ResizeImage' in item), None)
|
||||||
if transform:
|
if transform:
|
||||||
|
if isinstance(transform["ResizeImage"]["size"], list):
|
||||||
dynamic_shapes = transform["ResizeImage"]["size"][0]
|
dynamic_shapes = transform["ResizeImage"]["size"][0]
|
||||||
|
elif isinstance(transform["ResizeImage"]["size"], int):
|
||||||
|
dynamic_shapes = transform["ResizeImage"]["size"]
|
||||||
else:
|
else:
|
||||||
dynamic_shapes = 224
|
raise ValueError(
|
||||||
|
"ResizeImage size must be either a list or an int.")
|
||||||
|
else:
|
||||||
|
raise ValueError("No valid transform found.")
|
||||||
# Configuration required config for high-performance inference.
|
# Configuration required config for high-performance inference.
|
||||||
if config["Global"].get("hpi_config_path", None):
|
if config["Global"].get("hpi_config_path", None):
|
||||||
hpi_config = convert_to_dict(
|
hpi_config = convert_to_dict(
|
||||||
@ -272,9 +280,6 @@ def dump_infer_config(inference_config, path):
|
|||||||
hpi_config["Hpi"]["backend_config"]["tensorrt"][
|
hpi_config["Hpi"]["backend_config"]["tensorrt"][
|
||||||
"max_batch_size"] = 1
|
"max_batch_size"] = 1
|
||||||
infer_cfg["Hpi"] = hpi_config["Hpi"]
|
infer_cfg["Hpi"] = hpi_config["Hpi"]
|
||||||
if config["Global"].get("pdx_model_name", None):
|
|
||||||
infer_cfg["Global"] = {}
|
|
||||||
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
|
|
||||||
for transform in transforms:
|
for transform in transforms:
|
||||||
if "NormalizeImage" in transform:
|
if "NormalizeImage" in transform:
|
||||||
transform["NormalizeImage"]["channel_num"] = 3
|
transform["NormalizeImage"]["channel_num"] = 3
|
||||||
|
@ -215,6 +215,8 @@ def save_model_info(model_info, save_path, prefix):
|
|||||||
"""
|
"""
|
||||||
save model info to the target path
|
save model info to the target path
|
||||||
"""
|
"""
|
||||||
|
if paddle.distributed.get_rank() != 0:
|
||||||
|
return
|
||||||
save_path = os.path.join(save_path, prefix)
|
save_path = os.path.join(save_path, prefix)
|
||||||
if not os.path.exists(save_path):
|
if not os.path.exists(save_path):
|
||||||
os.makedirs(save_path)
|
os.makedirs(save_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user