[Feature]: Add optimzier and constructor. (#25)

* [Feature]: Add optimzier and constructor.

* refactor unit tests

* add optimizer doc

* add parrots wrapper

* add parrots wrapper

* solve comments

* resolve comments
pull/40/head
RangiLyu 2022-02-19 14:09:37 +08:00 committed by GitHub
parent dd6fb223e6
commit 7353778b7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1263 additions and 11 deletions

View File

@ -0,0 +1,127 @@
# 优化器Optimizer
在模型训练过程中,我们需要使用优化算法对模型的参数进行优化。在 PyTorch 的 `torch.optim` 中包含了各种优化算法的实现,这些优化算法的类被称为优化器。
在 PyTorch 中,用户可以通过构建一个优化器对象来优化模型的参数,下面是一个简单的例子:
```python
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
```
关于 PyTorch 优化器的详细介绍可以参考 [PyTorch 优化器文档](https://pytorch.org/docs/stable/optim.html#)
MMEngine 支持所有的 PyTorch 优化器,用户可以直接构建 PyTorch 优化器对象并将它传给[执行器Runner](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html) 。
和 PyTorch 文档中所给示例不同MMEngine 中通常不需要手动实现训练循环以及调用` optimizer.step()`,执行器会自动对损失函数进行反向传播并调用优化器的 `step` 方法更新模型参数。
同时我们也支持通过配置文件从注册器中构建优化器。更进一步的我们提供了优化器构造器optimizer constructor来对模型的优化进行更细粒度的调整。
## 使用配置文件构建优化器
MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以通过设置配置文件中的 `optimizer` 字段来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。
以配置一个 SGD 优化器为例:
```python
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
```
我们只需要指定 `optimizer` 字段中的 `type` 为 SGD 并设置学习率等参数,执行器会根据此字段以及执行器中的模型参数自动构建优化器。
## 细粒度调整模型超参
PyTorch 的优化器支持对模型中的不同参数设置不同的超参数例如对一个分类模型的骨干backbone和分类头head设置不同的学习率
```python
optim.SGD([
{'params': model.backbone.parameters()},
{'params': model.head.parameters(), 'lr': 1e-3}
], lr=0.01, momentum=0.9)
```
上面的例子中,模型的骨干部分使用了 0.01 学习率,而模型的头部则使用了 1e-3 学习率。
用户可以将模型的不同部分参数和对应的超参组成一个字典的列表传给优化器,来实现对模型优化的细粒度调整。
在 MMEngine 中我们通过优化器构造器optimizer constructor让用户能够直接通过设置优化器配置文件中的 `paramwise_cfg` 字段而非修改代码来实现对模型的不同部分设置不同的超参。
### 为不同类型的参数设置不同的超参系数
MMEngine 提供的默认优化器构造器支持对模型中不同类型的参数设置不同的超参系数。
例如,我们可以在 `paramwise_cfg` 中设置 `norm_decay_mult=0` 从而将正则化层normalization layer的权重weight和偏置bias的权值衰减系数weight decay设置为0
来实现 [Bag of Tricks](https://arxiv.org/abs/1812.01187) 论文中提到的不对正则化层进行权值衰减的技巧。
示例:
```python
optimizer = dict(type='SGD',
lr=0.01,
weight_decay=0.0001,
paramwise_cfg=dict(norm_decay_mult=0))
```
除了可以对偏置的权重衰减进行配置外MMEngine 的默认优化器构造器的 `paramwise_cfg` 还支持对更多不同类型的参数设置超参系数,支持的配置如下:
`bias_lr_mult`:偏置的学习率系数(不包括正则化层的偏置以及可变形卷积的 offset默认值为 1
`bias_decay_mult`:偏置的权值衰减系数(不包括正则化层的偏置以及可变形卷积的 offset默认值为 1
`norm_decay_mult`:正则化层权重和偏置的权值衰减系数,默认值为 1
`dwconv_decay_mult`Depth-wise 卷积的权值衰减系数,默认值为 1
`bypass_duplicate`:是否跳过重复的参数,默认为 `False`
`dcn_offset_lr_mult`可变形卷积Deformable Convolution的学习率系数默认值为 1
### 为模型不同部分的参数设置不同的超参系数
此外,与上文 PyTorch 的示例一样,在 MMEngine 中我们也同样可以对模型中的任意模块设置不同的超参,只需要在 `paramwise_cfg` 中设置 `custom_keys` 即可:
```python
optimizer = dict(type='SGD',
lr=0.01,
weight_decay=0.0001,
paramwise_cfg=dict(
custom_keys={
'backbone.layer0': dict(lr_mult=0, decay_mult=0),
'backbone': dict(lr_mult=1),
'head': dict(lr_mult=0.1),
}
))
```
上面的配置文件实现了对模型的骨干第一层的学习率和权重衰减设置为 0骨干的其余部分部分使用 0.01 学习率,而对模型的头部则使用 1e-3 学习率。
### 进阶用法:实现自定义的优化器构造器
与 MMEngine 中的其他模块一样,优化器构造器也同样由 [注册表](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/param_scheduler.html) 来管理。
用户可以实现自己的优化器构造策略来实现自定义的超参设置策略,并添加进 `OPTIMIZER_CONSTRUCTORS` 注册表中。
例如,我们想实现一个叫做`LayerDecayOptimizerConstructor`的优化器构造器,来实现对模型的不同深度的层自动设置递减的学习率。
我们可以通过继承 `DefaultOptimizerConstructor` 来实现这一策略,并将其添加进注册表中:
```python
@OPTIMIZER_CONSTRUCTORS.register_module()
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix='', is_dcn_module=None):
...
```
然后将优化器配置文件中的 `constructor` 字段设置为类名来指定使用这个自定义的优化器构造器:
```python
optimizer = dict(type='SGD',
lr=0.01,
weight_decay=0.0001,
constructor='LayerDecayOptimizerConstructor')
```
## 在训练过程中调整超参
优化器中的超参数在构造时只能设置为一个定值,仅仅使用优化器,并不能在训练过程中调整学习率等参数。
在 MMEngine 中我们实现了参数调度器Parameter Scheduler以便能够在训练过程中调整参数。关于参数调度器的用法请见[优化器参数调整策略](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/param_scheduler.html)

View File

@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
CosineAnnealingLR, CosineAnnealingMomentum,
CosineAnnealingParamScheduler, ExponentialLR,
ExponentialMomentum, ExponentialParamScheduler,
LinearLR, LinearMomentum, LinearParamScheduler,
MultiStepLR, MultiStepMomentum,
MultiStepParamScheduler, StepLR, StepMomentum,
StepParamScheduler, _ParamScheduler)
__all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR',
'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR',
'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum',
'LinearMomentum', 'MultiStepMomentum', 'StepMomentum',
'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
]

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer,
build_optimizer_constructor)
from .default_constructor import DefaultOptimizerConstructor
__all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor'
]

View File

@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from typing import Callable, List
import torch
import torch.nn as nn
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
def register_torch_optimizers() -> List[str]:
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
torch_optimizers.append(module_name)
return torch_optimizers
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg: dict) -> Callable:
return OPTIMIZER_CONSTRUCTORS.build(cfg)
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer

View File

@ -0,0 +1,260 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Union
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
build_from_cfg)
from mmengine.utils import is_list_of, mmcv_full_available
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
@OPTIMIZER_CONSTRUCTORS.register_module()
class DefaultOptimizerConstructor:
"""Default constructor for optimizers.
By default each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
one of the keys in ``custom_keys`` is a substring of the name of one
parameter, then the setting of the parameter will be specified by
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
be ignored. It should be noted that the aforementioned ``key`` is the
longest key that is a substring of the name of the parameter. If there
are multiple matched keys with the same length, then the key with lower
alphabet order will be chosen.
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
and ``decay_mult``. See Example 2 below.
- ``bias_lr_mult`` (float): It will be multiplied to the learning
rate for all bias parameters (except for those in normalization
layers and offset layers of DCN).
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in
normalization layers, depthwise conv layers, offset layers of DCN).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization
layers.
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of depthwise conv
layers.
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
rate for parameters of offset layer in the deformable convs
of a model.
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
would not be added into optimizer. Default: False.
Note:
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
override the effect of ``bias_lr_mult`` in the bias of offset layer.
So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
apply it to all the DCN layers in the model. So be careful when the
model contains multiple DCN layers in places other than backbone.
Args:
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are
- `type`: class name of the optimizer.
Optional fields are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
Example 1:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001)
>>> paramwise_cfg = dict(norm_decay_mult=0.)
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
Example 2:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
>>> paramwise_cfg = dict(custom_keys={
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
>>> # Then the `lr` and `weight_decay` for model.backbone is
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
>>> # model.cls_head is (0.01, 0.95).
"""
def __init__(self,
optimizer_cfg: dict,
paramwise_cfg: Optional[dict] = None):
if not isinstance(optimizer_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}')
self.optimizer_cfg = optimizer_cfg
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
self.base_lr = optimizer_cfg.get('lr', None)
self.base_wd = optimizer_cfg.get('weight_decay', None)
self._validate_cfg()
def _validate_cfg(self) -> None:
"""verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
if 'custom_keys' in self.paramwise_cfg:
if not isinstance(self.paramwise_cfg['custom_keys'], dict):
raise TypeError(
'If specified, custom_keys must be a dict, '
f'but got {type(self.paramwise_cfg["custom_keys"])}')
if self.base_wd is None:
for key in self.paramwise_cfg['custom_keys']:
if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
raise ValueError('base_wd should not be None')
# get base lr and weight decay
# weight_decay must be explicitly specified if mult is specified
if ('bias_decay_mult' in self.paramwise_cfg
or 'norm_decay_mult' in self.paramwise_cfg
or 'dwconv_decay_mult' in self.paramwise_cfg):
if self.base_wd is None:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
for group in param_group_list:
param_set.update(set(group['params']))
return not param.isdisjoint(param_set)
def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
is_dwconv = (
isinstance(module, torch.nn.Conv2d)
and module.in_channels == module.groups)
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if not param.requires_grad:
params.append(param_group)
continue
if bypass_duplicate and self._is_in(param_group, params):
warnings.warn(f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}')
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in sorted_keys:
if key in f'{prefix}.{name}':
is_custom = True
lr_mult = custom_keys[key].get('lr_mult', 1.)
param_group['lr'] = self.base_lr * lr_mult
if self.base_wd is not None:
decay_mult = custom_keys[key].get('decay_mult', 1.)
param_group['weight_decay'] = self.base_wd * decay_mult
break
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if name == 'bias' and not (is_norm or is_dcn_module):
param_group['lr'] = self.base_lr * bias_lr_mult
if (prefix.find('conv_offset') != -1 and is_dcn_module
and isinstance(module, torch.nn.Conv2d)):
# deal with both dcn_offset's bias & weight
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
# apply weight decay policies
if self.base_wd is not None:
# norm decay
if is_norm:
param_group[
'weight_decay'] = self.base_wd * norm_decay_mult
# depth-wise conv
elif is_dwconv:
param_group[
'weight_decay'] = self.base_wd * dwconv_decay_mult
# bias lr and decay
elif name == 'bias' and not is_dcn_module:
# TODO: current bias_decay_mult will have affect on DCN
param_group[
'weight_decay'] = self.base_wd * bias_decay_mult
params.append(param_group)
if mmcv_full_available():
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
is_dcn_module = isinstance(module,
(DeformConv2d, ModulatedDeformConv2d))
else:
is_dcn_module = False
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(
params,
child_mod,
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
if hasattr(model, 'module'):
model = model.module
optimizer_cfg = self.optimizer_cfg.copy()
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# set param-wise lr and weight decay recursively
params: List = []
self.add_params(params, model)
optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS)

View File

@ -2,7 +2,7 @@
import inspect
import sys
from collections.abc import Callable
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ..config import Config, ConfigDict
from ..utils import is_seq_of
@ -11,8 +11,7 @@ from ..utils import is_seq_of
def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: 'Registry',
default_args: Optional[Union[dict, ConfigDict,
Config]] = None) -> object:
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
"""Build a module from config dict.
At least one of the ``cfg`` and ``default_args`` contains the key "type"
@ -357,7 +356,7 @@ class Registry:
def build(self,
*args,
default_scope: Optional[str] = None,
**kwargs) -> None:
**kwargs) -> Any:
"""Build an instance.
Build an instance by calling :attr:`build_func`. If

View File

@ -2,9 +2,10 @@
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
has_method, import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
iter_cast, list_cast, mmcv_full_available,
requires_executable, requires_package, slice_list,
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
@ -15,5 +16,5 @@ __all__ = [
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method'
'is_method_overridden', 'has_method', 'mmcv_full_available'
]

View File

@ -2,13 +2,18 @@
import collections.abc
import functools
import itertools
import pkgutil
import subprocess
import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec
from itertools import repeat
from typing import Sequence, Type
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
import torch.nn as nn
from .parrots_wrapper import _BatchNorm, _InstanceNorm
# From PyTorch internals
@ -296,7 +301,8 @@ def requires_executable(prerequisites):
return check_prerequisites(prerequisites, checker=_check_executable)
def deprecated_api_warning(name_dict, cls_name=None):
def deprecated_api_warning(name_dict: dict,
cls_name: Optional[str] = None) -> Callable:
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
@ -356,7 +362,8 @@ def deprecated_api_warning(name_dict, cls_name=None):
return api_warning_wrapper
def is_method_overridden(method, base_class, derived_class):
def is_method_overridden(method: str, base_class: type,
derived_class: Union[type, Any]) -> bool:
"""Check if a method of base class is overridden in derived class.
Args:
@ -386,3 +393,43 @@ def has_method(obj: object, method: str) -> bool:
bool: True if the object has the method else False.
"""
return hasattr(obj, method) and callable(getattr(obj, method))
def mmcv_full_available() -> bool:
"""Check whether mmcv-full is installed.
Returns:
bool: True if mmcv-full is installed else False.
"""
try:
import mmcv # noqa: F401
except ImportError:
return False
ext_loader = pkgutil.find_loader('mmcv._ext')
return ext_loader is not None
def is_norm(layer: nn.Module,
exclude: Optional[Union[type, Tuple[type]]] = None) -> bool:
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type, tuple[type], optional): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
if exclude is not None:
if not isinstance(exclude, tuple):
exclude = (exclude, )
if not is_tuple_of(exclude, type):
raise TypeError(
f'"exclude" must be either None or type or a tuple of types, '
f'but got {type(exclude)}: {exclude}')
if exclude and isinstance(layer, exclude):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)

View File

@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import Optional
import torch
TORCH_VERSION = torch.__version__
def is_rocm_pytorch() -> bool:
is_rocm = False
if TORCH_VERSION != 'parrots':
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
except ImportError:
pass
return is_rocm
def _get_cuda_home() -> Optional[str]:
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
if is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME
def get_build_config():
if TORCH_VERSION == 'parrots':
from parrots.config import get_build_info
return get_build_info()
else:
return torch.__config__.show()
def _get_conv() -> tuple:
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
else:
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
return _ConvNd, _ConvTransposeMixin
def _get_dataloader() -> tuple:
if TORCH_VERSION == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader
else:
from torch.utils.data import DataLoader
PoolDataLoader = DataLoader
return DataLoader, PoolDataLoader
def _get_extension():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
CppExtension = partial(Extension, cuda=False)
CUDAExtension = partial(Extension, cuda=True)
else:
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
return BuildExtension, CppExtension, CUDAExtension
def _get_pool() -> tuple:
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
else:
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
def _get_norm() -> tuple:
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
else:
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()

View File

@ -0,0 +1,646 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from unittest import TestCase
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils import mmcv_full_available
MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:
sys.modules['mmcv.ops'] = MagicMock(
DeformConv2d=dict, ModulatedDeformConv2d=dict)
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
if MMCV_FULL_AVAILABLE:
from mmcv.ops import DeformConv2dPack
self.dcn = DeformConv2dPack(
3, 4, kernel_size=3, deformable_groups=1)
class ExampleDuplicateModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
self.conv2 = nn.Sequential(nn.Conv2d(4, 2, kernel_size=1))
self.bn = nn.BatchNorm2d(2)
self.sub = SubModel()
self.conv3 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False))
self.conv3[0] = self.conv1[0]
if MMCV_FULL_AVAILABLE:
from mmcv.ops import DeformConv2dPack
self.dcn = DeformConv2dPack(
3, 4, kernel_size=3, deformable_groups=1)
def forward(self, x):
return x
class SubModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
self.gn = nn.GroupNorm(2, 2)
self.param1 = nn.Parameter(torch.ones(1))
def forward(self, x):
return x
class PseudoDataParallel(nn.Module):
def __init__(self):
super().__init__()
self.module = ExampleModel()
def forward(self, x):
return x
class TestBuilder(TestCase):
def setUp(self):
self.model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9
def _check_default_optimizer(self, optimizer, model, prefix=''):
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
assert len(param_groups['params']) == len(param_names)
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[prefix + param_names[i]])
def _check_sgd_optimizer(self,
optimizer,
model,
prefix='',
bias_lr_mult=1,
bias_decay_mult=1,
norm_decay_mult=1,
dwconv_decay_mult=1,
dcn_offset_lr_mult=1,
bypass_duplicate=False):
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == self.momentum
# param1
param1 = param_groups[0]
assert param1['lr'] == self.base_lr
assert param1['weight_decay'] == self.base_wd
# conv1.weight
conv1_weight = param_groups[1]
assert conv1_weight['lr'] == self.base_lr
assert conv1_weight['weight_decay'] == self.base_wd
# conv2.weight
conv2_weight = param_groups[2]
assert conv2_weight['lr'] == self.base_lr
assert conv2_weight['weight_decay'] == self.base_wd
# conv2.bias
conv2_bias = param_groups[3]
assert conv2_bias['lr'] == self.base_lr * bias_lr_mult
assert conv2_bias['weight_decay'] == self.base_wd * bias_decay_mult
# bn.weight
bn_weight = param_groups[4]
assert bn_weight['lr'] == self.base_lr
assert bn_weight['weight_decay'] == self.base_wd * norm_decay_mult
# bn.bias
bn_bias = param_groups[5]
assert bn_bias['lr'] == self.base_lr
assert bn_bias['weight_decay'] == self.base_wd * norm_decay_mult
# sub.param1
sub_param1 = param_groups[6]
assert sub_param1['lr'] == self.base_lr
assert sub_param1['weight_decay'] == self.base_wd
# sub.conv1.weight
sub_conv1_weight = param_groups[7]
assert sub_conv1_weight['lr'] == self.base_lr
assert sub_conv1_weight[
'weight_decay'] == self.base_wd * dwconv_decay_mult
# sub.conv1.bias
sub_conv1_bias = param_groups[8]
assert sub_conv1_bias['lr'] == self.base_lr * bias_lr_mult
assert sub_conv1_bias[
'weight_decay'] == self.base_wd * dwconv_decay_mult
# sub.gn.weight
sub_gn_weight = param_groups[9]
assert sub_gn_weight['lr'] == self.base_lr
assert sub_gn_weight['weight_decay'] == self.base_wd * norm_decay_mult
# sub.gn.bias
sub_gn_bias = param_groups[10]
assert sub_gn_bias['lr'] == self.base_lr
assert sub_gn_bias['weight_decay'] == self.base_wd * norm_decay_mult
if torch.cuda.is_available():
dcn_conv_weight = param_groups[11]
assert dcn_conv_weight['lr'] == self.base_lr
assert dcn_conv_weight['weight_decay'] == self.base_wd
dcn_offset_weight = param_groups[12]
assert dcn_offset_weight['lr'] == self.base_lr * dcn_offset_lr_mult
assert dcn_offset_weight['weight_decay'] == self.base_wd
dcn_offset_bias = param_groups[13]
assert dcn_offset_bias['lr'] == self.base_lr * dcn_offset_lr_mult
assert dcn_offset_bias['weight_decay'] == self.base_wd
def test_torch_optimizers(self):
torch_optimizers = [
'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS',
'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam'
]
assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS))
def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
optimizer = build_optimizer(self.model, optimizer_cfg)
self._check_default_optimizer(optimizer, self.model)
# test build function with invalid ``constructor``
with self.assertRaises(KeyError):
optimizer_cfg['constructor'] = 'INVALID_CONSTRUCTOR'
build_optimizer(self.model, optimizer_cfg)
# test build function with invalid ``paramwise_cfg``
with self.assertRaises(KeyError):
optimizer_cfg['paramwise_cfg'] = dict(invalid_mult=1)
build_optimizer(self.model, optimizer_cfg)
def test_build_default_optimizer_constructor(self):
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optimizer = optim_constructor(self.model)
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
def test_build_custom_optimizer_constructor(self):
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
@OPTIMIZER_CONSTRUCTORS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
def __call__(self, model):
if hasattr(model, 'module'):
model = model.module
conv1_lr_mult = self.paramwise_cfg.get('conv1_lr_mult', 1.)
params = []
for name, param in model.named_parameters():
param_group = {'params': [param]}
if name.startswith('conv1') and param.requires_grad:
param_group['lr'] = self.base_lr * conv1_lr_mult
params.append(param_group)
self.optimizer_cfg['params'] = params
return build_from_cfg(self.optimizer_cfg, OPTIMIZERS)
paramwise_cfg = dict(conv1_lr_mult=5)
optim_constructor_cfg = dict(
type='MyOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optimizer = optim_constructor(self.model)
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
for i, param in enumerate(self.model.parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == self.momentum
# conv1.weight
assert param_groups[1][
'lr'] == self.base_lr * paramwise_cfg['conv1_lr_mult']
assert param_groups[1]['weight_decay'] == self.base_wd
def test_default_optimizer_constructor(self):
with self.assertRaises(TypeError):
# optimizer_cfg must be a dict
optimizer_cfg = []
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optim_constructor(self.model)
with self.assertRaises(TypeError):
# paramwise_cfg must be a dict or None
optimizer_cfg = dict(lr=0.0001)
paramwise_cfg = ['error']
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor(self.model)
with self.assertRaises(ValueError):
# bias_decay_mult/norm_decay_mult is specified but weight_decay
# is None
optimizer_cfg = dict(lr=0.0001, weight_decay=None)
paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor(self.model)
# basic config with ExampleModel
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(self.model)
self._check_default_optimizer(optimizer, self.model)
def test_default_optimizer_constructor_with_model_wrapper(self):
# basic config with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = None
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
self._check_default_optimizer(optimizer, model, prefix='module.')
# paramwise_cfg with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(model)
self._check_sgd_optimizer(
optimizer, model, prefix='module.', **paramwise_cfg)
# basic config with DataParallel
if torch.cuda.is_available():
model = torch.nn.DataParallel(ExampleModel())
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = None
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
self._check_default_optimizer(optimizer, model, prefix='module.')
# paramwise_cfg with DataParallel
if torch.cuda.is_available():
model = torch.nn.DataParallel(self.model)
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(model)
self._check_sgd_optimizer(
optimizer, model, prefix='module.', **paramwise_cfg)
def test_default_optimizer_constructor_with_empty_paramwise_cfg(self):
# Empty paramwise_cfg with ExampleModel
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict()
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
self._check_default_optimizer(optimizer, self.model)
# Empty paramwise_cfg with ExampleModel and no grad
model = ExampleModel()
for param in model.parameters():
param.requires_grad = False
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict()
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
self._check_default_optimizer(optimizer, model)
def test_default_optimizer_constructor_with_paramwise_cfg(self):
# paramwise_cfg with ExampleModel
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
def test_default_optimizer_constructor_no_grad(self):
# paramwise_cfg with ExampleModel and no grad
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
for param in self.model.parameters():
param.requires_grad = False
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
for i, (name, param) in enumerate(self.model.named_parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == self.momentum
assert param_group['lr'] == self.base_lr
assert param_group['weight_decay'] == self.base_wd
def test_default_optimizer_constructor_bypass_duplicate(self):
# paramwise_cfg with bypass_duplicate option
model = ExampleDuplicateModel()
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
with self.assertRaisesRegex(
ValueError,
'some parameters appear in more than one parameter group'):
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor(model)
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
bypass_duplicate=True)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
self.assertWarnsRegex(
Warning,
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
lambda: optim_constructor(model))
optimizer = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(
optimizer.param_groups) == len(model_parameters) == num_params
self._check_sgd_optimizer(optimizer, model, **paramwise_cfg)
def test_default_optimizer_constructor_custom_key(self):
# test DefaultOptimizerConstructor with custom_keys and ExampleModel
optimizer_cfg = dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
paramwise_cfg = dict(
custom_keys={
'param1': dict(lr_mult=10),
'sub': dict(lr_mult=0.1, decay_mult=0),
'sub.gn': dict(lr_mult=0.01),
'non_exist_key': dict(lr_mult=0.0)
},
norm_decay_mult=0.5)
with self.assertRaises(TypeError):
# custom_keys should be a dict
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg_)
optimizer = optim_constructor(self.model)
with self.assertRaises(ValueError):
# if 'decay_mult' is specified in custom_keys, weight_decay
# should be specified
optimizer_cfg_ = dict(type='SGD', lr=0.01)
paramwise_cfg_ = dict(
custom_keys={'.backbone': dict(decay_mult=0.5)})
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg_, paramwise_cfg_)
optimizer = optim_constructor(self.model)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
# check optimizer type and default config
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
# check params groups
param_groups = optimizer.param_groups
groups = []
group_settings = []
# group 1, matches of 'param1'
# 'param1' is the longest match for 'sub.param1'
groups.append(['param1', 'sub.param1'])
group_settings.append({
'lr': self.base_lr * 10,
'momentum': self.momentum,
'weight_decay': self.base_wd,
})
# group 2, matches of 'sub.gn'
groups.append(['sub.gn.weight', 'sub.gn.bias'])
group_settings.append({
'lr': self.base_lr * 0.01,
'momentum': self.momentum,
'weight_decay': self.base_wd,
})
# group 3, matches of 'sub'
groups.append(['sub.conv1.weight', 'sub.conv1.bias'])
group_settings.append({
'lr': self.base_lr * 0.1,
'momentum': self.momentum,
'weight_decay': 0,
})
# group 4, bn is configured by 'norm_decay_mult'
groups.append(['bn.weight', 'bn.bias'])
group_settings.append({
'lr': self.base_lr,
'momentum': self.momentum,
'weight_decay': self.base_wd * 0.5,
})
# group 5, default group
groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias'])
group_settings.append({
'lr': self.base_lr,
'momentum': self.momentum,
'weight_decay': self.base_wd
})
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(param_groups) == num_params
for i, (name, param) in enumerate(self.model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
for group, settings in zip(groups, group_settings):
if name in group:
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'
# test DefaultOptimizerConstructor with custom_keys and ExampleModel 2
optimizer_cfg = dict(
type='SGD', lr=self.base_lr, momentum=self.momentum)
paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)})
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
# check optimizer type and default config
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == 0
# check params groups
param_groups = optimizer.param_groups
groups = []
group_settings = []
# group 1, matches of 'param1'
groups.append(['param1', 'sub.param1'])
group_settings.append({
'lr': self.base_lr * 10,
'momentum': self.momentum,
'weight_decay': 0,
})
# group 2, default group
groups.append([
'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight',
'sub.gn.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias'
])
group_settings.append({
'lr': self.base_lr,
'momentum': self.momentum,
'weight_decay': 0
})
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(param_groups) == num_params
for i, (name, param) in enumerate(self.model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
for group, settings in zip(groups, group_settings):
if name in group:
for setting in settings:
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'