add save and load

pull/167/head
WuHaobo 2020-06-12 10:55:05 +08:00
parent f1c0a59f0d
commit f5fd4f1fd4
1 changed files with 28 additions and 9 deletions

View File

@ -74,7 +74,9 @@ def load_params(exe, prog, path, ignore_params=None):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
logger.info(logger.coloring('Loading parameters from {}...'.format(path), 'HEADER')) logger.info(
logger.coloring('Loading parameters from {}...'.format(path),
'HEADER'))
ignore_set = set() ignore_set = set()
state = _load_state(path) state = _load_state(path)
@ -100,37 +102,54 @@ def load_params(exe, prog, path, ignore_params=None):
if len(ignore_set) > 0: if len(ignore_set) > 0:
for k in ignore_set: for k in ignore_set:
if k in state: if k in state:
logger.warning('variable {} is already excluded automatically'.format(k)) logger.warning(
'variable {} is already excluded automatically'.format(k))
del state[k] del state[k]
fluid.io.set_program_state(prog, state) fluid.io.set_program_state(prog, state)
def init_model(config, program, exe): def init_model(config, net, optimizer):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
checkpoints = config.get('checkpoints') checkpoints = config.get('checkpoints')
if checkpoints: if checkpoints:
fluid.load(program, checkpoints, exe) assert os.path.exists(checkpoints + ".pdparams"), \
logger.info(logger.coloring("Finish initing model from {}".format(checkpoints),"HEADER")) "Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict, opti_dict = fluid.dygraph.load_dygraph(checkpoints)
net.set_dict(para_dict)
optimizer.set_dict(opti_dict)
logger.info(
logger.coloring("Finish initing model from {}".format(checkpoints),
"HEADER"))
return return
pretrained_model = config.get('pretrained_model') pretrained_model = config.get('pretrained_model')
if pretrained_model: if pretrained_model:
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
# TODO: load pretrained_model
raise NotImplementedError
for pretrain in pretrained_model: for pretrain in pretrained_model:
load_params(exe, program, pretrain) load_params(exe, program, pretrain)
logger.info(logger.coloring("Finish initing model from {}".format(pretrained_model),"HEADER")) logger.info(
logger.coloring("Finish initing model from {}".format(
pretrained_model), "HEADER"))
def save_model(program, model_path, epoch_id, prefix='ppcls'): def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
""" """
save model to the target path save model to the target path
""" """
model_path = os.path.join(model_path, str(epoch_id)) model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path) _mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
fluid.save(program, model_prefix)
logger.info(logger.coloring("Already save model in {}".format(model_path),"HEADER")) fluid.dygraph.save_dygraph(net.state_dict(), model_prefix)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_prefix)
logger.info(
logger.coloring("Already save model in {}".format(model_path),
"HEADER"))