pair param with key when load trained model params

pull/4748/head
WenmuZhou 2021-11-24 10:23:06 +00:00
parent 4a6f7ceca6
commit 7f1badf74c
1 changed files with 2 additions and 2 deletions

View File

@ -54,7 +54,7 @@ def load_model(config, model, optimizer=None):
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
if checkpoints:
if checkpoints.endswith('pdparams'):
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \
"The {}.pdparams does not exists!".format(checkpoints)
@ -97,7 +97,7 @@ def load_model(config, model, optimizer=None):
def load_pretrained_params(model, path):
logger = get_logger()
if path.endswith('pdparams'):
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)