fix save distillation model (#567)
parent
0e6fe6f1cf
commit
d0ecff1b5a
|
@ -73,13 +73,13 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
|
|||
def load_distillation_model(model, pretrained_model, load_static_weights):
|
||||
logger.info("In distillation mode, teacher model will be "
|
||||
"loaded firstly before student model.")
|
||||
assert len(pretrained_model
|
||||
) == 2, "pretrained_model length should be 2 but got {}".format(
|
||||
len(pretrained_model))
|
||||
assert len(
|
||||
load_static_weights
|
||||
) == 2, "load_static_weights length should be 2 but got {}".format(
|
||||
len(load_static_weights))
|
||||
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights]
|
||||
|
||||
teacher = model.teacher if hasattr(model,
|
||||
"teacher") else model._layers.teacher
|
||||
student = model.student if hasattr(model,
|
||||
|
@ -88,16 +88,16 @@ def load_distillation_model(model, pretrained_model, load_static_weights):
|
|||
teacher,
|
||||
path=pretrained_model[0],
|
||||
load_static_weights=load_static_weights[0])
|
||||
logger.info(
|
||||
logger.coloring("Finish initing teacher model from {}".format(
|
||||
pretrained_model), "HEADER"))
|
||||
load_dygraph_pretrain(
|
||||
student,
|
||||
path=pretrained_model[1],
|
||||
load_static_weights=load_static_weights[1])
|
||||
logger.info(
|
||||
logger.coloring("Finish initing student model from {}".format(
|
||||
pretrained_model), "HEADER"))
|
||||
logger.info("Finish initing teacher model from {}".format(
|
||||
pretrained_model))
|
||||
# load student model
|
||||
if len(pretrained_model) >= 2:
|
||||
load_dygraph_pretrain(
|
||||
student,
|
||||
path=pretrained_model[1],
|
||||
load_static_weights=load_static_weights[1])
|
||||
logger.info("Finish initing student model from {}".format(
|
||||
pretrained_model))
|
||||
|
||||
|
||||
def init_model(config, net, optimizer=None):
|
||||
|
@ -123,11 +123,7 @@ def init_model(config, net, optimizer=None):
|
|||
load_static_weights = config.get('load_static_weights', False)
|
||||
use_distillation = config.get('use_distillation', False)
|
||||
if pretrained_model:
|
||||
if isinstance(pretrained_model,
|
||||
list): # load distillation pretrained model
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(
|
||||
pretrained_model)
|
||||
if use_distillation:
|
||||
load_distillation_model(net, pretrained_model, load_static_weights)
|
||||
else: # common load
|
||||
load_dygraph_pretrain(
|
||||
|
|
Loading…
Reference in New Issue