pair param with key when load trained model params
parent
91b79f6929
commit
4a6f7ceca6
|
@ -57,22 +57,23 @@ def load_model(config, model, optimizer=None):
|
|||
if checkpoints.endswith('pdparams'):
|
||||
checkpoints = checkpoints.replace('.pdparams', '')
|
||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||
f"The {checkpoints}.pdparams does not exists!"
|
||||
|
||||
"The {}.pdparams does not exists!".format(checkpoints)
|
||||
|
||||
# load params from trained model
|
||||
params = paddle.load(checkpoints + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in params:
|
||||
logger.warning(f"{key} not in loaded params {params.keys()} !")
|
||||
logger.warning("{} not in loaded params {} !".format(
|
||||
key, params.keys()))
|
||||
pre_value = params[key]
|
||||
if list(value.shape) == list(pre_value.shape):
|
||||
new_state_dict[key] = pre_value
|
||||
else:
|
||||
logger.warning(
|
||||
f"The shape of model params {key} {value.shape} not matched with loaded params shape {pre_value.shape} !"
|
||||
)
|
||||
"The shape of model params {} {} not matched with loaded params shape {} !".
|
||||
format(key, value.shape, pre_value.shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
|
@ -99,7 +100,7 @@ def load_pretrained_params(model, path):
|
|||
if path.endswith('pdparams'):
|
||||
path = path.replace('.pdparams', '')
|
||||
assert os.path.exists(path + ".pdparams"), \
|
||||
f"The {path}.pdparams does not exists!"
|
||||
"The {}.pdparams does not exists!".format(path)
|
||||
|
||||
params = paddle.load(path + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
|
@ -109,10 +110,10 @@ def load_pretrained_params(model, path):
|
|||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.warning(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
"The shape of model params {} {} not matched with loaded params {} {} !".
|
||||
format(k1, state_dict[k1].shape, k2, params[k2].shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"load pretrain successful from {path}")
|
||||
logger.info("load pretrain successful from {}".format(path))
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue