fix pdmodel to json (#15122)

Co-authored-by: zhangyubo0722 <zangyubo0722@163.com>
pull/15130/head
zhangyubo0722 2025-05-12 21:22:52 +08:00 committed by GitHub
parent ef74ef5d3d
commit 0cc9870eb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -20,6 +20,7 @@ import errno
import os import os
import pickle import pickle
import json import json
from packaging import version
import paddle import paddle
@ -288,7 +289,8 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
config["Global"]["save_model_dir"], "train_result.json" config["Global"]["save_model_dir"], "train_result.json"
) )
save_model_tag = ["pdparams", "pdopt", "pdstates"] save_model_tag = ["pdparams", "pdopt", "pdstates"]
if FLAGS_json_format_model: paddle_version = version.parse(paddle.__version__)
if FLAGS_json_format_model or paddle_version >= version.parse("3.0.0"):
save_inference_files = { save_inference_files = {
"inference_config": "inference.yml", "inference_config": "inference.yml",
"pdmodel": "inference.json", "pdmodel": "inference.json",