parent
ef74ef5d3d
commit
0cc9870eb3
|
@ -20,6 +20,7 @@ import errno
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import json
|
import json
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
|
@ -288,7 +289,8 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
|
||||||
config["Global"]["save_model_dir"], "train_result.json"
|
config["Global"]["save_model_dir"], "train_result.json"
|
||||||
)
|
)
|
||||||
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
save_model_tag = ["pdparams", "pdopt", "pdstates"]
|
||||||
if FLAGS_json_format_model:
|
paddle_version = version.parse(paddle.__version__)
|
||||||
|
if FLAGS_json_format_model or paddle_version >= version.parse("3.0.0"):
|
||||||
save_inference_files = {
|
save_inference_files = {
|
||||||
"inference_config": "inference.yml",
|
"inference_config": "inference.yml",
|
||||||
"pdmodel": "inference.json",
|
"pdmodel": "inference.json",
|
||||||
|
|
Loading…
Reference in New Issue