mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add custom_group to DefaultOptimizerConstrutor (#347)
* feat: add custom_group to DefaultOptimConstrutor * refactor: move custom_groups validate to _validate_cfg * docs: add doc to explain custom_groups * feat: add unittest for non_exist_key * refactor: one param per group * fix: small fix * fix: name * docs: docstring * refactor: change to mult for only lr and wd custom * docs: docstring * docs: more explaination * feat: sort custom key * docs: add docstring * refactor: use reverse arg of sorted * docs: fix comment * docs: fix comment * refactor: small modi * refactor: small modi * refactor: small modi
This commit is contained in:
parent
657f03ad08
commit
c74d729d92
@ -14,6 +14,17 @@ class DefaultOptimizerConstructor:
|
||||
By default each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
It is a dict and may contain the following fields:
|
||||
|
||||
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
||||
one of the keys in ``custom_keys`` is a substring of the name of one
|
||||
parameter, then the setting of the parameter will be specified by
|
||||
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
|
||||
be ignored. It should be noted that the aforementioned ``key`` is the
|
||||
longest key that is a substring of the name of the parameter. If there
|
||||
are multiple matched keys with the same length, then the key with lower
|
||||
alphabet order will be chosen.
|
||||
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
|
||||
and ``decay_mult``. See Example 2 below.
|
||||
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for all bias parameters (except for those in normalization
|
||||
layers).
|
||||
@ -39,7 +50,7 @@ class DefaultOptimizerConstructor:
|
||||
lr, weight_decay, momentum, etc.
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
|
||||
Example:
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
||||
>>> weight_decay=0.0001)
|
||||
@ -47,6 +58,18 @@ class DefaultOptimizerConstructor:
|
||||
>>> optim_builder = DefaultOptimizerConstructor(
|
||||
>>> optimizer_cfg, paramwise_cfg)
|
||||
>>> optimizer = optim_builder(model)
|
||||
|
||||
Example 2:
|
||||
>>> # assume model have attribute model.backbone and model.cls_head
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
|
||||
>>> paramwise_cfg = dict(custom_keys={
|
||||
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||
>>> optim_builder = DefaultOptimizerConstructor(
|
||||
>>> optimizer_cfg, paramwise_cfg)
|
||||
>>> optimizer = optim_builder(model)
|
||||
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
||||
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
||||
>>> # model.cls_head is (0.01, 0.95).
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
@ -63,6 +86,13 @@ class DefaultOptimizerConstructor:
|
||||
if not isinstance(self.paramwise_cfg, dict):
|
||||
raise TypeError('paramwise_cfg should be None or a dict, '
|
||||
f'but got {type(self.paramwise_cfg)}')
|
||||
|
||||
if ('custom_keys' in self.paramwise_cfg
|
||||
and not isinstance(self.paramwise_cfg['custom_keys'], dict)):
|
||||
raise TypeError(
|
||||
'If specified, custom_keys must be a dict, '
|
||||
f'but got {type(self.paramwise_cfg["custom_keys"])}')
|
||||
|
||||
# get base lr and weight decay
|
||||
# weight_decay must be explicitly specified if mult is specified
|
||||
if ('bias_decay_mult' in self.paramwise_cfg
|
||||
@ -93,6 +123,10 @@ class DefaultOptimizerConstructor:
|
||||
prefix (str): The prefix of the module
|
||||
"""
|
||||
# get param-wise options
|
||||
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
||||
# first sort with alphabet order and then sort with reversed len of str
|
||||
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
||||
|
||||
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
|
||||
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
|
||||
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
|
||||
@ -115,23 +149,35 @@ class DefaultOptimizerConstructor:
|
||||
warnings.warn(f'{prefix} is duplicate. It is skipped since '
|
||||
f'bypass_duplicate={bypass_duplicate}')
|
||||
continue
|
||||
# bias_lr_mult affects all bias parameters except for norm.bias
|
||||
if name == 'bias' and not is_norm:
|
||||
param_group['lr'] = self.base_lr * bias_lr_mult
|
||||
# apply weight decay policies
|
||||
if self.base_wd is not None:
|
||||
# norm decay
|
||||
if is_norm:
|
||||
# if the parameter match one of the custom keys, ignore other rules
|
||||
is_custom = False
|
||||
for key in sorted_keys:
|
||||
if key in f'{prefix}.{name}':
|
||||
is_custom = True
|
||||
param_group['lr'] = self.base_lr * custom_keys[key].get(
|
||||
'lr_mult', 1.)
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * norm_decay_mult
|
||||
# depth-wise conv
|
||||
elif is_dwconv:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * dwconv_decay_mult
|
||||
# bias lr and decay
|
||||
elif name == 'bias':
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * bias_decay_mult
|
||||
'weight_decay'] = self.base_wd * custom_keys[key].get(
|
||||
'decay_mult', 1.)
|
||||
break
|
||||
if not is_custom:
|
||||
# bias_lr_mult affects all bias parameters except for norm.bias
|
||||
if name == 'bias' and not is_norm:
|
||||
param_group['lr'] = self.base_lr * bias_lr_mult
|
||||
# apply weight decay policies
|
||||
if self.base_wd is not None:
|
||||
# norm decay
|
||||
if is_norm:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * norm_decay_mult
|
||||
# depth-wise conv
|
||||
elif is_dwconv:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * dwconv_decay_mult
|
||||
# bias lr and decay
|
||||
elif name == 'bias':
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * bias_decay_mult
|
||||
params.append(param_group)
|
||||
|
||||
for child_name, child_mod in module.named_children():
|
||||
|
@ -355,6 +355,86 @@ def test_default_optimizer_constructor():
|
||||
assert len(optimizer.param_groups) == len(model_parameters) == 11
|
||||
check_optimizer(optimizer, model, **paramwise_cfg)
|
||||
|
||||
# test DefaultOptimizerConstructor with custom_groups and ExampleModel
|
||||
model = ExampleModel()
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||
paramwise_cfg = dict(
|
||||
custom_keys={
|
||||
'param1': dict(lr_mult=10),
|
||||
'sub': dict(lr_mult=0.1, decay_mult=0),
|
||||
'sub.gn': dict(lr_mult=0.01),
|
||||
'non_exist_key': dict(lr_mult=0.0)
|
||||
},
|
||||
norm_decay_mult=0.5)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# custom_keys should be a dict
|
||||
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optim_constructor, paramwise_cfg_)
|
||||
optimizer = optim_constructor(model)
|
||||
|
||||
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
|
||||
paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
# check optimizer type and default config
|
||||
assert isinstance(optimizer, torch.optim.SGD)
|
||||
assert optimizer.defaults['lr'] == base_lr
|
||||
assert optimizer.defaults['momentum'] == momentum
|
||||
assert optimizer.defaults['weight_decay'] == base_wd
|
||||
|
||||
# check params groups
|
||||
param_groups = optimizer.param_groups
|
||||
|
||||
groups = []
|
||||
group_settings = []
|
||||
# group 1, matches of 'param1'
|
||||
# 'param1' is the longest match for 'sub.param1'
|
||||
groups.append(['param1', 'sub.param1'])
|
||||
group_settings.append({
|
||||
'lr': base_lr * 10,
|
||||
'momentum': momentum,
|
||||
'weight_decay': base_wd,
|
||||
})
|
||||
# group 2, matches of 'sub.gn'
|
||||
groups.append(['sub.gn.weight', 'sub.gn.bias'])
|
||||
group_settings.append({
|
||||
'lr': base_lr * 0.01,
|
||||
'momentum': momentum,
|
||||
'weight_decay': base_wd,
|
||||
})
|
||||
# group 3, matches of 'sub'
|
||||
groups.append(['sub.conv1.weight', 'sub.conv1.bias'])
|
||||
group_settings.append({
|
||||
'lr': base_lr * 0.1,
|
||||
'momentum': momentum,
|
||||
'weight_decay': 0,
|
||||
})
|
||||
# group 4, bn is configured by 'norm_decay_mult'
|
||||
groups.append(['bn.weight', 'bn.bias'])
|
||||
group_settings.append({
|
||||
'lr': base_lr,
|
||||
'momentum': momentum,
|
||||
'weight_decay': base_wd * 0.5,
|
||||
})
|
||||
# group 5, default group
|
||||
groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias'])
|
||||
group_settings.append({
|
||||
'lr': base_lr,
|
||||
'momentum': momentum,
|
||||
'weight_decay': base_wd
|
||||
})
|
||||
|
||||
assert len(param_groups) == 11
|
||||
for i, (name, param) in enumerate(model.named_parameters()):
|
||||
assert torch.equal(param_groups[i]['params'][0], param)
|
||||
for group, settings in zip(groups, group_settings):
|
||||
if name in group:
|
||||
for setting in settings:
|
||||
assert param_groups[i][setting] == settings[
|
||||
setting], f'{name} {setting}'
|
||||
|
||||
|
||||
def test_torch_optimizers():
|
||||
torch_optimizers = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user