rename train result (#14231)
parent
eaef336f9d
commit
54decf96d0
|
@ -273,7 +273,7 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
|
|||
|
||||
assert last_num >= 1
|
||||
train_results_path = os.path.join(
|
||||
config["Global"]["save_model_dir"], "train_results.json"
|
||||
config["Global"]["save_model_dir"], "train_result.json"
|
||||
)
|
||||
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
||||
save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"]
|
||||
|
|
|
@ -172,11 +172,11 @@ def main(config, device, logger, vdl_writer, seed):
|
|||
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
|
||||
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
|
||||
if os.path.exists(
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
|
||||
):
|
||||
try:
|
||||
os.remove(
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue