diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 4e27f12c1..7e7869d20 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -105,7 +105,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): net.set_state_dict(para_dict) loss.set_state_dict(para_dict) for i in range(len(optimizer)): - optimizer[i].set_state_dict(opti_dict) + optimizer[i].set_state_dict(opti_dict[i]) logger.info("Finish load checkpoints from {}".format(checkpoints)) return metric_dict @@ -117,7 +117,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): else: # common load load_dygraph_pretrain(net, path=pretrained_model) logger.info("Finish load pretrained model from {}".format( - pretrained_model)) + pretrained_model)) def save_model(net,