add save and load

pull/167/head
WuHaobo 2020-06-12 10:55:05 +08:00
parent f1c0a59f0d
commit f5fd4f1fd4
1 changed files with 28 additions and 9 deletions

View File

@ -74,7 +74,9 @@ def load_params(exe, prog, path, ignore_params=None):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info(logger.coloring('Loading parameters from {}...'.format(path), 'HEADER'))
logger.info(
logger.coloring('Loading parameters from {}...'.format(path),
'HEADER'))
ignore_set = set()
state = _load_state(path)
@ -100,37 +102,54 @@ def load_params(exe, prog, path, ignore_params=None):
if len(ignore_set) > 0:
for k in ignore_set:
if k in state:
logger.warning('variable {} is already excluded automatically'.format(k))
logger.warning(
'variable {} is already excluded automatically'.format(k))
del state[k]
fluid.io.set_program_state(prog, state)
def init_model(config, program, exe):
def init_model(config, net, optimizer):
"""
load model from checkpoint or pretrained_model
"""
checkpoints = config.get('checkpoints')
if checkpoints:
fluid.load(program, checkpoints, exe)
logger.info(logger.coloring("Finish initing model from {}".format(checkpoints),"HEADER"))
assert os.path.exists(checkpoints + ".pdparams"), \
"Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict, opti_dict = fluid.dygraph.load_dygraph(checkpoints)
net.set_dict(para_dict)
optimizer.set_dict(opti_dict)
logger.info(
logger.coloring("Finish initing model from {}".format(checkpoints),
"HEADER"))
return
pretrained_model = config.get('pretrained_model')
if pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
# TODO: load pretrained_model
raise NotImplementedError
for pretrain in pretrained_model:
load_params(exe, program, pretrain)
logger.info(logger.coloring("Finish initing model from {}".format(pretrained_model),"HEADER"))
logger.info(
logger.coloring("Finish initing model from {}".format(
pretrained_model), "HEADER"))
def save_model(program, model_path, epoch_id, prefix='ppcls'):
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
"""
save model to the target path
"""
model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
fluid.save(program, model_prefix)
logger.info(logger.coloring("Already save model in {}".format(model_path),"HEADER"))
fluid.dygraph.save_dygraph(net.state_dict(), model_prefix)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_prefix)
logger.info(
logger.coloring("Already save model in {}".format(model_path),
"HEADER"))