mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
save the inference model in json format by default
This commit is contained in:
parent
e5972f75ef
commit
d15bbf82cf
@ -19,6 +19,14 @@ import paddle
|
|||||||
from . import logger
|
from . import logger
|
||||||
|
|
||||||
|
|
||||||
|
# just to determine the inference model file format
|
||||||
|
def get_FLAGS_json_format_model():
|
||||||
|
# json format by default
|
||||||
|
return os.environ.get("FLAGS_json_format_model", "1").lower() in ("1", "true", "t")
|
||||||
|
|
||||||
|
FLAGS_json_format_model = get_FLAGS_json_format_model()
|
||||||
|
|
||||||
|
|
||||||
def save_predict_result(save_path, result):
|
def save_predict_result(save_path, result):
|
||||||
if os.path.splitext(save_path)[-1] == '':
|
if os.path.splitext(save_path)[-1] == '':
|
||||||
if save_path[-1] == "/":
|
if save_path[-1] == "/":
|
||||||
@ -51,9 +59,19 @@ def update_train_results(config,
|
|||||||
train_results_path = os.path.join(config["Global"]["output_dir"],
|
train_results_path = os.path.join(config["Global"]["output_dir"],
|
||||||
"train_result.json")
|
"train_result.json")
|
||||||
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
||||||
save_inference_tag = [
|
if FLAGS_json_format_model:
|
||||||
"inference_config", "pdmodel", "pdiparams", "pdiparams.info"
|
save_inference_files = {
|
||||||
]
|
"inference_config": "inference.yml",
|
||||||
|
"pdmodel": "inference.json",
|
||||||
|
"pdiparams": "inference.pdiparams",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
save_inference_files = {
|
||||||
|
"inference_config": "inference.yml",
|
||||||
|
"pdmodel": "inference.pdmodel",
|
||||||
|
"pdiparams": "inference.pdiparams",
|
||||||
|
"pdiparams.info": "inference.pdiparams.info"
|
||||||
|
}
|
||||||
if ema:
|
if ema:
|
||||||
save_model_tag.append("pdema")
|
save_model_tag.append("pdema")
|
||||||
if os.path.exists(train_results_path):
|
if os.path.exists(train_results_path):
|
||||||
@ -81,10 +99,9 @@ def update_train_results(config,
|
|||||||
for tag in save_model_tag:
|
for tag in save_model_tag:
|
||||||
train_results["models"]["best"][tag] = os.path.join(
|
train_results["models"]["best"][tag] = os.path.join(
|
||||||
prefix, f"{prefix}.{tag}")
|
prefix, f"{prefix}.{tag}")
|
||||||
for tag in save_inference_tag:
|
for key in save_inference_files:
|
||||||
train_results["models"]["best"][tag] = os.path.join(
|
train_results["models"]["best"][key] = os.path.join(
|
||||||
prefix, "inference", f"inference.{tag}"
|
prefix, "inference", save_inference_files[key])
|
||||||
if tag != "inference_config" else "inference.yml")
|
|
||||||
else:
|
else:
|
||||||
for i in range(last_num - 1, 0, -1):
|
for i in range(last_num - 1, 0, -1):
|
||||||
train_results["models"][f"last_{i + 1}"] = train_results["models"][
|
train_results["models"][f"last_{i + 1}"] = train_results["models"][
|
||||||
@ -93,10 +110,9 @@ def update_train_results(config,
|
|||||||
for tag in save_model_tag:
|
for tag in save_model_tag:
|
||||||
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
||||||
prefix, f"{prefix}.{tag}")
|
prefix, f"{prefix}.{tag}")
|
||||||
for tag in save_inference_tag:
|
for key in save_inference_files:
|
||||||
train_results["models"][f"last_{1}"][tag] = os.path.join(
|
train_results["models"][f"last_{1}"][key] = os.path.join(
|
||||||
prefix, "inference", f"inference.{tag}"
|
prefix, "inference", save_inference_files[key])
|
||||||
if tag != "inference_config" else "inference.yml")
|
|
||||||
|
|
||||||
with open(train_results_path, "w") as fp:
|
with open(train_results_path, "w") as fp:
|
||||||
json.dump(train_results, fp)
|
json.dump(train_results, fp)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user