mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
save the inference model in json format by default
This commit is contained in:
parent
e5972f75ef
commit
d15bbf82cf
@ -19,6 +19,14 @@ import paddle
|
||||
from . import logger
|
||||
|
||||
|
||||
# just to determine the inference model file format
|
||||
def get_FLAGS_json_format_model():
|
||||
# json format by default
|
||||
return os.environ.get("FLAGS_json_format_model", "1").lower() in ("1", "true", "t")
|
||||
|
||||
FLAGS_json_format_model = get_FLAGS_json_format_model()
|
||||
|
||||
|
||||
def save_predict_result(save_path, result):
|
||||
if os.path.splitext(save_path)[-1] == '':
|
||||
if save_path[-1] == "/":
|
||||
@ -51,9 +59,19 @@ def update_train_results(config,
|
||||
train_results_path = os.path.join(config["Global"]["output_dir"],
|
||||
"train_result.json")
|
||||
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
||||
save_inference_tag = [
|
||||
"inference_config", "pdmodel", "pdiparams", "pdiparams.info"
|
||||
]
|
||||
if FLAGS_json_format_model:
|
||||
save_inference_files = {
|
||||
"inference_config": "inference.yml",
|
||||
"pdmodel": "inference.json",
|
||||
"pdiparams": "inference.pdiparams",
|
||||
}
|
||||
else:
|
||||
save_inference_files = {
|
||||
"inference_config": "inference.yml",
|
||||
"pdmodel": "inference.pdmodel",
|
||||
"pdiparams": "inference.pdiparams",
|
||||
"pdiparams.info": "inference.pdiparams.info"
|
||||
}
|
||||
if ema:
|
||||
save_model_tag.append("pdema")
|
||||
if os.path.exists(train_results_path):
|
||||
@ -81,10 +99,9 @@ def update_train_results(config,
|
||||
for tag in save_model_tag:
|
||||
train_results["models"]["best"][tag] = os.path.join(
|
||||
prefix, f"{prefix}.{tag}")
|
||||
for tag in save_inference_tag:
|
||||
train_results["models"]["best"][tag] = os.path.join(
|
||||
prefix, "inference", f"inference.{tag}"
|
||||
if tag != "inference_config" else "inference.yml")
|
||||
for key in save_inference_files:
|
||||
train_results["models"]["best"][key] = os.path.join(
|
||||
prefix, "inference", save_inference_files[key])
|
||||
else:
|
||||
for i in range(last_num - 1, 0, -1):
|
||||
train_results["models"][f"last_{i + 1}"] = train_results["models"][
|
||||
@ -93,10 +110,9 @@ def update_train_results(config,
|
||||
for tag in save_model_tag:
|
||||
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
||||
prefix, f"{prefix}.{tag}")
|
||||
for tag in save_inference_tag:
|
||||
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
||||
prefix, "inference", f"inference.{tag}"
|
||||
if tag != "inference_config" else "inference.yml")
|
||||
for key in save_inference_files:
|
||||
train_results["models"][f"last_{1}"][key] = os.path.join(
|
||||
prefix, "inference", save_inference_files[key])
|
||||
|
||||
with open(train_results_path, "w") as fp:
|
||||
json.dump(train_results, fp)
|
||||
|
Loading…
x
Reference in New Issue
Block a user