diff --git a/ppcls/utils/save_result.py b/ppcls/utils/save_result.py index 5714676ca..bcd1b5076 100644 --- a/ppcls/utils/save_result.py +++ b/ppcls/utils/save_result.py @@ -19,6 +19,14 @@ import paddle from . import logger +# just to determine the inference model file format +def get_FLAGS_json_format_model(): + # json format by default + return os.environ.get("FLAGS_json_format_model", "1").lower() in ("1", "true", "t") + +FLAGS_json_format_model = get_FLAGS_json_format_model() + + def save_predict_result(save_path, result): if os.path.splitext(save_path)[-1] == '': if save_path[-1] == "/": @@ -51,9 +59,19 @@ def update_train_results(config, train_results_path = os.path.join(config["Global"]["output_dir"], "train_result.json") save_model_tag = ["pdparams", "pdopt", "pdstates"] - save_inference_tag = [ - "inference_config", "pdmodel", "pdiparams", "pdiparams.info" - ] + if FLAGS_json_format_model: + save_inference_files = { + "inference_config": "inference.yml", + "pdmodel": "inference.json", + "pdiparams": "inference.pdiparams", + } + else: + save_inference_files = { + "inference_config": "inference.yml", + "pdmodel": "inference.pdmodel", + "pdiparams": "inference.pdiparams", + "pdiparams.info": "inference.pdiparams.info" + } if ema: save_model_tag.append("pdema") if os.path.exists(train_results_path): @@ -81,10 +99,9 @@ def update_train_results(config, for tag in save_model_tag: train_results["models"]["best"][tag] = os.path.join( prefix, f"{prefix}.{tag}") - for tag in save_inference_tag: - train_results["models"]["best"][tag] = os.path.join( - prefix, "inference", f"inference.{tag}" - if tag != "inference_config" else "inference.yml") + for key in save_inference_files: + train_results["models"]["best"][key] = os.path.join( + prefix, "inference", save_inference_files[key]) else: for i in range(last_num - 1, 0, -1): train_results["models"][f"last_{i + 1}"] = train_results["models"][ @@ -93,10 +110,9 @@ def update_train_results(config, for tag in save_model_tag: train_results["models"][f"last_{1}"][tag] = os.path.join( prefix, f"{prefix}.{tag}") - for tag in save_inference_tag: - train_results["models"][f"last_{1}"][tag] = os.path.join( - prefix, "inference", f"inference.{tag}" - if tag != "inference_config" else "inference.yml") + for key in save_inference_files: + train_results["models"][f"last_{1}"][key] = os.path.join( + prefix, "inference", save_inference_files[key]) with open(train_results_path, "w") as fp: json.dump(train_results, fp)