fix bug for static graph

pull/1868/head
HydrogenSulfate 2022-04-21 16:31:28 +08:00
parent 05770197c3
commit aa26a8c1d8
2 changed files with 11 additions and 3 deletions

View File

@ -224,7 +224,7 @@ class Engine(object):
# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.config["Global"]["epochs"],
self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader),
[self.model, self.train_loss_func])

View File

@ -44,8 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
optim_config = config["Optimizer"]
optim_config = copy.deepcopy(config)
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
@ -93,6 +92,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
else:
grad_clip = None
optim_model = []
# for static graph
if model_list is None:
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
return optim, lr
# for dynamic graph
for i in range(len(model_list)):
if len(model_list[i].parameters()) == 0:
continue