rename train result (#14231)

pull/14375/head
zhangyubo0722 2024-11-14 19:12:57 +08:00 committed by GitHub
parent eaef336f9d
commit 54decf96d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -273,7 +273,7 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
assert last_num >= 1
train_results_path = os.path.join(
config["Global"]["save_model_dir"], "train_results.json"
config["Global"]["save_model_dir"], "train_result.json"
)
save_model_tag = ["pdparams", "pdopt", "pdstates"]
save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"]

View File

@ -172,11 +172,11 @@ def main(config, device, logger, vdl_writer, seed):
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
if os.path.exists(
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
):
try:
os.remove(
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
)
except:
pass