Add custom_group to DefaultOptimizerConstrutor (#347)

* feat: add custom_group to DefaultOptimConstrutor

* refactor: move custom_groups validate to _validate_cfg

* docs: add doc to explain custom_groups

* feat: add unittest for non_exist_key

* refactor: one param per group

* fix: small fix

* fix: name

* docs: docstring

* refactor: change to mult for only lr and wd custom

* docs: docstring

* docs: more explaination

* feat: sort custom key

* docs: add docstring

* refactor: use reverse arg of sorted

* docs: fix comment

* docs: fix comment

* refactor: small modi

* refactor: small modi

* refactor: small modi
This commit is contained in:
Harry 2020-06-17 23:39:50 +08:00 committed by GitHub
parent 657f03ad08
commit c74d729d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 17 deletions

View File

@ -14,6 +14,17 @@ class DefaultOptimizerConstructor:
By default each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
one of the keys in ``custom_keys`` is a substring of the name of one
parameter, then the setting of the parameter will be specified by
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
be ignored. It should be noted that the aforementioned ``key`` is the
longest key that is a substring of the name of the parameter. If there
are multiple matched keys with the same length, then the key with lower
alphabet order will be chosen.
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
and ``decay_mult``. See Example 2 below.
- ``bias_lr_mult`` (float): It will be multiplied to the learning
rate for all bias parameters (except for those in normalization
layers).
@ -39,7 +50,7 @@ class DefaultOptimizerConstructor:
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
Example:
Example 1:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001)
@ -47,6 +58,18 @@ class DefaultOptimizerConstructor:
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
Example 2:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
>>> paramwise_cfg = dict(custom_keys={
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
>>> # Then the `lr` and `weight_decay` for model.backbone is
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
>>> # model.cls_head is (0.01, 0.95).
"""
def __init__(self, optimizer_cfg, paramwise_cfg=None):
@ -63,6 +86,13 @@ class DefaultOptimizerConstructor:
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
if ('custom_keys' in self.paramwise_cfg
and not isinstance(self.paramwise_cfg['custom_keys'], dict)):
raise TypeError(
'If specified, custom_keys must be a dict, '
f'but got {type(self.paramwise_cfg["custom_keys"])}')
# get base lr and weight decay
# weight_decay must be explicitly specified if mult is specified
if ('bias_decay_mult' in self.paramwise_cfg
@ -93,6 +123,10 @@ class DefaultOptimizerConstructor:
prefix (str): The prefix of the module
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
@ -115,23 +149,35 @@ class DefaultOptimizerConstructor:
warnings.warn(f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}')
continue
# bias_lr_mult affects all bias parameters except for norm.bias
if name == 'bias' and not is_norm:
param_group['lr'] = self.base_lr * bias_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm:
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in sorted_keys:
if key in f'{prefix}.{name}':
is_custom = True
param_group['lr'] = self.base_lr * custom_keys[key].get(
'lr_mult', 1.)
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# depth-wise conv
elif is_dwconv:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# bias lr and decay
elif name == 'bias':
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
'weight_decay'] = self.base_wd * custom_keys[key].get(
'decay_mult', 1.)
break
if not is_custom:
# bias_lr_mult affects all bias parameters except for norm.bias
if name == 'bias' and not is_norm:
param_group['lr'] = self.base_lr * bias_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# depth-wise conv
elif is_dwconv:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# bias lr and decay
elif name == 'bias':
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
params.append(param_group)
for child_name, child_mod in module.named_children():

View File

@ -355,6 +355,86 @@ def test_default_optimizer_constructor():
assert len(optimizer.param_groups) == len(model_parameters) == 11
check_optimizer(optimizer, model, **paramwise_cfg)
# test DefaultOptimizerConstructor with custom_groups and ExampleModel
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
paramwise_cfg = dict(
custom_keys={
'param1': dict(lr_mult=10),
'sub': dict(lr_mult=0.1, decay_mult=0),
'sub.gn': dict(lr_mult=0.01),
'non_exist_key': dict(lr_mult=0.0)
},
norm_decay_mult=0.5)
with pytest.raises(TypeError):
# custom_keys should be a dict
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
optim_constructor = DefaultOptimizerConstructor(
optim_constructor, paramwise_cfg_)
optimizer = optim_constructor(model)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
# check optimizer type and default config
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['momentum'] == momentum
assert optimizer.defaults['weight_decay'] == base_wd
# check params groups
param_groups = optimizer.param_groups
groups = []
group_settings = []
# group 1, matches of 'param1'
# 'param1' is the longest match for 'sub.param1'
groups.append(['param1', 'sub.param1'])
group_settings.append({
'lr': base_lr * 10,
'momentum': momentum,
'weight_decay': base_wd,
})
# group 2, matches of 'sub.gn'
groups.append(['sub.gn.weight', 'sub.gn.bias'])
group_settings.append({
'lr': base_lr * 0.01,
'momentum': momentum,
'weight_decay': base_wd,
})
# group 3, matches of 'sub'
groups.append(['sub.conv1.weight', 'sub.conv1.bias'])
group_settings.append({
'lr': base_lr * 0.1,
'momentum': momentum,
'weight_decay': 0,
})
# group 4, bn is configured by 'norm_decay_mult'
groups.append(['bn.weight', 'bn.bias'])
group_settings.append({
'lr': base_lr,
'momentum': momentum,
'weight_decay': base_wd * 0.5,
})
# group 5, default group
groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias'])
group_settings.append({
'lr': base_lr,
'momentum': momentum,
'weight_decay': base_wd
})
assert len(param_groups) == 11
for i, (name, param) in enumerate(model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
for group, settings in zip(groups, group_settings):
if name in group:
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'
def test_torch_optimizers():
torch_optimizers = [