fix bug for static graph
parent
05770197c3
commit
aa26a8c1d8
|
@ -224,7 +224,7 @@ class Engine(object):
|
|||
# build optimizer
|
||||
if self.mode == 'train':
|
||||
self.optimizer, self.lr_sch = build_optimizer(
|
||||
self.config, self.config["Global"]["epochs"],
|
||||
self.config["Optimizer"], self.config["Global"]["epochs"],
|
||||
len(self.train_dataloader),
|
||||
[self.model, self.train_loss_func])
|
||||
|
||||
|
|
|
@ -44,8 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
|||
|
||||
# model_list is None in static graph
|
||||
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||
config = copy.deepcopy(config)
|
||||
optim_config = config["Optimizer"]
|
||||
optim_config = copy.deepcopy(config)
|
||||
if isinstance(optim_config, dict):
|
||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
||||
optim_name = optim_config.pop("name")
|
||||
|
@ -93,6 +92,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
else:
|
||||
grad_clip = None
|
||||
optim_model = []
|
||||
|
||||
# for static graph
|
||||
if model_list is None:
|
||||
optim = getattr(optimizer, optim_name)(
|
||||
learning_rate=lr, grad_clip=grad_clip,
|
||||
**optim_cfg)(model_list=optim_model)
|
||||
return optim, lr
|
||||
|
||||
# for dynamic graph
|
||||
for i in range(len(model_list)):
|
||||
if len(model_list[i].parameters()) == 0:
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue