pair param with key when load trained model params
parent
4a6f7ceca6
commit
7f1badf74c
|
@ -54,7 +54,7 @@ def load_model(config, model, optimizer=None):
|
|||
pretrained_model = global_config.get('pretrained_model')
|
||||
best_model_dict = {}
|
||||
if checkpoints:
|
||||
if checkpoints.endswith('pdparams'):
|
||||
if checkpoints.endswith('.pdparams'):
|
||||
checkpoints = checkpoints.replace('.pdparams', '')
|
||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||
"The {}.pdparams does not exists!".format(checkpoints)
|
||||
|
@ -97,7 +97,7 @@ def load_model(config, model, optimizer=None):
|
|||
|
||||
def load_pretrained_params(model, path):
|
||||
logger = get_logger()
|
||||
if path.endswith('pdparams'):
|
||||
if path.endswith('.pdparams'):
|
||||
path = path.replace('.pdparams', '')
|
||||
assert os.path.exists(path + ".pdparams"), \
|
||||
"The {}.pdparams does not exists!".format(path)
|
||||
|
|
Loading…
Reference in New Issue