Support for multiple components when build optimizer

This commit is contained in:
HydrogenSulfate 2022-04-22 13:15:49 +08:00
parent 8ae8934358
commit 74b4574367

View File

@ -67,12 +67,13 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}} # optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
# step1 build lr # step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope optim_scope_list = optim_item[optim_name].pop('scope').split(
' ') # get optim_scope list
optim_cfg = optim_item[optim_name] # get optim_cfg optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch) lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) for scope ({}) success..".format( logger.info("build lr ({}) for scope ({}) success..".format(
lr.__class__.__name__, optim_scope)) lr.__class__.__name__, optim_scope_list))
# step2 build regularization # step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None: if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg: if 'weight_decay' in optim_cfg:
@ -84,11 +85,13 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg = getattr(paddle.regularizer, reg_name)(**reg_config) reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg optim_cfg["weight_decay"] = reg
logger.info("build regularizer ({}) for scope ({}) success..". logger.info("build regularizer ({}) for scope ({}) success..".
format(reg.__class__.__name__, optim_scope)) format(reg.__class__.__name__, optim_scope_list))
# step3 build optimizer # step3 build optimizer
if 'clip_norm' in optim_cfg: if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm') clip_norm = optim_cfg.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
logger.info("build gradclip ({}) for scope ({}) success..".format(
grad_clip.__class__.__name__, optim_scope_list))
else: else:
grad_clip = None grad_clip = None
optim_model = [] optim_model = []
@ -101,33 +104,34 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
return optim, lr return optim, lr
# for dynamic graph # for dynamic graph
if optim_scope == "all": for scope in optim_scope_list:
optim_model = model_list if scope == "all":
elif optim_scope == "model": optim_model += model_list
optim_model = [model_list[0], ] elif scope == "model":
elif optim_scope in ["backbone", "neck", "head"]: optim_model += [model_list[0], ]
optim_model = [getattr(model_list[0], optim_scope, None), ] elif scope in ["backbone", "neck", "head"]:
elif optim_scope == "loss": optim_model += [getattr(model_list[0], scope, None), ]
optim_model = [model_list[1], ] elif scope == "loss":
else: optim_model += [model_list[1], ]
optim_model = [ else:
model_list[1].loss_func[i] optim_model += [
for i in range(len(model_list[1].loss_func)) model_list[1].loss_func[i]
if model_list[1].loss_func[i].__class__.__name__ == optim_scope for i in range(len(model_list[1].loss_func))
] if model_list[1].loss_func[i].__class__.__name__ == scope
]
# remove invalid items
optim_model = [ optim_model = [
optim_model[i] for i in range(len(optim_model)) optim_model[i] for i in range(len(optim_model))
if (optim_model[i] is not None if (optim_model[i] is not None
) and (len(optim_model[i].parameters()) > 0) ) and (len(optim_model[i].parameters()) > 0)
] ]
assert len(optim_model) > 0, \ assert len(optim_model) > 0, \
f"optim_model is empty for optim_scope({optim_scope})" f"optim_model is empty for optim_scope({optim_scope_list})"
optim = getattr(optimizer, optim_name)( optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip, learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model) **optim_cfg)(model_list=optim_model)
logger.info("build optimizer ({}) for scope ({}) success..".format( logger.info("build optimizer ({}) for scope ({}) success..".format(
optim.__class__.__name__, optim_scope)) optim.__class__.__name__, optim_scope_list))
optim_list.append(optim) optim_list.append(optim)
lr_list.append(lr) lr_list.append(lr)
return optim_list, lr_list return optim_list, lr_list