refine optimizer/__init__.py

pull/1853/head
HydrogenSulfate 2022-04-20 13:10:31 +08:00
parent 80b8ca3f23
commit 15242df170
1 changed files with 17 additions and 3 deletions

View File

@ -58,6 +58,12 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
}]
optim_list = []
lr_list = []
"""NOTE:
Currently only support optim objets below.
1. single optimizer config.
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
3. loss which has parameters, such as CenterLoss.
"""
for optim_item in optim_config:
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# step1 build lr
@ -91,11 +97,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
if len(model_list[i].parameters()) == 0:
continue
if optim_scope == "all":
# optimizer for all
optim_model.append(model_list[i])
else:
for m in model_list[i].sublayers(True):
if m.__class__.__name__ == optim_scope:
optim_model.append(model_list[i])
if optim_scope.endswith("Loss"):
# optimizer for loss
for m in model_list[i].sublayers(True):
if m.__class_name == optim_scope:
optim_model.append(m)
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
optim = getattr(optimizer, optim_name)(