Add regex support for selecting optim object

This commit is contained in:
tianyi1997 2023-02-07 14:58:13 +08:00 committed by HydrogenSulfate
parent 8d4a79e57f
commit 0cc6bc0bd3

View File

@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import copy
import paddle
from typing import Dict, List
@ -120,6 +121,11 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_model.append(model_list[i])
elif hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
else:
for name, layer in model_list[i].named_sublayers():
if len(layer.parameters()) != 0 \
and re.fullmatch(optim_scope, name):
optim_model.append(layer)
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,