rename train result (#14217)
parent
0accd26000
commit
1d4e7a80a0
|
@ -268,7 +268,7 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
|
|||
|
||||
assert last_num >= 1
|
||||
train_results_path = os.path.join(
|
||||
config["Global"]["save_model_dir"], "train_results.json"
|
||||
config["Global"]["save_model_dir"], "train_result.json"
|
||||
)
|
||||
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
||||
save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"]
|
||||
|
|
|
@ -172,11 +172,11 @@ def main(config, device, logger, vdl_writer, seed):
|
|||
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
|
||||
amp_custom_white_list = config["Global"].get("amp_custom_white_list", [])
|
||||
if os.path.exists(
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
|
||||
):
|
||||
try:
|
||||
os.remove(
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_results.json")
|
||||
os.path.join(config["Global"]["save_model_dir"], "train_result.json")
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue