mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
rc版本适配
This commit is contained in:
parent
44840726ff
commit
4d775dc98f
@ -68,11 +68,11 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
|||||||
param_state_dict[key] = pre_state_dict[weight_name]
|
param_state_dict[key] = pre_state_dict[weight_name]
|
||||||
else:
|
else:
|
||||||
param_state_dict[key] = model_dict[key]
|
param_state_dict[key] = model_dict[key]
|
||||||
model.set_dict(param_state_dict)
|
model.set_state_dict(param_state_dict)
|
||||||
return
|
return
|
||||||
|
|
||||||
param_state_dict, optim_state_dict = paddle.load(path)
|
param_state_dict = paddle.load(path + '.pdparams')
|
||||||
model.set_dict(param_state_dict)
|
model.set_state_dict(param_state_dict)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||||||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||||
para_dict = paddle.load(checkpoints + '.pdparams')
|
para_dict = paddle.load(checkpoints + '.pdparams')
|
||||||
opti_dict = paddle.load(checkpoints + '.pdopt')
|
opti_dict = paddle.load(checkpoints + '.pdopt')
|
||||||
model.set_dict(para_dict)
|
model.set_state_dict(para_dict)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
optimizer.set_state_dict(opti_dict)
|
optimizer.set_state_dict(opti_dict)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user