diff --git a/docs/zh_cn/tutorials/optimizer.md b/docs/zh_cn/tutorials/optimizer.md new file mode 100644 index 00000000..6980a59d --- /dev/null +++ b/docs/zh_cn/tutorials/optimizer.md @@ -0,0 +1,127 @@ +# 优化器(Optimizer) + +在模型训练过程中,我们需要使用优化算法对模型的参数进行优化。在 PyTorch 的 `torch.optim` 中包含了各种优化算法的实现,这些优化算法的类被称为优化器。 +在 PyTorch 中,用户可以通过构建一个优化器对象来优化模型的参数,下面是一个简单的例子: + +```python +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001) + +for input, target in dataset: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() +``` + +关于 PyTorch 优化器的详细介绍可以参考 [PyTorch 优化器文档](https://pytorch.org/docs/stable/optim.html#) + +MMEngine 支持所有的 PyTorch 优化器,用户可以直接构建 PyTorch 优化器对象并将它传给[执行器(Runner)](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html) 。 +和 PyTorch 文档中所给示例不同,MMEngine 中通常不需要手动实现训练循环以及调用` optimizer.step()`,执行器会自动对损失函数进行反向传播并调用优化器的 `step` 方法更新模型参数。 + +同时,我们也支持通过配置文件从注册器中构建优化器。更进一步的,我们提供了优化器构造器(optimizer constructor)来对模型的优化进行更细粒度的调整。 + +## 使用配置文件构建优化器 + +MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以通过设置配置文件中的 `optimizer` 字段来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。 + +以配置一个 SGD 优化器为例: + +```python +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +``` + +我们只需要指定 `optimizer` 字段中的 `type` 为 SGD, 并设置学习率等参数,执行器会根据此字段以及执行器中的模型参数自动构建优化器。 + +## 细粒度调整模型超参 + +PyTorch 的优化器支持对模型中的不同参数设置不同的超参数,例如对一个分类模型的骨干(backbone)和分类头(head)设置不同的学习率: + +```python +optim.SGD([ + {'params': model.backbone.parameters()}, + {'params': model.head.parameters(), 'lr': 1e-3} + ], lr=0.01, momentum=0.9) +``` + +上面的例子中,模型的骨干部分使用了 0.01 学习率,而模型的头部则使用了 1e-3 学习率。 +用户可以将模型的不同部分参数和对应的超参组成一个字典的列表传给优化器,来实现对模型优化的细粒度调整。 + +在 MMEngine 中,我们通过优化器构造器(optimizer constructor),让用户能够直接通过设置优化器配置文件中的 `paramwise_cfg` 字段而非修改代码来实现对模型的不同部分设置不同的超参。 + +### 为不同类型的参数设置不同的超参系数 + +MMEngine 提供的默认优化器构造器支持对模型中不同类型的参数设置不同的超参系数。 +例如,我们可以在 `paramwise_cfg` 中设置 `norm_decay_mult=0` ,从而将正则化层(normalization layer)的权重(weight)和偏置(bias)的权值衰减系数(weight decay)设置为0, +来实现 [Bag of Tricks](https://arxiv.org/abs/1812.01187) 论文中提到的不对正则化层进行权值衰减的技巧。 + +示例: + +```python +optimizer = dict(type='SGD', + lr=0.01, + weight_decay=0.0001, + paramwise_cfg=dict(norm_decay_mult=0)) +``` + +除了可以对偏置的权重衰减进行配置外,MMEngine 的默认优化器构造器的 `paramwise_cfg` 还支持对更多不同类型的参数设置超参系数,支持的配置如下: + +`bias_lr_mult`:偏置的学习率系数(不包括正则化层的偏置以及可变形卷积的 offset),默认值为 1 + +`bias_decay_mult`:偏置的权值衰减系数(不包括正则化层的偏置以及可变形卷积的 offset),默认值为 1 + +`norm_decay_mult`:正则化层权重和偏置的权值衰减系数,默认值为 1 + +`dwconv_decay_mult`:Depth-wise 卷积的权值衰减系数,默认值为 1 + +`bypass_duplicate`:是否跳过重复的参数,默认为 `False` + +`dcn_offset_lr_mult`:可变形卷积(Deformable Convolution)的学习率系数,默认值为 1 + +### 为模型不同部分的参数设置不同的超参系数 + +此外,与上文 PyTorch 的示例一样,在 MMEngine 中我们也同样可以对模型中的任意模块设置不同的超参,只需要在 `paramwise_cfg` 中设置 `custom_keys` 即可: + +```python +optimizer = dict(type='SGD', + lr=0.01, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={ + 'backbone.layer0': dict(lr_mult=0, decay_mult=0), + 'backbone': dict(lr_mult=1), + 'head': dict(lr_mult=0.1), + } + )) +``` + +上面的配置文件实现了对模型的骨干第一层的学习率和权重衰减设置为 0,骨干的其余部分部分使用 0.01 学习率,而对模型的头部则使用 1e-3 学习率。 + +### 进阶用法:实现自定义的优化器构造器 + +与 MMEngine 中的其他模块一样,优化器构造器也同样由 [注册表](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/param_scheduler.html) 来管理。 +用户可以实现自己的优化器构造策略来实现自定义的超参设置策略,并添加进 `OPTIMIZER_CONSTRUCTORS` 注册表中。 + +例如,我们想实现一个叫做`LayerDecayOptimizerConstructor`的优化器构造器,来实现对模型的不同深度的层自动设置递减的学习率。 +我们可以通过继承 `DefaultOptimizerConstructor` 来实现这一策略,并将其添加进注册表中: + +```python +@OPTIMIZER_CONSTRUCTORS.register_module() +class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): + def add_params(self, params, module, prefix='', is_dcn_module=None): + ... +``` + +然后将优化器配置文件中的 `constructor` 字段设置为类名来指定使用这个自定义的优化器构造器: + +```python +optimizer = dict(type='SGD', + lr=0.01, + weight_decay=0.0001, + constructor='LayerDecayOptimizerConstructor') +``` + +## 在训练过程中调整超参 + +优化器中的超参数在构造时只能设置为一个定值,仅仅使用优化器,并不能在训练过程中调整学习率等参数。 +在 MMEngine 中,我们实现了参数调度器(Parameter Scheduler),以便能够在训练过程中调整参数。关于参数调度器的用法请见[优化器参数调整策略](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/param_scheduler.html) diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py new file mode 100644 index 00000000..98832cfd --- /dev/null +++ b/mmengine/optim/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, + DefaultOptimizerConstructor, build_optimizer, + build_optimizer_constructor) +from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, + CosineAnnealingLR, CosineAnnealingMomentum, + CosineAnnealingParamScheduler, ExponentialLR, + ExponentialMomentum, ExponentialParamScheduler, + LinearLR, LinearMomentum, LinearParamScheduler, + MultiStepLR, MultiStepMomentum, + MultiStepParamScheduler, StepLR, StepMomentum, + StepParamScheduler, _ParamScheduler) + +__all__ = [ + 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer', + 'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR', + 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', + 'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum', + 'LinearMomentum', 'MultiStepMomentum', 'StepMomentum', + 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', + 'ExponentialParamScheduler', 'LinearParamScheduler', + 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' +] diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py new file mode 100644 index 00000000..77d5ac18 --- /dev/null +++ b/mmengine/optim/optimizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer, + build_optimizer_constructor) +from .default_constructor import DefaultOptimizerConstructor + +__all__ = [ + 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', + 'build_optimizer', 'build_optimizer_constructor' +] diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py new file mode 100644 index 00000000..a48f65e2 --- /dev/null +++ b/mmengine/optim/optimizer/builder.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +from typing import Callable, List + +import torch +import torch.nn as nn + +from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS + + +def register_torch_optimizers() -> List[str]: + torch_optimizers = [] + for module_name in dir(torch.optim): + if module_name.startswith('__'): + continue + _optim = getattr(torch.optim, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module(module=_optim) + torch_optimizers.append(module_name) + return torch_optimizers + + +TORCH_OPTIMIZERS = register_torch_optimizers() + + +def build_optimizer_constructor(cfg: dict) -> Callable: + return OPTIMIZER_CONSTRUCTORS.build(cfg) + + +def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer: + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py new file mode 100644 index 00000000..be573d3c --- /dev/null +++ b/mmengine/optim/optimizer/default_constructor.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn import GroupNorm, LayerNorm + +from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, + build_from_cfg) +from mmengine.utils import is_list_of, mmcv_full_available +from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm + + +@OPTIMIZER_CONSTRUCTORS.register_module() +class DefaultOptimizerConstructor: + """Default constructor for optimizers. + + By default each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Default: False. + + Note: + + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset layer. + So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset + layer in deformable convs, set ``dcn_offset_lr_mult`` to the original + ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when the + model contains multiple DCN layers in places other than backbone. + + Args: + optimizer_cfg (dict): The config dict of the optimizer. + Positional fields are + + - `type`: class name of the optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, + >>> weight_decay=0.0001) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) + >>> paramwise_cfg = dict(custom_keys={ + '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def __init__(self, + optimizer_cfg: dict, + paramwise_cfg: Optional[dict] = None): + if not isinstance(optimizer_cfg, dict): + raise TypeError('optimizer_cfg should be a dict', + f'but got {type(optimizer_cfg)}') + self.optimizer_cfg = optimizer_cfg + self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg + self.base_lr = optimizer_cfg.get('lr', None) + self.base_wd = optimizer_cfg.get('weight_decay', None) + self._validate_cfg() + + def _validate_cfg(self) -> None: + """verify the correctness of the config.""" + if not isinstance(self.paramwise_cfg, dict): + raise TypeError('paramwise_cfg should be None or a dict, ' + f'but got {type(self.paramwise_cfg)}') + + if 'custom_keys' in self.paramwise_cfg: + if not isinstance(self.paramwise_cfg['custom_keys'], dict): + raise TypeError( + 'If specified, custom_keys must be a dict, ' + f'but got {type(self.paramwise_cfg["custom_keys"])}') + if self.base_wd is None: + for key in self.paramwise_cfg['custom_keys']: + if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]: + raise ValueError('base_wd should not be None') + + # get base lr and weight decay + # weight_decay must be explicitly specified if mult is specified + if ('bias_decay_mult' in self.paramwise_cfg + or 'norm_decay_mult' in self.paramwise_cfg + or 'dwconv_decay_mult' in self.paramwise_cfg): + if self.base_wd is None: + raise ValueError('base_wd should not be None') + + def _is_in(self, param_group: dict, param_group_list: list) -> bool: + """check whether the `param_group` is in the`param_group_list`""" + assert is_list_of(param_group_list, dict) + param = set(param_group['params']) + param_set = set() + for group in param_group_list: + param_set.update(set(group['params'])) + + return not param.isdisjoint(param_set) + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + is_dcn_module: Optional[Union[int, float]] = None) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + if bypass_duplicate and self._is_in(param_group, params): + warnings.warn(f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}') + continue + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + break + + if not is_custom: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not (is_norm or is_dcn_module): + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # depth-wise conv + elif is_dwconv: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # bias lr and decay + elif name == 'bias' and not is_dcn_module: + # TODO: current bias_decay_mult will have affect on DCN + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + params.append(param_group) + + if mmcv_full_available(): + from mmcv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) + + def __call__(self, model: nn.Module) -> torch.optim.Optimizer: + if hasattr(model, 'module'): + model = model.module + + optimizer_cfg = self.optimizer_cfg.copy() + # if no paramwise option is specified, just use the global setting + if not self.paramwise_cfg: + optimizer_cfg['params'] = model.parameters() + return build_from_cfg(optimizer_cfg, OPTIMIZERS) + + # set param-wise lr and weight decay recursively + params: List = [] + self.add_params(params, model) + optimizer_cfg['params'] = params + + return build_from_cfg(optimizer_cfg, OPTIMIZERS) diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 32c7d722..3c55ca14 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -2,7 +2,7 @@ import inspect import sys from collections.abc import Callable -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..config import Config, ConfigDict from ..utils import is_seq_of @@ -11,8 +11,7 @@ from ..utils import is_seq_of def build_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: 'Registry', - default_args: Optional[Union[dict, ConfigDict, - Config]] = None) -> object: + default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: """Build a module from config dict. At least one of the ``cfg`` and ``default_args`` contains the key "type" @@ -357,7 +356,7 @@ class Registry: def build(self, *args, default_scope: Optional[str] = None, - **kwargs) -> None: + **kwargs) -> Any: """Build an instance. Build an instance by calling :attr:`build_func`. If diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 952b2dc3..c3ee0dd9 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -2,9 +2,10 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning, has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, - iter_cast, list_cast, requires_executable, requires_package, - slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, - to_ntuple, tuple_cast) + iter_cast, list_cast, mmcv_full_available, + requires_executable, requires_package, slice_list, + to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, + tuple_cast) from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, scandir, symlink) @@ -15,5 +16,5 @@ __all__ = [ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir', 'deprecated_api_warning', 'import_modules_from_strings', 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', - 'is_method_overridden', 'has_method' + 'is_method_overridden', 'has_method', 'mmcv_full_available' ] diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 41823a7f..79977be0 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -2,13 +2,18 @@ import collections.abc import functools import itertools +import pkgutil import subprocess import warnings from collections import abc from importlib import import_module from inspect import getfullargspec from itertools import repeat -from typing import Sequence, Type +from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union + +import torch.nn as nn + +from .parrots_wrapper import _BatchNorm, _InstanceNorm # From PyTorch internals @@ -296,7 +301,8 @@ def requires_executable(prerequisites): return check_prerequisites(prerequisites, checker=_check_executable) -def deprecated_api_warning(name_dict, cls_name=None): +def deprecated_api_warning(name_dict: dict, + cls_name: Optional[str] = None) -> Callable: """A decorator to check if some arguments are deprecate and try to replace deprecate src_arg_name to dst_arg_name. @@ -356,7 +362,8 @@ def deprecated_api_warning(name_dict, cls_name=None): return api_warning_wrapper -def is_method_overridden(method, base_class, derived_class): +def is_method_overridden(method: str, base_class: type, + derived_class: Union[type, Any]) -> bool: """Check if a method of base class is overridden in derived class. Args: @@ -386,3 +393,43 @@ def has_method(obj: object, method: str) -> bool: bool: True if the object has the method else False. """ return hasattr(obj, method) and callable(getattr(obj, method)) + + +def mmcv_full_available() -> bool: + """Check whether mmcv-full is installed. + + Returns: + bool: True if mmcv-full is installed else False. + """ + try: + import mmcv # noqa: F401 + except ImportError: + return False + ext_loader = pkgutil.find_loader('mmcv._ext') + return ext_loader is not None + + +def is_norm(layer: nn.Module, + exclude: Optional[Union[type, Tuple[type]]] = None) -> bool: + """Check if a layer is a normalization layer. + + Args: + layer (nn.Module): The layer to be checked. + exclude (type, tuple[type], optional): Types to be excluded. + + Returns: + bool: Whether the layer is a norm layer. + """ + if exclude is not None: + if not isinstance(exclude, tuple): + exclude = (exclude, ) + if not is_tuple_of(exclude, type): + raise TypeError( + f'"exclude" must be either None or type or a tuple of types, ' + f'but got {type(exclude)}: {exclude}') + + if exclude and isinstance(layer, exclude): + return False + + all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) + return isinstance(layer, all_norm_bases) diff --git a/mmengine/utils/parrots_wrapper.py b/mmengine/utils/parrots_wrapper.py new file mode 100644 index 00000000..8f8f2157 --- /dev/null +++ b/mmengine/utils/parrots_wrapper.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Optional + +import torch + +TORCH_VERSION = torch.__version__ + + +def is_rocm_pytorch() -> bool: + is_rocm = False + if TORCH_VERSION != 'parrots': + try: + from torch.utils.cpp_extension import ROCM_HOME + is_rocm = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + except ImportError: + pass + return is_rocm + + +def _get_cuda_home() -> Optional[str]: + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import CUDA_HOME + else: + if is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + CUDA_HOME = ROCM_HOME + else: + from torch.utils.cpp_extension import CUDA_HOME + return CUDA_HOME + + +def get_build_config(): + if TORCH_VERSION == 'parrots': + from parrots.config import get_build_info + return get_build_info() + else: + return torch.__config__.show() + + +def _get_conv() -> tuple: + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin + else: + from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin + return _ConvNd, _ConvTransposeMixin + + +def _get_dataloader() -> tuple: + if TORCH_VERSION == 'parrots': + from torch.utils.data import DataLoader, PoolDataLoader + else: + from torch.utils.data import DataLoader + PoolDataLoader = DataLoader + return DataLoader, PoolDataLoader + + +def _get_extension(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import BuildExtension, Extension + CppExtension = partial(Extension, cuda=False) + CUDAExtension = partial(Extension, cuda=True) + else: + from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) + return BuildExtension, CppExtension, CUDAExtension + + +def _get_pool() -> tuple: + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + else: + from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd + + +def _get_norm() -> tuple: + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm2d + else: + from torch.nn.modules.batchnorm import _BatchNorm + from torch.nn.modules.instancenorm import _InstanceNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm + return _BatchNorm, _InstanceNorm, SyncBatchNorm_ + + +_ConvNd, _ConvTransposeMixin = _get_conv() +DataLoader, PoolDataLoader = _get_dataloader() +BuildExtension, CppExtension, CUDAExtension = _get_extension() +_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() +_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py new file mode 100644 index 00000000..e205c54f --- /dev/null +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -0,0 +1,646 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from unittest import TestCase +from unittest.mock import MagicMock + +import torch +import torch.nn as nn + +from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, + DefaultOptimizerConstructor, build_optimizer, + build_optimizer_constructor) +from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS +from mmengine.registry import build_from_cfg +from mmengine.utils import mmcv_full_available + +MMCV_FULL_AVAILABLE = mmcv_full_available() +if not MMCV_FULL_AVAILABLE: + sys.modules['mmcv.ops'] = MagicMock( + DeformConv2d=dict, ModulatedDeformConv2d=dict) + + +class ExampleModel(nn.Module): + + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.ones(1)) + self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(4, 2, kernel_size=1) + self.bn = nn.BatchNorm2d(2) + self.sub = SubModel() + if MMCV_FULL_AVAILABLE: + from mmcv.ops import DeformConv2dPack + self.dcn = DeformConv2dPack( + 3, 4, kernel_size=3, deformable_groups=1) + + +class ExampleDuplicateModel(nn.Module): + + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.ones(1)) + self.conv1 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False)) + self.conv2 = nn.Sequential(nn.Conv2d(4, 2, kernel_size=1)) + self.bn = nn.BatchNorm2d(2) + self.sub = SubModel() + self.conv3 = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=False)) + self.conv3[0] = self.conv1[0] + if MMCV_FULL_AVAILABLE: + from mmcv.ops import DeformConv2dPack + self.dcn = DeformConv2dPack( + 3, 4, kernel_size=3, deformable_groups=1) + + def forward(self, x): + return x + + +class SubModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2) + self.gn = nn.GroupNorm(2, 2) + self.param1 = nn.Parameter(torch.ones(1)) + + def forward(self, x): + return x + + +class PseudoDataParallel(nn.Module): + + def __init__(self): + super().__init__() + self.module = ExampleModel() + + def forward(self, x): + return x + + +class TestBuilder(TestCase): + + def setUp(self): + self.model = ExampleModel() + self.base_lr = 0.01 + self.momentum = 0.0001 + self.base_wd = 0.9 + + def _check_default_optimizer(self, optimizer, model, prefix=''): + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == self.base_lr + assert optimizer.defaults['momentum'] == self.momentum + assert optimizer.defaults['weight_decay'] == self.base_wd + param_groups = optimizer.param_groups[0] + if MMCV_FULL_AVAILABLE: + param_names = [ + 'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', + 'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight', + 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight', + 'dcn.conv_offset.weight', 'dcn.conv_offset.bias' + ] + else: + param_names = [ + 'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias', + 'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight', + 'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias' + ] + param_dict = dict(model.named_parameters()) + assert len(param_groups['params']) == len(param_names) + for i in range(len(param_groups['params'])): + assert torch.equal(param_groups['params'][i], + param_dict[prefix + param_names[i]]) + + def _check_sgd_optimizer(self, + optimizer, + model, + prefix='', + bias_lr_mult=1, + bias_decay_mult=1, + norm_decay_mult=1, + dwconv_decay_mult=1, + dcn_offset_lr_mult=1, + bypass_duplicate=False): + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == self.base_lr + assert optimizer.defaults['momentum'] == self.momentum + assert optimizer.defaults['weight_decay'] == self.base_wd + model_parameters = list(model.parameters()) + assert len(param_groups) == len(model_parameters) + for i, param in enumerate(model_parameters): + param_group = param_groups[i] + assert torch.equal(param_group['params'][0], param) + assert param_group['momentum'] == self.momentum + + # param1 + param1 = param_groups[0] + assert param1['lr'] == self.base_lr + assert param1['weight_decay'] == self.base_wd + # conv1.weight + conv1_weight = param_groups[1] + assert conv1_weight['lr'] == self.base_lr + assert conv1_weight['weight_decay'] == self.base_wd + # conv2.weight + conv2_weight = param_groups[2] + assert conv2_weight['lr'] == self.base_lr + assert conv2_weight['weight_decay'] == self.base_wd + # conv2.bias + conv2_bias = param_groups[3] + assert conv2_bias['lr'] == self.base_lr * bias_lr_mult + assert conv2_bias['weight_decay'] == self.base_wd * bias_decay_mult + # bn.weight + bn_weight = param_groups[4] + assert bn_weight['lr'] == self.base_lr + assert bn_weight['weight_decay'] == self.base_wd * norm_decay_mult + # bn.bias + bn_bias = param_groups[5] + assert bn_bias['lr'] == self.base_lr + assert bn_bias['weight_decay'] == self.base_wd * norm_decay_mult + # sub.param1 + sub_param1 = param_groups[6] + assert sub_param1['lr'] == self.base_lr + assert sub_param1['weight_decay'] == self.base_wd + # sub.conv1.weight + sub_conv1_weight = param_groups[7] + assert sub_conv1_weight['lr'] == self.base_lr + assert sub_conv1_weight[ + 'weight_decay'] == self.base_wd * dwconv_decay_mult + # sub.conv1.bias + sub_conv1_bias = param_groups[8] + assert sub_conv1_bias['lr'] == self.base_lr * bias_lr_mult + assert sub_conv1_bias[ + 'weight_decay'] == self.base_wd * dwconv_decay_mult + # sub.gn.weight + sub_gn_weight = param_groups[9] + assert sub_gn_weight['lr'] == self.base_lr + assert sub_gn_weight['weight_decay'] == self.base_wd * norm_decay_mult + # sub.gn.bias + sub_gn_bias = param_groups[10] + assert sub_gn_bias['lr'] == self.base_lr + assert sub_gn_bias['weight_decay'] == self.base_wd * norm_decay_mult + + if torch.cuda.is_available(): + dcn_conv_weight = param_groups[11] + assert dcn_conv_weight['lr'] == self.base_lr + assert dcn_conv_weight['weight_decay'] == self.base_wd + + dcn_offset_weight = param_groups[12] + assert dcn_offset_weight['lr'] == self.base_lr * dcn_offset_lr_mult + assert dcn_offset_weight['weight_decay'] == self.base_wd + + dcn_offset_bias = param_groups[13] + assert dcn_offset_bias['lr'] == self.base_lr * dcn_offset_lr_mult + assert dcn_offset_bias['weight_decay'] == self.base_wd + + def test_torch_optimizers(self): + torch_optimizers = [ + 'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS', + 'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam' + ] + assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS)) + + def test_build_optimizer(self): + # test build function without ``constructor`` and ``paramwise_cfg`` + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + optimizer = build_optimizer(self.model, optimizer_cfg) + self._check_default_optimizer(optimizer, self.model) + + # test build function with invalid ``constructor`` + with self.assertRaises(KeyError): + optimizer_cfg['constructor'] = 'INVALID_CONSTRUCTOR' + build_optimizer(self.model, optimizer_cfg) + + # test build function with invalid ``paramwise_cfg`` + with self.assertRaises(KeyError): + optimizer_cfg['paramwise_cfg'] = dict(invalid_mult=1) + build_optimizer(self.model, optimizer_cfg) + + def test_build_default_optimizer_constructor(self): + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) + optim_constructor_cfg = dict( + type='DefaultOptimizerConstructor', + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg) + optim_constructor = build_optimizer_constructor(optim_constructor_cfg) + optimizer = optim_constructor(self.model) + self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg) + + def test_build_custom_optimizer_constructor(self): + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + + @OPTIMIZER_CONSTRUCTORS.register_module() + class MyOptimizerConstructor(DefaultOptimizerConstructor): + + def __call__(self, model): + if hasattr(model, 'module'): + model = model.module + + conv1_lr_mult = self.paramwise_cfg.get('conv1_lr_mult', 1.) + params = [] + + for name, param in model.named_parameters(): + param_group = {'params': [param]} + if name.startswith('conv1') and param.requires_grad: + param_group['lr'] = self.base_lr * conv1_lr_mult + params.append(param_group) + self.optimizer_cfg['params'] = params + + return build_from_cfg(self.optimizer_cfg, OPTIMIZERS) + + paramwise_cfg = dict(conv1_lr_mult=5) + optim_constructor_cfg = dict( + type='MyOptimizerConstructor', + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg) + optim_constructor = build_optimizer_constructor(optim_constructor_cfg) + optimizer = optim_constructor(self.model) + + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == self.base_lr + assert optimizer.defaults['momentum'] == self.momentum + assert optimizer.defaults['weight_decay'] == self.base_wd + for i, param in enumerate(self.model.parameters()): + param_group = param_groups[i] + assert torch.equal(param_group['params'][0], param) + assert param_group['momentum'] == self.momentum + # conv1.weight + assert param_groups[1][ + 'lr'] == self.base_lr * paramwise_cfg['conv1_lr_mult'] + assert param_groups[1]['weight_decay'] == self.base_wd + + def test_default_optimizer_constructor(self): + with self.assertRaises(TypeError): + # optimizer_cfg must be a dict + optimizer_cfg = [] + optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) + optim_constructor(self.model) + + with self.assertRaises(TypeError): + # paramwise_cfg must be a dict or None + optimizer_cfg = dict(lr=0.0001) + paramwise_cfg = ['error'] + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optim_constructor(self.model) + + with self.assertRaises(ValueError): + # bias_decay_mult/norm_decay_mult is specified but weight_decay + # is None + optimizer_cfg = dict(lr=0.0001, weight_decay=None) + paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optim_constructor(self.model) + + # basic config with ExampleModel + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) + optimizer = optim_constructor(self.model) + self._check_default_optimizer(optimizer, self.model) + + def test_default_optimizer_constructor_with_model_wrapper(self): + # basic config with pseudo data parallel + model = PseudoDataParallel() + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = None + optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) + optimizer = optim_constructor(model) + self._check_default_optimizer(optimizer, model, prefix='module.') + + # paramwise_cfg with pseudo data parallel + model = PseudoDataParallel() + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(model) + self._check_sgd_optimizer( + optimizer, model, prefix='module.', **paramwise_cfg) + + # basic config with DataParallel + if torch.cuda.is_available(): + model = torch.nn.DataParallel(ExampleModel()) + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = None + optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) + optimizer = optim_constructor(model) + self._check_default_optimizer(optimizer, model, prefix='module.') + + # paramwise_cfg with DataParallel + if torch.cuda.is_available(): + model = torch.nn.DataParallel(self.model) + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(model) + self._check_sgd_optimizer( + optimizer, model, prefix='module.', **paramwise_cfg) + + def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): + # Empty paramwise_cfg with ExampleModel + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict() + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model) + self._check_default_optimizer(optimizer, self.model) + + # Empty paramwise_cfg with ExampleModel and no grad + model = ExampleModel() + for param in model.parameters(): + param.requires_grad = False + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict() + optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) + optimizer = optim_constructor(model) + self._check_default_optimizer(optimizer, model) + + def test_default_optimizer_constructor_with_paramwise_cfg(self): + # paramwise_cfg with ExampleModel + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model) + self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg) + + def test_default_optimizer_constructor_no_grad(self): + # paramwise_cfg with ExampleModel and no grad + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) + + for param in self.model.parameters(): + param.requires_grad = False + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model) + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == self.base_lr + assert optimizer.defaults['momentum'] == self.momentum + assert optimizer.defaults['weight_decay'] == self.base_wd + for i, (name, param) in enumerate(self.model.named_parameters()): + param_group = param_groups[i] + assert torch.equal(param_group['params'][0], param) + assert param_group['momentum'] == self.momentum + assert param_group['lr'] == self.base_lr + assert param_group['weight_decay'] == self.base_wd + + def test_default_optimizer_constructor_bypass_duplicate(self): + # paramwise_cfg with bypass_duplicate option + model = ExampleDuplicateModel() + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1) + + with self.assertRaisesRegex( + ValueError, + 'some parameters appear in more than one parameter group'): + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optim_constructor(model) + + paramwise_cfg = dict( + bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + bypass_duplicate=True) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + + self.assertWarnsRegex( + Warning, + 'conv3.0 is duplicate. It is skipped since bypass_duplicate=True', + lambda: optim_constructor(model)) + optimizer = optim_constructor(model) + model_parameters = list(model.parameters()) + num_params = 14 if MMCV_FULL_AVAILABLE else 11 + assert len( + optimizer.param_groups) == len(model_parameters) == num_params + self._check_sgd_optimizer(optimizer, model, **paramwise_cfg) + + def test_default_optimizer_constructor_custom_key(self): + # test DefaultOptimizerConstructor with custom_keys and ExampleModel + optimizer_cfg = dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum) + paramwise_cfg = dict( + custom_keys={ + 'param1': dict(lr_mult=10), + 'sub': dict(lr_mult=0.1, decay_mult=0), + 'sub.gn': dict(lr_mult=0.01), + 'non_exist_key': dict(lr_mult=0.0) + }, + norm_decay_mult=0.5) + + with self.assertRaises(TypeError): + # custom_keys should be a dict + paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001]) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg_) + optimizer = optim_constructor(self.model) + + with self.assertRaises(ValueError): + # if 'decay_mult' is specified in custom_keys, weight_decay + # should be specified + optimizer_cfg_ = dict(type='SGD', lr=0.01) + paramwise_cfg_ = dict( + custom_keys={'.backbone': dict(decay_mult=0.5)}) + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg_, paramwise_cfg_) + optimizer = optim_constructor(self.model) + + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model) + # check optimizer type and default config + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == self.base_lr + assert optimizer.defaults['momentum'] == self.momentum + assert optimizer.defaults['weight_decay'] == self.base_wd + + # check params groups + param_groups = optimizer.param_groups + + groups = [] + group_settings = [] + # group 1, matches of 'param1' + # 'param1' is the longest match for 'sub.param1' + groups.append(['param1', 'sub.param1']) + group_settings.append({ + 'lr': self.base_lr * 10, + 'momentum': self.momentum, + 'weight_decay': self.base_wd, + }) + # group 2, matches of 'sub.gn' + groups.append(['sub.gn.weight', 'sub.gn.bias']) + group_settings.append({ + 'lr': self.base_lr * 0.01, + 'momentum': self.momentum, + 'weight_decay': self.base_wd, + }) + # group 3, matches of 'sub' + groups.append(['sub.conv1.weight', 'sub.conv1.bias']) + group_settings.append({ + 'lr': self.base_lr * 0.1, + 'momentum': self.momentum, + 'weight_decay': 0, + }) + # group 4, bn is configured by 'norm_decay_mult' + groups.append(['bn.weight', 'bn.bias']) + group_settings.append({ + 'lr': self.base_lr, + 'momentum': self.momentum, + 'weight_decay': self.base_wd * 0.5, + }) + # group 5, default group + groups.append(['conv1.weight', 'conv2.weight', 'conv2.bias']) + group_settings.append({ + 'lr': self.base_lr, + 'momentum': self.momentum, + 'weight_decay': self.base_wd + }) + + num_params = 14 if MMCV_FULL_AVAILABLE else 11 + assert len(param_groups) == num_params + for i, (name, param) in enumerate(self.model.named_parameters()): + assert torch.equal(param_groups[i]['params'][0], param) + for group, settings in zip(groups, group_settings): + if name in group: + for setting in settings: + assert param_groups[i][setting] == settings[ + setting], f'{name} {setting}' + + # test DefaultOptimizerConstructor with custom_keys and ExampleModel 2 + optimizer_cfg = dict( + type='SGD', lr=self.base_lr, momentum=self.momentum) + paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) + + optim_constructor = DefaultOptimizerConstructor( + optimizer_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model) + # check optimizer type and default config + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == self.base_lr + assert optimizer.defaults['momentum'] == self.momentum + assert optimizer.defaults['weight_decay'] == 0 + + # check params groups + param_groups = optimizer.param_groups + + groups = [] + group_settings = [] + # group 1, matches of 'param1' + groups.append(['param1', 'sub.param1']) + group_settings.append({ + 'lr': self.base_lr * 10, + 'momentum': self.momentum, + 'weight_decay': 0, + }) + # group 2, default group + groups.append([ + 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', + 'sub.gn.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', + 'bn.weight', 'bn.bias' + ]) + group_settings.append({ + 'lr': self.base_lr, + 'momentum': self.momentum, + 'weight_decay': 0 + }) + + num_params = 14 if MMCV_FULL_AVAILABLE else 11 + assert len(param_groups) == num_params + for i, (name, param) in enumerate(self.model.named_parameters()): + assert torch.equal(param_groups[i]['params'][0], param) + for group, settings in zip(groups, group_settings): + if name in group: + for setting in settings: + assert param_groups[i][setting] == settings[ + setting], f'{name} {setting}'