save the inference model in json format by default

This commit is contained in:
gaotingquan 2025-04-15 09:04:04 +00:00 committed by Tingquan Gao
parent e5972f75ef
commit d15bbf82cf

View File

@ -19,6 +19,14 @@ import paddle
from . import logger
# just to determine the inference model file format
def get_FLAGS_json_format_model():
# json format by default
return os.environ.get("FLAGS_json_format_model", "1").lower() in ("1", "true", "t")
FLAGS_json_format_model = get_FLAGS_json_format_model()
def save_predict_result(save_path, result):
if os.path.splitext(save_path)[-1] == '':
if save_path[-1] == "/":
@ -51,9 +59,19 @@ def update_train_results(config,
train_results_path = os.path.join(config["Global"]["output_dir"],
"train_result.json")
save_model_tag = ["pdparams", "pdopt", "pdstates"]
save_inference_tag = [
"inference_config", "pdmodel", "pdiparams", "pdiparams.info"
]
if FLAGS_json_format_model:
save_inference_files = {
"inference_config": "inference.yml",
"pdmodel": "inference.json",
"pdiparams": "inference.pdiparams",
}
else:
save_inference_files = {
"inference_config": "inference.yml",
"pdmodel": "inference.pdmodel",
"pdiparams": "inference.pdiparams",
"pdiparams.info": "inference.pdiparams.info"
}
if ema:
save_model_tag.append("pdema")
if os.path.exists(train_results_path):
@ -81,10 +99,9 @@ def update_train_results(config,
for tag in save_model_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix, f"{prefix}.{tag}")
for tag in save_inference_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix, "inference", f"inference.{tag}"
if tag != "inference_config" else "inference.yml")
for key in save_inference_files:
train_results["models"]["best"][key] = os.path.join(
prefix, "inference", save_inference_files[key])
else:
for i in range(last_num - 1, 0, -1):
train_results["models"][f"last_{i + 1}"] = train_results["models"][
@ -93,10 +110,9 @@ def update_train_results(config,
for tag in save_model_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix, f"{prefix}.{tag}")
for tag in save_inference_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix, "inference", f"inference.{tag}"
if tag != "inference_config" else "inference.yml")
for key in save_inference_files:
train_results["models"][f"last_{1}"][key] = os.path.join(
prefix, "inference", save_inference_files[key])
with open(train_results_path, "w") as fp:
json.dump(train_results, fp)