mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
opt load model
This commit is contained in:
parent
134e70704e
commit
8ffa50feb7
@ -57,7 +57,7 @@ def load_model(config, model, optimizer=None):
|
||||
if checkpoints.endswith('.pdparams'):
|
||||
checkpoints = checkpoints.replace('.pdparams', '')
|
||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||
"The {}.pdparams does not exists!".format(checkpoints)
|
||||
"The {}.pdparams is not exists!".format(checkpoints)
|
||||
|
||||
# load params from trained model
|
||||
params = paddle.load(checkpoints + '.pdparams')
|
||||
@ -67,6 +67,7 @@ def load_model(config, model, optimizer=None):
|
||||
if key not in params:
|
||||
logger.warning("{} not in loaded params {} !".format(
|
||||
key, params.keys()))
|
||||
continue
|
||||
pre_value = params[key]
|
||||
if list(value.shape) == list(pre_value.shape):
|
||||
new_state_dict[key] = pre_value
|
||||
@ -76,9 +77,14 @@ def load_model(config, model, optimizer=None):
|
||||
format(key, value.shape, pre_value.shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
if optimizer is not None:
|
||||
optimizer.set_state_dict(optim_dict)
|
||||
if os.path.exists(checkpoints + '.pdopt'):
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
optimizer.set_state_dict(optim_dict)
|
||||
else:
|
||||
logger.warning(
|
||||
"{}.pdopt is not exists, params of optimizer is not loaded".
|
||||
format(checkpoints))
|
||||
|
||||
if os.path.exists(checkpoints + '.states'):
|
||||
with open(checkpoints + '.states', 'rb') as f:
|
||||
@ -100,7 +106,7 @@ def load_pretrained_params(model, path):
|
||||
if path.endswith('.pdparams'):
|
||||
path = path.replace('.pdparams', '')
|
||||
assert os.path.exists(path + ".pdparams"), \
|
||||
"The {}.pdparams does not exists!".format(path)
|
||||
"The {}.pdparams is not exists!".format(path)
|
||||
|
||||
params = paddle.load(path + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
|
Loading…
x
Reference in New Issue
Block a user