mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add custom_group to DefaultOptimizerConstrutor (#347)
* feat: add custom_group to DefaultOptimConstrutor * refactor: move custom_groups validate to _validate_cfg * docs: add doc to explain custom_groups * feat: add unittest for non_exist_key * refactor: one param per group * fix: small fix * fix: name * docs: docstring * refactor: change to mult for only lr and wd custom * docs: docstring * docs: more explaination * feat: sort custom key * docs: add docstring * refactor: use reverse arg of sorted * docs: fix comment * docs: fix comment * refactor: small modi * refactor: small modi * refactor: small modi
This commit is contained in:
parent
657f03ad08
commit
c74d729d92
@ -14,6 +14,17 @@ class DefaultOptimizerConstructor:
|
|||||||
By default each parameter share the same optimizer settings, and we
|
By default each parameter share the same optimizer settings, and we
|
||||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||||
It is a dict and may contain the following fields:
|
It is a dict and may contain the following fields:
|
||||||
|
|
||||||
|
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
||||||
|
one of the keys in ``custom_keys`` is a substring of the name of one
|
||||||
|
parameter, then the setting of the parameter will be specified by
|
||||||
|
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
|
||||||
|
be ignored. It should be noted that the aforementioned ``key`` is the
|
||||||
|
longest key that is a substring of the name of the parameter. If there
|
||||||
|
are multiple matched keys with the same length, then the key with lower
|
||||||
|
alphabet order will be chosen.
|
||||||
|
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
|
||||||
|
and ``decay_mult``. See Example 2 below.
|
||||||
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
||||||
rate for all bias parameters (except for those in normalization
|
rate for all bias parameters (except for those in normalization
|
||||||
layers).
|
layers).
|
||||||
@ -39,7 +50,7 @@ class DefaultOptimizerConstructor:
|
|||||||
lr, weight_decay, momentum, etc.
|
lr, weight_decay, momentum, etc.
|
||||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||||
|
|
||||||
Example:
|
Example 1:
|
||||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
||||||
>>> weight_decay=0.0001)
|
>>> weight_decay=0.0001)
|
||||||
@ -47,6 +58,18 @@ class DefaultOptimizerConstructor:
|
|||||||
>>> optim_builder = DefaultOptimizerConstructor(
|
>>> optim_builder = DefaultOptimizerConstructor(
|
||||||
>>> optimizer_cfg, paramwise_cfg)
|
>>> optimizer_cfg, paramwise_cfg)
|
||||||
>>> optimizer = optim_builder(model)
|
>>> optimizer = optim_builder(model)
|
||||||
|
|
||||||
|
Example 2:
|
||||||
|
>>> # assume model have attribute model.backbone and model.cls_head
|
||||||
|
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
|
||||||
|
>>> paramwise_cfg = dict(custom_keys={
|
||||||
|
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||||
|
>>> optim_builder = DefaultOptimizerConstructor(
|
||||||
|
>>> optimizer_cfg, paramwise_cfg)
|
||||||
|
>>> optimizer = optim_builder(model)
|
||||||
|
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
||||||
|
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
||||||
|
>>> # model.cls_head is (0.01, 0.95).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||||
@ -63,6 +86,13 @@ class DefaultOptimizerConstructor:
|
|||||||
if not isinstance(self.paramwise_cfg, dict):
|
if not isinstance(self.paramwise_cfg, dict):
|
||||||
raise TypeError('paramwise_cfg should be None or a dict, '
|
raise TypeError('paramwise_cfg should be None or a dict, '
|
||||||
f'but got {type(self.paramwise_cfg)}')
|
f'but got {type(self.paramwise_cfg)}')
|
||||||
|
|
||||||
|
if ('custom_keys' in self.paramwise_cfg
|
||||||
|
and not isinstance(self.paramwise_cfg['custom_keys'], dict)):
|
||||||
|
raise TypeError(
|
||||||
|
'If specified, custom_keys must be a dict, '
|
||||||
|
f'but got {type(self.paramwise_cfg["custom_keys"])}')
|
||||||
|
|
||||||
# get base lr and weight decay
|
# get base lr and weight decay
|
||||||
# weight_decay must be explicitly specified if mult is specified
|
# weight_decay must be explicitly specified if mult is specified
|
||||||
if ('bias_decay_mult' in self.paramwise_cfg
|
if ('bias_decay_mult' in self.paramwise_cfg
|
||||||
@ -93,6 +123,10 @@ class DefaultOptimizerConstructor:
|
|||||||
prefix (str): The prefix of the module
|
prefix (str): The prefix of the module
|
||||||
"""
|
"""
|
||||||
# get param-wise options
|
# get param-wise options
|
||||||
|
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
||||||
|
# first sort with alphabet order and then sort with reversed len of str
|
||||||
|
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
||||||
|
|
||||||
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
|
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
|
||||||
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
|
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
|
||||||
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
|
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
|
||||||
@ -115,6 +149,18 @@ class DefaultOptimizerConstructor:
|
|||||||
warnings.warn(f'{prefix} is duplicate. It is skipped since '
|
warnings.warn(f'{prefix} is duplicate. It is skipped since '
|
||||||
f'bypass_duplicate={bypass_duplicate}')
|
f'bypass_duplicate={bypass_duplicate}')
|
||||||
continue
|
continue
|
||||||
|
# if the parameter match one of the custom keys, ignore other rules
|
||||||
|
is_custom = False
|
||||||
|
for key in sorted_keys:
|
||||||
|
if key in f'{prefix}.{name}':
|
||||||
|
is_custom = True
|
||||||
|
param_group['lr'] = self.base_lr * custom_keys[key].get(
|
||||||
|
'lr_mult', 1.)
|
||||||
|
param_group[
|
||||||
|
'weight_decay'] = self.base_wd * custom_keys[key].get(
|
||||||
|
'decay_mult', 1.)
|
||||||
|
break
|
||||||
|
if not is_custom:
|
||||||
# bias_lr_mult affects all bias parameters except for norm.bias
|
# bias_lr_mult affects all bias parameters except for norm.bias
|
||||||
if name == 'bias' and not is_norm:
|
if name == 'bias' and not is_norm:
|
||||||
param_group['lr'] = self.base_lr * bias_lr_mult
|
param_group['lr'] = self.base_lr * bias_lr_mult
|
||||||
|
@ -355,6 +355,86 @@ def test_default_optimizer_constructor():
|
|||||||
assert len(optimizer.param_groups) == len(model_parameters) == 11
|
assert len(optimizer.param_groups) == len(model_parameters) == 11
|
||||||
check_optimizer(optimizer, model, **paramwise_cfg)
|
check_optimizer(optimizer, model, **paramwise_cfg)
|
||||||
|
|
||||||
|
# test DefaultOptimizerConstructor with custom_groups and ExampleModel
|
||||||
|
model = ExampleModel()
|
||||||
|
optimizer_cfg = dict(
|
||||||
|
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||||
|
paramwise_cfg = dict(
|
||||||
|
custom_keys={
|
||||||
|
'param1': dict(lr_mult=10),
|
||||||
|
'sub': dict(lr_mult=0.1, decay_mult=0),
|
||||||
|
'sub.gn': dict(lr_mult=0.01),
|
||||||
|
'non_exist_key': dict(lr_mult=0.0)
|
||||||
|
},
|
||||||
|
norm_decay_mult=0.5)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# custom_keys should be a dict
|
||||||
|
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
|
||||||
|
optim_constructor = DefaultOptimizerConstructor(
|
||||||
|
optim_constructor, paramwise_cfg_)
|
||||||
|
optimizer = optim_constructor(model)
|
||||||
|
|
||||||
|
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg,
|
||||||
|
paramwise_cfg)
|
||||||
|
optimizer = optim_constructor(model)
|
||||||
|
# check optimizer type and default config
|
||||||
|
assert isinstance(optimizer, torch.optim.SGD)
|
||||||
|
assert optimizer.defaults['lr'] == base_lr
|
||||||
|
assert optimizer.defaults['momentum'] == momentum
|
||||||
|
assert optimizer.defaults['weight_decay'] == base_wd
|
||||||
|
|
||||||
|
# check params groups
|
||||||
|
param_groups = optimizer.param_groups
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
group_settings = []
|
||||||
|
# group 1, matches of 'param1'
|
||||||
|
# 'param1' is the longest match for 'sub.param1'
|
||||||
|
groups.append(['param1', 'sub.param1'])
|
||||||
|
group_settings.append({
|
||||||
|
'lr': base_lr * 10,
|
||||||
|
'momentum': momentum,
|
||||||
|
'weight_decay': base_wd,
|
||||||
|
})
|
||||||
|
# group 2, matches of 'sub.gn'
|
||||||
|
groups.append(['sub.gn.weight', 'sub.gn.bias'])
|
||||||
|
group_settings.append({
|
||||||
|
'lr': base_lr * 0.01,
|
||||||
|
'momentum': momentum,
|
||||||
|
'weight_decay': base_wd,
|
||||||
|
})
|
||||||
|
# group 3, matches of 'sub'
|
||||||
|
groups.append(['sub.conv1.weight', 'sub.conv1.bias'])
|
||||||
|
group_settings.append({
|
||||||
|
'lr': base_lr * 0.1,
|
||||||
|
'momentum': momentum,
|
||||||
|
'weight_decay': 0,
|
||||||
|
})
|
||||||
|
# group 4, bn is configured by 'norm_decay_mult'
|
||||||
|
groups.append(['bn.weight', 'bn.bias'])
|
||||||
|
group_settings.append({
|
||||||
|
'lr': base_lr,
|
||||||
|
'momentum': momentum,
|
||||||
|
'weight_decay': base_wd * 0.5,
|
||||||
|
})
|
||||||
|
# group 5, default group
|
||||||
|
groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias'])
|
||||||
|
group_settings.append({
|
||||||
|
'lr': base_lr,
|
||||||
|
'momentum': momentum,
|
||||||
|
'weight_decay': base_wd
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(param_groups) == 11
|
||||||
|
for i, (name, param) in enumerate(model.named_parameters()):
|
||||||
|
assert torch.equal(param_groups[i]['params'][0], param)
|
||||||
|
for group, settings in zip(groups, group_settings):
|
||||||
|
if name in group:
|
||||||
|
for setting in settings:
|
||||||
|
assert param_groups[i][setting] == settings[
|
||||||
|
setting], f'{name} {setting}'
|
||||||
|
|
||||||
|
|
||||||
def test_torch_optimizers():
|
def test_torch_optimizers():
|
||||||
torch_optimizers = [
|
torch_optimizers = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user