PaddleClas/ppcls/utils/save_result.py
2025-05-12 21:21:45 +08:00

123 lines
4.6 KiB
Python

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import yaml
import paddle
from packaging import version
from . import logger
# just to determine the inference model file format
def get_FLAGS_json_format_model():
# json format by default
return os.environ.get("FLAGS_json_format_model", "1").lower() in (
"1", "true", "t")
FLAGS_json_format_model = get_FLAGS_json_format_model()
def save_predict_result(save_path, result):
if os.path.splitext(save_path)[-1] == '':
if save_path[-1] == "/":
save_path = save_path[:-1]
save_path = save_path + '.json'
elif os.path.splitext(save_path)[-1] == '.json':
save_path = save_path
else:
raise Exception(
f"{save_path} is invalid input path, only files in json format are supported."
)
if os.path.exists(save_path):
logger.warning(f"The file {save_path} will be overwritten.")
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(result, f)
def update_train_results(config,
prefix,
metric_info,
done_flag=False,
last_num=5,
ema=False):
if paddle.distributed.get_rank() != 0:
return
assert last_num >= 1
train_results_path = os.path.join(config["Global"]["output_dir"],
"train_result.json")
save_model_tag = ["pdparams", "pdopt", "pdstates"]
paddle_version = version.parse(paddle.__version__)
if FLAGS_json_format_model or paddle_version >= version.parse("3.0.0"):
save_inference_files = {
"inference_config": "inference.yml",
"pdmodel": "inference.json",
"pdiparams": "inference.pdiparams",
}
else:
save_inference_files = {
"inference_config": "inference.yml",
"pdmodel": "inference.pdmodel",
"pdiparams": "inference.pdiparams",
"pdiparams.info": "inference.pdiparams.info"
}
if ema:
save_model_tag.append("pdema")
if os.path.exists(train_results_path):
with open(train_results_path, "r") as fp:
train_results = json.load(fp)
else:
train_results = {}
train_results["model_name"] = config["Global"].get("pdx_model_name",
None)
if config.get("infer", None):
train_results["label_dict"] = config["Infer"]["PostProcess"].get(
"class_id_map_file", "")
else:
train_results["label_dict"] = ""
train_results["train_log"] = "train.log"
train_results["visualdl_log"] = ""
train_results["config"] = "config.yaml"
train_results["models"] = {}
for i in range(1, last_num + 1):
train_results["models"][f"last_{i}"] = {}
train_results["models"]["best"] = {}
train_results["done_flag"] = done_flag
if prefix == "best_model":
train_results["models"]["best"]["score"] = metric_info["metric"]
for tag in save_model_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix, f"{prefix}.{tag}")
for key in save_inference_files:
train_results["models"]["best"][key] = os.path.join(
prefix, "inference", save_inference_files[key])
else:
for i in range(last_num - 1, 0, -1):
train_results["models"][f"last_{i + 1}"] = train_results["models"][
f"last_{i}"].copy()
train_results["models"][f"last_{1}"]["score"] = metric_info["metric"]
for tag in save_model_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix, f"{prefix}.{tag}")
for key in save_inference_files:
train_results["models"][f"last_{1}"][key] = os.path.join(
prefix, "inference", save_inference_files[key])
with open(train_results_path, "w") as fp:
json.dump(train_results, fp)