add save and load
parent
f1c0a59f0d
commit
f5fd4f1fd4
|
@ -74,7 +74,9 @@ def load_params(exe, prog, path, ignore_params=None):
|
|||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
|
||||
logger.info(logger.coloring('Loading parameters from {}...'.format(path), 'HEADER'))
|
||||
logger.info(
|
||||
logger.coloring('Loading parameters from {}...'.format(path),
|
||||
'HEADER'))
|
||||
|
||||
ignore_set = set()
|
||||
state = _load_state(path)
|
||||
|
@ -100,37 +102,54 @@ def load_params(exe, prog, path, ignore_params=None):
|
|||
if len(ignore_set) > 0:
|
||||
for k in ignore_set:
|
||||
if k in state:
|
||||
logger.warning('variable {} is already excluded automatically'.format(k))
|
||||
logger.warning(
|
||||
'variable {} is already excluded automatically'.format(k))
|
||||
del state[k]
|
||||
|
||||
fluid.io.set_program_state(prog, state)
|
||||
|
||||
|
||||
def init_model(config, program, exe):
|
||||
def init_model(config, net, optimizer):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
checkpoints = config.get('checkpoints')
|
||||
if checkpoints:
|
||||
fluid.load(program, checkpoints, exe)
|
||||
logger.info(logger.coloring("Finish initing model from {}".format(checkpoints),"HEADER"))
|
||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||
"Given dir {}.pdparams not exist.".format(checkpoints)
|
||||
assert os.path.exists(checkpoints + ".pdopt"), \
|
||||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||
para_dict, opti_dict = fluid.dygraph.load_dygraph(checkpoints)
|
||||
net.set_dict(para_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
logger.info(
|
||||
logger.coloring("Finish initing model from {}".format(checkpoints),
|
||||
"HEADER"))
|
||||
return
|
||||
|
||||
pretrained_model = config.get('pretrained_model')
|
||||
if pretrained_model:
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
# TODO: load pretrained_model
|
||||
raise NotImplementedError
|
||||
for pretrain in pretrained_model:
|
||||
load_params(exe, program, pretrain)
|
||||
logger.info(logger.coloring("Finish initing model from {}".format(pretrained_model),"HEADER"))
|
||||
logger.info(
|
||||
logger.coloring("Finish initing model from {}".format(
|
||||
pretrained_model), "HEADER"))
|
||||
|
||||
|
||||
def save_model(program, model_path, epoch_id, prefix='ppcls'):
|
||||
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
|
||||
"""
|
||||
save model to the target path
|
||||
"""
|
||||
model_path = os.path.join(model_path, str(epoch_id))
|
||||
_mkdir_if_not_exist(model_path)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
fluid.save(program, model_prefix)
|
||||
logger.info(logger.coloring("Already save model in {}".format(model_path),"HEADER"))
|
||||
|
||||
fluid.dygraph.save_dygraph(net.state_dict(), model_prefix)
|
||||
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_prefix)
|
||||
logger.info(
|
||||
logger.coloring("Already save model in {}".format(model_path),
|
||||
"HEADER"))
|
||||
|
|
Loading…
Reference in New Issue