[Feature] Add optimizer wrapper (#265)

* Support multiple optimizers

* minor refinement

* improve unit tests

* minor fix

* Update unit tests for resuming or saving ckpt for multiple optimizers

* refine docstring

* refine docstring

* fix typo

* update docstring

* refactor the logic to build multiple optimizers

* resolve comments

* ParamSchedulers spports multiple optimizers

* add optimizer_wrapper

* fix comment and docstirng

* fix unit test

* add unit test

* refine docstring

* RuntimeInfoHook supports printing multi learning rates

* resolve comments

* add optimizer_wrapper

* fix mypy

* fix lint

* fix OptimizerWrapperDict docstring and add unit test

* rename OptimizerWrapper to OptimWrapper, OptimWrapperDict inherit OptimWrapper, and fix as comment

* Fix AmpOptimizerWrapper

* rename build_optmizer_wrapper to build_optim_wrapper

* refine optimizer wrapper

* fix AmpOptimWrapper.step, docstring

* resolve confict

* rename DefaultOptimConstructor

* fix as comment

* rename clig grad auguments

* refactor optim_wrapper config

* fix docstring of DefaultOptimWrapperConstructor

fix docstring of DefaultOptimWrapperConstructor

* add get_lr method to OptimWrapper and OptimWrapperDict

* skip some amp unit test

* fix unit test

* fix get_lr, get_momentum docstring

* refactor get_lr, get_momentum, fix as comment

* fix error message

Co-authored-by: zhouzaida <zhouzaida@163.com>
This commit is contained in:
Mashiro 2022-06-01 18:04:38 +08:00 committed by GitHub
parent 987e5b83f9
commit 3e3866c1b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1858 additions and 487 deletions

View File

@ -83,7 +83,7 @@ class OptimizerHook(Hook):
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.
"""
runner.optimizer.zero_grad()
runner.optim_wrapper.zero_grad()
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()
@ -94,7 +94,7 @@ class OptimizerHook(Hook):
# Add grad norm to the logger
runner.message_hub.update_scalar('train/grad_norm',
float(grad_norm))
runner.optimizer.step()
runner.optim_wrapper.step()
def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
"""Detect anomalous parameters that are not included in the graph.

View File

@ -41,13 +41,16 @@ class RuntimeInfoHook(Hook):
"""Update current iter and learning rate information before every
iteration."""
runner.message_hub.update_info('iter', runner.iter)
if isinstance(runner.optimizer, dict):
for name, optimizer in runner.optimizer.items():
runner.message_hub.update_scalar(
f'train/{name}.lr', optimizer.param_groups[0]['lr'])
else:
runner.message_hub.update_scalar(
'train/lr', runner.optimizer.param_groups[0]['lr'])
lr_dict = runner.optim_wrapper.get_lr()
assert isinstance(lr_dict, dict), (
'`runner.optim_wrapper.get_lr()` should return a dict '
'of learning rate when training with OptimWrapper(single '
'optimizer) or OptimWrapperDict(multiple optimizer), '
f'but got {type(lr_dict)} please check your optimizer '
'constructor return an `OptimWrapper` or `OptimWrapperDict` '
'instance')
for name, lr in lr_dict.items():
runner.message_hub.update_scalar(f'train/{name}', lr[0])
def after_train_iter(self,
runner,

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer)
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
AmpOptimWrapper, DefaultOptimWrapperConstructor,
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
CosineAnnealingLR, CosineAnnealingMomentum,
CosineAnnealingParamScheduler, ExponentialLR,
@ -11,12 +12,12 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
StepParamScheduler, _ParamScheduler)
__all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR',
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optim_wrapper',
'DefaultOptimWrapperConstructor', 'ConstantLR', 'CosineAnnealingLR',
'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum',
'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum',
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
'_ParamScheduler'
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict'
]

View File

@ -1,8 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer
from .default_constructor import DefaultOptimizerConstructor
from .amp_optimizer_wrapper import AmpOptimWrapper
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
build_optim_wrapper)
from .default_constructor import DefaultOptimWrapperConstructor
from .optimizer_wrapper import OptimWrapper
from .optimizer_wrapper_dict import OptimWrapperDict
__all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer'
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
'AmpOptimWrapper', 'OptimWrapperDict'
]

View File

@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
import torch
from torch.cuda.amp import GradScaler
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import TORCH_VERSION, digit_version
from .optimizer_wrapper import OptimWrapper
@OPTIM_WRAPPERS.register_module()
class AmpOptimWrapper(OptimWrapper):
"""A subclass of :class:`OptimWrapper` that supports automatic mixed
precision training based on torch.cuda.amp.
``AmpOptimWrapper`` provides a unified interface with
``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way
as ``OptimWrapper``.
Warnings:
``AmpOptimWrapper`` requires PyTorch >= 1.6.
Args:
loss_scale (float or str or dict): The initial configuration of
`torch.cuda.amp.GradScaler`. See more specific arguments
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501
- "dynamic": Initialize GradScale without any arguments.
- float: Initialize GradScaler with ``init_scale``.
- dict: Initialize GradScaler with more detail configuration.
**kwargs: Keyword arguments passed to OptimWrapper.
"""
def __init__(self, loss_scale=512., **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert torch.cuda.is_available(), (
'``AmpOptimizerWrapper`` is only available training on gpu')
super().__init__(**kwargs)
self._scale_update_param = None
if loss_scale == 'dynamic':
# If loss_scale is a string, it must be 'dynamic', then dynamic
# loss scaling will be used.
self.loss_scaler = GradScaler()
elif isinstance(loss_scale, float):
# Static loss scaling
self._scale_update_param = loss_scale
self.loss_scaler = GradScaler(init_scale=loss_scale)
elif isinstance(loss_scale, dict):
# More specific configuration.
self.loss_scaler = GradScaler(**loss_scale)
else:
raise TypeError('loss_scale must be of type float, dict, or '
f'"dynamic", but got {loss_scale}')
def backward(self, loss: torch.Tensor):
"""Perform gradient back propagation with :attr:`loss_scaler`.
Args:
loss (torch.Tensor): The loss of current iteration.
"""
self.loss_scaler.scale(loss).backward()
def step(self):
"""Update parameters with :attr:`loss_scaler`."""
if self.clip_grad_kwargs:
self.loss_scaler.unscale_(self.optimizer)
self._clip_grad()
self.loss_scaler.step(self.optimizer)
self.loss_scaler.update(self._scale_update_param)
def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`loss_scaler`.
Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "loss_scaler".
Returns:
dict: The merged state dict of :attr:`loss_scaler` and
:attr:`optimizer`.
"""
# save state_dict of loss_scaler
state_dict = self.optimizer.state_dict()
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
return state_dict
def load_state_dict(self, state_dict: dict):
"""Load and parse the state dictionary of :attr:`optimizer` and
:attr:`loss_scaler`.
If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will
load the corresponding keys. Otherwise, only the :attr:`optimizer`
will load the state dictionary.
Args:
state_dict (dict): The state dict of :attr:`optimizer` and
:attr:`loss_scaler`
"""
if 'loss_scaler' in state_dict:
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
self.optimizer.load_state_dict(state_dict)
@contextmanager
def precision_context(self):
"""A wrapper of ``torch.cuda.amp.autocast``"""
with torch.cuda.amp.autocast():
yield

View File

@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from typing import List
from typing import List, Union
import torch
import torch.nn as nn
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
from mmengine.config import Config, ConfigDict
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .optimizer_wrapper import OptimWrapper
def register_torch_optimizers() -> List[str]:
@ -30,31 +32,31 @@ def register_torch_optimizers() -> List[str]:
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
"""Build function of optimizer.
def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
"""Build function of OptimWrapper.
If ``constructor`` is set in the ``cfg``, this method will build an
optimizer constructor, and use optimizer constructor to build the
optimizer. If ``constructor`` is not set, the
``DefaultOptimizerConstructor`` will be used by default.
optimizer wrapper constructor, and use optimizer wrapper constructor to
build the optimizer wrapper. If ``constructor`` is not set, the
``DefaultOptimWrapperConstructor`` will be used by default.
Args:
model (nn.Module): Model to be optimized.
cfg (dict): Config of optimizer and optimizer constructor.
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
cfg (dict): Config of optimizer wrapper, optimizer constructor and
optimizer.
Returns:
torch.optim.Optimizer: The built optimizer.
OptimWrapper: The built optimizer wrapper.
"""
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
optim_wrapper_cfg = copy.deepcopy(cfg)
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
optim_wrapper = optim_wrapper_constructor(model)
return optim_wrapper

View File

@ -6,17 +6,19 @@ import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.logging.logger import print_log
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
from mmengine.logging import print_log
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
from mmengine.utils import is_list_of, mmcv_full_available
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from .optimizer_wrapper import OptimWrapper
@OPTIMIZER_CONSTRUCTORS.register_module()
class DefaultOptimizerConstructor:
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class DefaultOptimWrapperConstructor:
"""Default constructor for optimizers.
By default each parameter share the same optimizer settings, and we
By default, each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
@ -62,49 +64,65 @@ class DefaultOptimizerConstructor:
model contains multiple DCN layers in places other than backbone.
Args:
optimizer_cfg (dict): The config dict of the optimizer.
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
Positional fields are
- ``type``: class name of the OptimizerWrapper
- ``optimizer``: The configuration of optimizer.
Optional fields are
- any arguments of the corresponding optimizer wrapper type,
e.g., accumulative_iters, clip_grad, etc.
The positional fields of ``optimizer`` are
- `type`: class name of the optimizer.
Optional fields are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
Example 1:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001)
>>> optim_wrapper_cfg = dict(
>>> dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01,
>>> momentum=0.9, weight_decay=0.0001))
>>> paramwise_cfg = dict(norm_decay_mult=0.)
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
>>> optim_wrapper_cfg, paramwise_cfg)
>>> optim_wrapper = optim_wrapper_builder(model)
Example 2:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
>>> optim_wrapper_cfg = dict(type=OptimWrapper, optimizer=dict(
>>> type='SGD', lr=0.01, weight_decay=0.95))
>>> paramwise_cfg = dict(custom_keys={
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
>>> '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
>>> optim_wrapper_cfg, paramwise_cfg)
>>> optim_wrapper = optim_wrapper_builder(model)
>>> # Then the `lr` and `weight_decay` for model.backbone is
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
>>> # model.cls_head is (0.01, 0.95).
"""
def __init__(self,
optimizer_cfg: dict,
optim_wrapper_cfg: dict,
paramwise_cfg: Optional[dict] = None):
if not isinstance(optimizer_cfg, dict):
if not isinstance(optim_wrapper_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}')
self.optimizer_cfg = optimizer_cfg
f'but got {type(optim_wrapper_cfg)}')
assert 'optimizer' in optim_wrapper_cfg, (
'`optim_wrapper_cfg` must contain "optimizer" config')
self.optim_wrapper_cfg = optim_wrapper_cfg.copy()
self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
self.base_lr = optimizer_cfg.get('lr', None)
self.base_wd = optimizer_cfg.get('weight_decay', None)
self.base_lr = self.optimizer_cfg.get('lr', None)
self.base_wd = self.optimizer_cfg.get('weight_decay', None)
self._validate_cfg()
def _validate_cfg(self) -> None:
@ -249,19 +267,23 @@ class DefaultOptimizerConstructor:
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
def __call__(self, model: nn.Module) -> OptimWrapper:
if hasattr(model, 'module'):
model = model.module
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
return OPTIMIZERS.build(optimizer_cfg)
optimizer = OPTIMIZERS.build(optimizer_cfg)
else:
# set param-wise lr and weight decay recursively
params: List = []
self.add_params(params, model)
optimizer_cfg['params'] = params
return OPTIMIZERS.build(optimizer_cfg)
optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper

View File

@ -0,0 +1,349 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad
from torch.optim import Optimizer
from mmengine.logging import MessageHub, MMLogger
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import has_batch_norm
@OPTIM_WRAPPERS.register_module()
class OptimWrapper:
"""Optimizer wrapper provides a common interface for updating parameters.
Optimizer wrapper provides a unified interface for single precision
training and automatic mixed precision training with different hardware.
OptimWrapper encapsulates optimizer to provide simplified interfaces
for commonly used training techniques such as gradient accumulative and
grad clips. ``OptimWrapper`` implements the basic logic of gradient
accumulation and gradient clipping based on ``torch.optim.Optimizer``.
The subclasses only need to override some methods to implement the mixed
precision training. See more information in :class:`AmpOptimWrapper`.
Args:
optimizer (Optimizer): Optimizer used to update model parameters.
accumulative_iters (int): The number of iterations to accumulate
gradients. The parameters will be updated per
``accumulative_iters``.
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
the arguments of ``torch.nn.utils.clip_grad``.
Warnings:
If ``accumulative_iters`` is larger than 1, :meth:`update_params` must
be called in the context of ``accumulate_grad``.
Examples:
>>> # Config sample of OptimWrapper.
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> accumulative_iters=3,
>>> clip_grad=dict(max_norm=0.2))
>>> # Use OptimWrapper to update model.
>>> import torch.nn as nn
>>> import torch
>>> from torch.optim import SGD
>>> from torch.utils.data import DataLoader
>>> from mmengine.optim import OptimWrapper
>>>
>>> model = nn.Linear(1, 1)
>>> dataset = torch.randn(10, 1, 1)
>>> dataloader = DataLoader(dataset)
>>> optimizer = SGD(model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>>
>>> for data in dataloader:
>>> loss = model(data)
>>> optim_wrapper.update_params(loss)
>>> # Enable gradient accumulation. If model is a subclass instance of
>>> # DistributedDataParallel, ``accumulate_grad`` context manager can
>>> # avoid unnecessary gradient synchronize.
>>> for iter, data in enumerate(dataloader):
>>> with optim_wrapper.accumulate_grad(
>>> model, iter, len(dataloader)):
>>> loss = model(data)
>>> optim_wrapper.update_params(loss)
"""
def __init__(self,
optimizer: Optimizer,
accumulative_iters: int = 1,
clip_grad: Optional[dict] = None):
assert accumulative_iters > 0, (
'accumulative_iters at least greater than or equal to 1')
self.accumulative_iters = accumulative_iters
# `max_iters` and `cur_iter` is only valid in gradient accumulative
# mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be
# updated in the ``accumulate_grad`` context that is enabled in
# `runner.train_loop`.
self.cur_iter = 0
self.max_iters = 0
assert isinstance(optimizer, Optimizer), (
'optimizer must be a `torch.optim.Optimizer` instance, but got '
f'{type(optimizer)}')
self.optimizer = optimizer
if clip_grad is not None:
# clip_grad_kwargs should not be non-empty dict.
assert isinstance(clip_grad, dict) and clip_grad, (
'If `clip_grad_kwargs` is not None, it should be a `dict` '
'which is the arguments of `torch.nn.utils.clip_grad`')
self.clip_grad_kwargs = clip_grad
self.logger = MMLogger.get_current_instance()
# Used to update `grad_norm` log message.
self.message_hub = MessageHub.get_current_instance()
self.iter_status_initialized = False
def update_params(self, loss: torch.Tensor) -> None:
"""Update parameters in :attr:`optimizer`.
Args:
loss (torch.Tensor): A tensor for back propagation.
"""
if self.accumulative_iters == 1:
# update parameters without gradient accumulation. The gradient
# should not be rescaled and `loss_factor=1`.
loss_factor = 1
else:
# gradient accumulation must be called in the context of
# ``accumulate_grad``.
assert hasattr(self, 'divisible_iters'), (
'gradient accumulation must be performed in the context of'
'`OptimWrapper.accumulate_grad`')
# if `self.accumulative_iters > 1`, the gradient needs to be
# rescaled and accumulated. In most cases, `loss_factor` equals to
# `self.accumulative_iters`. However `self.max_iters` may not be
# divisible `self.by accumulative_iters`, so the `loss_scale` for
# the last few iterations needs to be recalculated.
if self.cur_iter < self.divisible_iters:
loss_factor = self.accumulative_iters
else:
loss_factor = self.remainder_iters
assert loss_factor > 0, (
'loss_factor should be larger than zero! This error could '
'happened when gradient accumulation context enabled with an '
'error `cur_iter` or `max_iters` please check your loop')
loss = loss / loss_factor
self.backward(loss)
# Update parameters only if `self.cur_iter` is divisible by
# `self.accumulative_iters` or `self.cur_iter` equals to
# `self.max_iters`
if self._should_update(self.cur_iter, self.max_iters):
self.step()
self.zero_grad()
def backward(self, loss: torch.Tensor) -> None:
"""Perform gradient back propagation.
Provide unified ``backward`` interface compatible with automatic mixed
precision training. Subclass can overload this method to implement the
required logic. For example, ``torch.cuda.amp`` require some extra
operation on GradScaler during backward process.
Args:
loss (torch.Tensor): The loss of current iteration.
"""
loss.backward()
def zero_grad(self) -> None:
"""A wrapper of ``Optimizer.zero_grad``.
Provide unified ``zero_grad`` interface compatible with automatic mixed
precision training. Subclass can overload this method to implement the
required logic.
"""
self.optimizer.zero_grad()
def step(self) -> None:
"""A wrapper of ``Optimizer.step``.
Provide unified ``step`` interface compatible with automatic mixed
precision training. Subclass can overload this method to implement the
required logic. For example, ``torch.cuda.amp`` require some extra
operation on ``GradScaler`` during step process.
Clip grad if :attr:`clip_grad_kwargs` is not None, and then update
parameters.
"""
if self.clip_grad_kwargs:
self._clip_grad()
self.optimizer.step()
def state_dict(self) -> dict:
"""A wrapper of ``Optimizer.state_dict``.
Provide unified ``state_dict`` interface compatible with automatic
mixed precision training. Subclass can overload this method to
implement the required logic. For example, the state dictionary of
GradScaler should be saved when training with ``torch.cuda.amp``.
Returns:
dict: The state dictionary of :attr:`optimizer`.
"""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
:attr:`optimizer`.
Provide unified ``load_state_dict`` interface compatible with automatic
mixed precision training. Subclass can overload this method to
implement the required logic. For example, the state dictionary of
GradScaler should be loaded when training with ``torch.cuda.amp``.
Args:
state_dict (dict): The state dictionary of :attr:`optimizer`.
"""
self.optimizer.load_state_dict(state_dict)
@property
def param_groups(self) -> List[dict]:
"""A wrapper of ``Optimizer.param_groups``.
Make OptimizeWrapper compatible with :class:`_ParamScheduler`.
Returns:
dict: the ``param_groups`` of :attr:`optimizer`.
"""
return self.optimizer.param_groups
def get_lr(self) -> Dict[str, List[float]]:
"""Get the learning rate of the optimizer.
Provide unified interface to get learning rate of optimizer.
Returns:
List[float]: Learning rate of the optimizer.
"""
lr = [group['lr'] for group in self.param_groups]
return dict(lr=lr)
def get_momentum(self) -> Dict[str, List[float]]:
"""Get the momentum of the optimizer.
Provide unified interface to get momentum of optimizer.
Returns:
List[float]: Momentum of the optimizer.
"""
momentum = []
for group in self.param_groups:
# Get momentum of SGD.
if 'momentum' in group.keys():
momentum.append(group['momentum'])
# Get momentum of Adam.
elif 'betas' in group.keys():
momentum.append(group['betas'][0])
else:
momentum.append(0)
return dict(momentum=momentum)
@contextmanager
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
"""A Context manager for gradient accumulation and avoiding unnecessary
gradient synchronization during gradient accumulation.
If model is an instance with ``no_sync`` method (which means
blocking the gradient synchronization) and
``self.accumulative_iters != 1``. The model will not automatically
synchronize gradients if ``cur_iter`` is divisible by
``self.accumulative_iters``. Otherwise, this method will enable an
empty context.
Warnings:
This context manager must be enabled if you want to use
gradient accumulation.
Args:
model (nn.Module): The training model.
cur_iter (int): Current iteration during training process.
max_iters (int): Maximum training iteration.
"""
assert max_iters > 0, '`max_iters` must be larger than zero'
self.cur_iter = cur_iter
self.max_iters = max_iters
if not self.iter_status_initialized:
self._initilize_iter_status(model)
# During gradient accumulation process, the gradient synchronize
# should only happen before updating parameters.
if (not self._should_update(cur_iter, max_iters)
and hasattr(model, 'no_sync')):
with model.no_sync():
yield
else:
yield
@contextmanager
def precision_context(self):
"""precision context which enables an empty context by default.
The subclass used for mixed or low precision training needs to override
this method.
"""
yield
def _clip_grad(self) -> None:
"""Clip the gradients of parameters."""
params: List[torch.Tensor] = []
for param_group in self.optimizer.param_groups:
params.extend(param_group['params'])
params = list(
filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0:
grad_norm = clip_grad.clip_grad_norm_(params,
**self.clip_grad_kwargs)
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
def _initilize_iter_status(self, model: nn.Module) -> None:
"""Initialize gradient accumulation related attributes.
Args:
model (nn.Module): Training model
"""
if self.max_iters % self.accumulative_iters != 0:
self.logger.warning(
'Resume iter number is not divisible by accumulative_iters in '
'GradientCumulativeOptimizerHook, which means the gradient of '
'some iters is lost and the result may be influenced slightly.'
)
if has_batch_norm(model) and self.accumulative_iters > 1:
self.logger.warning(
'Gradient accumulative may slightly decrease '
'performance because the model has BatchNorm layers.')
residual_iters = self.max_iters - self.cur_iter
# The maximum number of training iteration that is divisible by
# accumulative_iters.
self.divisible_iters = (
residual_iters // self.accumulative_iters *
self.accumulative_iters)
# Remainder of ``self.max_iters`` divided by ``self.max_iters``
self.remainder_iters = residual_iters - self.divisible_iters
self.iter_status_initialized = True
def _should_update(self, cur_iter: int, max_iters: int) -> bool:
"""Should optim_wrapper update parameters or synchronized gradient at
current iteration.
Args:
cur_iter (int): Current iteration of training process.
max_iters (int): Maximum iterations of training process.
Returns:
bool: Whether to update parameters or synchronized gradient.
"""
return ((cur_iter + 1) % self.accumulative_iters == 0
or cur_iter + 1 == max_iters)
def __repr__(self):
wrapper_info = f'Type: {type(self).__name__}\n' \
f'accumulative_iters: {self.accumulative_iters}\n' \
f'optimizer: \n'
optimizer_str = repr(self.optimizer) + '\n'
return wrapper_info + optimizer_str

View File

@ -0,0 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from contextlib import ExitStack, contextmanager
from typing import Dict, Iterator, List, Tuple
import torch
import torch.nn as nn
from .optimizer_wrapper import OptimWrapper
class OptimWrapperDict(OptimWrapper):
"""A dictionary container of :obj:`OptimWrapper`.
If runner is training with multiple optimizers, all optimizer wrappers
should be managed by :obj:`OptimWrapperDict` which is built by
``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save
the state dictionary of all optimizer wrappers.
Consider the semantic ambiguity of calling :meth:``update_params``,
:meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not
implement these methods.
Examples:
>>> import torch.nn as nn
>>> from torch.optim import SGD
>>> from mmengine.optim import OptimWrapperDict, OptimWrapper
>>> model1 = nn.Linear(1, 1)
>>> model2 = nn.Linear(1, 1)
>>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1))
>>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1))
>>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1,
>>> model2=optim_wrapper2)
Note:
The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed
in the same way as `dict`.
Args:
**optim_wrappers: A dictionary of ``OptimWrapper`` instance.
"""
def __init__(self, **optim_wrapper_dict: OptimWrapper):
first_key = next(iter(optim_wrapper_dict))
first_optim_wrapper = optim_wrapper_dict[first_key]
assert isinstance(first_optim_wrapper, OptimWrapper), (
'Each argument of `OptimWrapperDict` must be an `OptimWrapper '
'instance`')
optim_wrapper_class = type(first_optim_wrapper)
for key, value in optim_wrapper_dict.items():
assert type(value) == optim_wrapper_class, (
f'All optimizer wrappers should have the same type, but found'
f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}'
)
if value.accumulative_iters != 1:
warnings.warn(
f'The `accumulative_iters` of {key} is '
f'{value.accumulative_iters}. OptimWrapperDict '
'will not enable any `accumulate_grad` context of its '
'optimizer wrappers. You should access the corresponding '
'optimizer wrapper to enable the context.')
self.optim_wrappers = optim_wrapper_dict
def update_params(self, loss: torch.Tensor) -> None:
"""Update all optimizer wrappers would lead to a duplicate backward
errors, and OptimWrapperDict does not know which optimizer wrapper
should be updated.
Therefore, this method is not implemented. The optimizer wrapper of
OptimWrapperDict should be accessed and call its `update_params.
"""
raise NotImplementedError(
'You should access the OptimWrapper of the '
'OptimWrapperDict and call its `update_params`')
def backward(self, loss: torch.Tensor) -> None:
"""Since OptimWrapperDict doesn't know which optimizer wrapper's
backward method should be called (``loss_scaler`` maybe different in
different :obj:AmpOptimWrapper), this method is not implemented.
The optimizer wrapper of OptimWrapperDict should be accessed and call
its `backward.
"""
raise NotImplementedError('You should access the OptimWrapper of the '
'OptimWrapperDict and call its `backward`')
def step(self) -> None:
"""Since the backward method is not implemented, the step should not be
implemented either."""
raise NotImplementedError('You should access the OptimWrapper of the '
'OptimWrapperDict and call its `step`')
def zero_grad(self) -> None:
"""Set the gradients of all optimizer wrappers to zero."""
for optim_wrapper in self.optim_wrappers.values():
optim_wrapper.zero_grad()
@contextmanager
def precision_context(self):
optim_wrapper = next(iter(self.optim_wrappers.values()))
with optim_wrapper.precision_context():
yield
@contextmanager
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
"""Enable ``accumulate_grad`` contexts of all optimizer wrappers.
Warning:
Consider there is only one ``model`` arguments for all
optimizer wrappers, all optimizer wrappers are working under the
same ``model.no_sync`` context. For example, there is a model
composed of model_a(optimizer_a) and model_b(optimizer_b).
``OptimWrapperDict.accumulate_grad`` will further
call ``model.no_sync``, which will block the gradient
synchronization of both a and b. If optimizer_a and
optimizer_b have different ``accumulative_iters``, and want to
block the gradient synchronization of model_a and model_b
separately, the model should not implement the ``no_sync``
method(or enable an empty context). The ``accumulate_grad`` context
should be enabled inside the model by accessing corresponding
optimizer wrapper.
"""
with ExitStack() as stack:
for optim_wrapper in self.optim_wrappers.values():
stack.enter_context(
optim_wrapper.accumulate_grad(model, cur_iter, max_iters))
yield
def load_state_dict(self, state_dict: dict) -> None:
"""Load the state dictionary from the ``state_dict``.
Args:
state_dict (dict): Each key-value pair in `state_dict` represents
the name and the state dictionary of corresponding
:obj:`OptimWrapper`.
"""
for name, _state_dict in state_dict.items():
assert name in self.optim_wrappers, (
f'Mismatched `state_dict`! cannot found {name} in '
'OptimWrapperDict')
self.optim_wrappers[name].load_state_dict(_state_dict)
def get_lr(self) -> Dict[str, List[float]]:
"""Get the learning rate of all optimizers.
Returns:
Dict[str, List[float]]: Learning rate of all optimizers.
"""
lr_dict = dict()
for name, optim_wrapper in self.optim_wrappers.items():
lr_dict[f'{name}.lr'] = optim_wrapper.get_lr()['lr']
return lr_dict
def get_momentum(self) -> Dict[str, List[float]]:
"""Get the momentum of all optimizers.
Returns:
Dict[str, List[float]]: momentum of all optimizers.
"""
momentum_dict = dict()
for name, optim_wrapper in self.optim_wrappers.items():
momentum_dict[f'{name}.momentum'] = optim_wrapper.get_momentum(
)['momentum']
return momentum_dict
def state_dict(self) -> dict:
"""Get the state dictionary of all optimizer wrappers.
Returns:
dict: Each key-value pair in the dictionary represents the name
and state dictionary of corresponding :obj:`OptimWrapper`.
"""
state_dict = dict()
for name, optim_wrapper in self.optim_wrappers.items():
state_dict[name] = optim_wrapper.state_dict()
return state_dict
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding
:obj:`OptimWrapper`"""
yield from self.optim_wrappers.items()
def values(self) -> Iterator[OptimWrapper]:
"""A generator to get :obj:`OptimWrapper`"""
yield from self.optim_wrappers.values()
def keys(self) -> Iterator[str]:
"""A generator to get the name of :obj:`OptimWrapper`"""
yield from self.optim_wrappers.keys()
def __getitem__(self, key: str) -> OptimWrapper:
assert key in self.optim_wrappers, (
f'Cannot find {key} in OptimWrapperDict, please check '
'your optimizer constructor.')
return self.optim_wrappers[key]
def __contains__(self, key: str) -> bool:
return key in self.optim_wrappers
def __len__(self) -> int:
return len(self.optim_wrappers)
def __repr__(self) -> str:
desc = ''
for name, optim_wrapper in self.optim_wrappers.items():
desc += f'name: {name}\n'
desc += repr(optim_wrapper)
return desc

View File

@ -22,7 +22,7 @@ class ConstantLR(LRSchedulerMixin, ConstantParamScheduler):
changes to the learning rate value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
factor (float): The number we multiply learning rate until the
milestone. Defaults to 1./3.
begin (int): Step at which to start updating the learning rate.
@ -68,7 +68,7 @@ class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler):
only implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Defaults to 0.
begin (int): Step at which to start updating the learning rate.
@ -92,7 +92,7 @@ class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
begin (int): Step at which to start updating the learning rate.
Defaults to 0.
@ -116,7 +116,7 @@ class LinearLR(LRSchedulerMixin, LinearParamScheduler):
Notice that such decay can happen simultaneously with other changes to the
learning rate from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
start_factor (float): The number we multiply learning rate in the
first epoch. The multiplication factor changes towards end_factor
in the following epochs. Defaults to 1./3.
@ -143,7 +143,7 @@ class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler):
outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
Defaults to 0.1.
@ -167,7 +167,7 @@ class StepLR(LRSchedulerMixin, StepParamScheduler):
other changes to the learning rate from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
Defaults to 0.1.
@ -193,7 +193,7 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
eta_min (float): Minimum learning rate at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.

View File

@ -22,7 +22,8 @@ class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler):
momentum value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
factor (float): The number we multiply momentum until the milestone.
Defaults to 1./3.
begin (int): Step at which to start updating the momentum.
@ -69,7 +70,8 @@ class CosineAnnealingMomentum(MomentumSchedulerMixin,
only implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum momentum value. Defaults to 0.
begin (int): Step at which to start updating the momentum.
@ -93,7 +95,8 @@ class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler):
"""Decays the momentum of each parameter group by gamma every epoch.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
gamma (float): Multiplicative factor of momentum value decay.
begin (int): Step at which to start updating the momentum.
Defaults to 0.
@ -117,7 +120,8 @@ class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler):
Notice that such decay can happen simultaneously with other changes to the
momentum from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
start_factor (float): The number we multiply momentum in the
first epoch. The multiplication factor changes towards end_factor
in the following epochs. Defaults to 1./3.
@ -144,7 +148,8 @@ class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler):
scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of momentum value decay.
Defaults to 0.1.
@ -168,7 +173,8 @@ class StepMomentum(MomentumSchedulerMixin, StepParamScheduler):
to the momentum from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
step_size (int): Period of momentum value decay.
gamma (float): Multiplicative factor of momentum value decay.
Defaults to 0.1.
@ -194,7 +200,8 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
eta_min (float): Minimum momentum at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.

View File

@ -4,14 +4,17 @@ import warnings
import weakref
from collections import Counter
from functools import wraps
from typing import Callable, List
from typing import Callable, List, Union
from torch.optim import Optimizer
from mmengine.optim import OptimWrapper
from mmengine.registry import PARAM_SCHEDULERS
INF = int(1e9)
OptimizerType = Union[OptimWrapper, Optimizer]
class _ParamScheduler:
"""Base class for parameter schedulers.
@ -23,7 +26,7 @@ class _ParamScheduler:
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
begin (int): Step at which to start updating the parameters.
@ -40,7 +43,7 @@ class _ParamScheduler:
""" # noqa: E501
def __init__(self,
optimizer: Optimizer,
optimizer: OptimizerType,
param_name: str,
begin: int = 0,
end: int = INF,
@ -49,7 +52,7 @@ class _ParamScheduler:
verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
if not isinstance(optimizer, (Optimizer, OptimWrapper)):
raise TypeError('``optimizer`` should be an Optimizer,'
'but got {}'.format(type(optimizer).__name__))
self.optimizer = optimizer
@ -111,8 +114,8 @@ class _ParamScheduler:
return wrapper
# add counter to optimizer
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._global_step = -1
self.optimizer.step = with_counter(self.optimizer.step) # type: ignore
self.optimizer._global_step = -1 # type: ignore
self._global_step = -1
self.verbose = verbose
@ -218,7 +221,7 @@ class StepParamScheduler(_ParamScheduler):
other changes to the parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
step_size (int): Period of parameter value decay.
gamma (float): Multiplicative factor of parameter value decay.
Defaults to 0.1.
@ -235,7 +238,7 @@ class StepParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: OptimizerType,
param_name: str,
step_size: int,
gamma: float = 0.1,
@ -304,7 +307,7 @@ class MultiStepParamScheduler(_ParamScheduler):
scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of parameter value decay.
Defaults to 0.1.
@ -321,7 +324,7 @@ class MultiStepParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: OptimizerType,
param_name: str,
milestones: List[int],
gamma: float = 0.1,
@ -391,7 +394,8 @@ class ConstantParamScheduler(_ParamScheduler):
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
factor (float): The number we multiply parameter value until the
milestone. Defaults to 1./3.
begin (int): Step at which to start updating the parameters.
@ -407,7 +411,7 @@ class ConstantParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: OptimizerType,
param_name: str,
factor: float = 1.0 / 3,
begin: int = 0,
@ -477,7 +481,8 @@ class ExponentialParamScheduler(_ParamScheduler):
"""Decays the parameter value of each parameter group by gamma every epoch.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
gamma (float): Multiplicative factor of parameter value decay.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
@ -492,7 +497,7 @@ class ExponentialParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: OptimizerType,
param_name: str,
gamma: float,
begin: int = 0,
@ -573,7 +578,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
only implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum parameter value. Defaults to 0.
begin (int): Step at which to start updating the parameters.
@ -592,7 +598,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: Union[Optimizer, OptimWrapper],
param_name: str,
T_max: int,
eta_min: float = 0.,
@ -670,7 +676,8 @@ class LinearParamScheduler(_ParamScheduler):
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
start_factor (float): The number we multiply parameter value in the
first epoch. The multiplication factor changes towards end_factor
in the following epochs. Defaults to 1./3.
@ -689,7 +696,7 @@ class LinearParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: Union[Optimizer, OptimWrapper],
param_name: str,
start_factor: float = 1.0 / 3,
end_factor: float = 1.0,
@ -765,7 +772,8 @@ class PolyParamScheduler(_ParamScheduler):
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
@ -782,7 +790,7 @@ class PolyParamScheduler(_ParamScheduler):
"""
def __init__(self,
optimizer: Optimizer,
optimizer: Union[Optimizer, OptimWrapper],
param_name: str,
eta_min: float = 0,
power: float = 1.0,

View File

@ -2,17 +2,17 @@
from .default_scope import DefaultScope
from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
WEIGHT_INITIALIZERS)
METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
from .utils import count_registered_modules, traverse_registry_tree
__all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
'LOG_PROCESSORS', 'DefaultScope', 'traverse_registry_tree',
'count_registered_modules'
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope',
'traverse_registry_tree', 'count_registered_modules'
]

View File

@ -32,7 +32,7 @@ WEIGHT_INITIALIZERS = Registry('weight initializer')
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer')
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor')
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry('parameter scheduler')
# manage all kinds of metrics
@ -48,3 +48,6 @@ VISBACKENDS = Registry('vis_backend')
# manage logprocessor
LOG_PROCESSORS = Registry('log_processor')
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper')

View File

@ -6,6 +6,7 @@ import random
import shutil
import time
import warnings
from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union
@ -25,7 +26,8 @@ from mmengine.evaluator import Evaluator
from mmengine.hooks import Hook
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
VISUALIZERS, DefaultScope,
@ -44,6 +46,7 @@ from .priority import Priority, get_priority
ConfigType = Union[Dict, Config, ConfigDict]
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
List[_ParamScheduler]]]
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
class Runner:
@ -97,10 +100,14 @@ class Runner:
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
specified. Defaults to None.
See :meth:`build_test_loop` for more details.
optimizer (Optimizer or dict, optional): Computing gradient of model
parameters. If specified, :attr:`train_dataloader` should also be
specified. Defaults to None.
See :meth:`build_optimizer` for examples.
optim_wrapper (OptimWrapper or dict, optional):
Computing gradient of model parameters. If specified,
:attr:`train_dataloader` should also be specified. If automatic
mixed precision or gradient accmulation
training is required. The type of ``optim_wrapper`` should be
AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for
examples. Defaults to None.
param_scheduler (_ParamScheduler or dict or list, optional):
Parameter scheduler for updating optimizer parameters. If
specified, :attr:`optimizer` should also be specified.
@ -177,7 +184,8 @@ class Runner:
>>> sampler=dict(type='DefaultSampler', shuffle=False),
>>> batch_size=1,
>>> num_workers=0),
>>> optimizer=dict(type='SGD', lr=0.01),
>>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(
>>> type='SGD', lr=0.01)),
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
>>> val_evaluator=dict(type='ToyEvaluator'),
>>> test_evaluator=dict(type='ToyEvaluator'),
@ -217,7 +225,7 @@ class Runner:
train_cfg: Optional[Dict] = None,
val_cfg: Optional[Dict] = None,
test_cfg: Optional[Dict] = None,
optimizer: Optional[Union[Optimizer, Dict]] = None,
optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
@ -249,7 +257,7 @@ class Runner:
self.cfg = Config(dict())
# lazy initialization
training_related = [train_dataloader, train_cfg, optimizer]
training_related = [train_dataloader, train_cfg, optim_wrapper]
if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)):
raise ValueError(
@ -257,14 +265,16 @@ class Runner:
'all None or not None, but got '
f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, '
f'optimizer={optimizer}.')
f'optim_wrapper={optim_wrapper}.')
self._train_dataloader = train_dataloader
self._train_loop = train_cfg
self.optimizer = optimizer
self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
self.optim_wrapper = optim_wrapper
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optimizer is None:
if param_scheduler is not None and self.optim_wrapper is None:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
@ -400,7 +410,7 @@ class Runner:
train_cfg=cfg.get('train_cfg'),
val_cfg=cfg.get('val_cfg'),
test_cfg=cfg.get('test_cfg'),
optimizer=cfg.get('optimizer'),
optim_wrapper=cfg.get('optim_wrapper'),
param_scheduler=cfg.get('param_scheduler'),
val_evaluator=cfg.get('val_evaluator'),
test_evaluator=cfg.get('test_evaluator'),
@ -803,21 +813,25 @@ class Runner:
return model
def build_optimizer(
self, optimizer: Union[Optimizer, Dict]
) -> Union[Optimizer, Dict[str, Optimizer]]:
"""Build an optimizer or multiple optimizers.
def build_optim_wrapper(
self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict]
) -> Union[OptimWrapper, OptimWrapperDict]:
"""Build optimizer wrapper.
Args:
optimizer (Optimizer or dict): An Optimizer object or a dict to
build Optimizer objects. If ``optimizer`` is an Optimizer
object, just returns itself.
optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a
dict to build OptimWrapper objects. If ``optim_wrapper`` is an
OptimWrapper, just return an ``OptimizeWrapper`` instance.
Examples:
>>> # build an optimizer
>>> optim_cfg = dict(type='SGD', lr=0.01)
>>> optimizer = runner.build_optimizer(optim_cfg)
>>> optimizer
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
... type='SGD', lr=0.01))
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
Type: OptimWrapper
accumulative_iters: 1
optimizer:
SGD (
Parameter Group 0
dampening: 0
@ -828,71 +842,85 @@ class Runner:
)
>>> # build multiple optimizers
>>> optim_cfg = dict(
... generator=dict(type='SGD', lr=0.01),
... discriminator=dict(type='Adam',lr=0.02)
>>> optim_wrapper_cfg = dict(
... generator=dict(type='OptimWrapper', optimizer=dict(
... type='SGD', lr=0.01)),
... discriminator=dict(type='OptimWrapper', optimizer=dict(
... type='Adam', lr=0.001))
... # need to customize a multiple optimizer constructor
... constructor='CustomizedMultipleOptimizersConstructor',
...)
>>> optimizer = runner.build_optimizer(optim_cfg)
>>> optimizer
{'generator': SGD (
>>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
name: generator
Type: OptimWrapper
accumulative_iters: 1
optimizer:
SGD (
Parameter Group 0
dampening: 0
lr: 0.01
lr: 0.1
momentum: 0
nesterov: False
weight_decay: 0
),
'discriminator': SGD (
)
name: discriminator
Type: OptimWrapper
accumulative_iters: 1
optimizer:
'discriminator': Adam (
Parameter Group 0
dampening: 0
lr: 0.02
momentum: 0
nesterov: False
weight_decay: 0
)}
)
Important:
If you need to build multiple optimizers, you should implement a
MultipleOptimizerConstructor which gets parameters passed to
corresponding optimizers. More details about how to customize
OptimizerConstructor can be found at `optimizer-docs`_.
corresponding optimizers and compose the ``OptimWrapperDict``.
More details about how to customize OptimizerConstructor can be
found at `optimizer-docs`_.
Returns:
Optimizer or dict[str, Optimizer]: Optimizer build from
``optimizer``.
OptimWrapper: Optimizer wrapper build from ``optimizer_cfg``.
.. _optimizer-docs:
https://mmengine.readthedocs.io/en/latest/tutorials/optimizer.html
"""
if isinstance(optimizer, Optimizer):
return optimizer
elif isinstance(optimizer, dict):
if 'type' not in optimizer and 'constructor' not in optimizer:
for name, optim in optimizer.items():
if not isinstance(optim, Optimizer):
if isinstance(optim_wrapper, OptimWrapper):
return optim_wrapper
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)):
if 'type' not in optim_wrapper and ('constructor'
not in optim_wrapper):
optim_wrappers = OrderedDict()
for name, optim in optim_wrapper.items():
if not isinstance(optim, OptimWrapper):
raise ValueError(
'each item mush be an optimizer object when "type"'
' and "constructor" are not in optimizer, '
f'but got {name}={optim}')
return optimizer
return build_optimizer(self.model, optimizer)
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
else:
raise TypeError('optimizer should be an Optimizer object or dict, '
f'but got {optimizer}')
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')
def _build_param_scheduler(self, scheduler: Union[_ParamScheduler, Dict,
List],
optimizer: Optimizer) -> List[_ParamScheduler]:
def _build_param_scheduler(
self, scheduler: Union[_ParamScheduler, Dict, List],
optim_wrapper: OptimWrapper) -> List[_ParamScheduler]:
"""Build parameter schedulers for a single optimizer.
Args:
scheduler (_ParamScheduler or dict or list): A Param Scheduler
object or a dict or list of dict to build parameter schedulers.
optimizer (Optimizer): An optimizer object is passed to construnct
ParamScheduler object.
optim_wrapper (OptimWrapper): An optimizer wrapper object is
passed to construct ParamScheduler object.
Returns:
list[_ParamScheduler]: List of parameter schedulers build from
@ -922,7 +950,7 @@ class Runner:
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
param_schedulers.append(
cls.build_iter_from_epoch( # type: ignore
optimizer=self.optimizer,
optimizer=optim_wrapper,
**_scheduler,
epoch_length=len(
self.train_dataloader), # type: ignore
@ -931,11 +959,11 @@ class Runner:
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(optimizer=optimizer)))
default_args=dict(optimizer=optim_wrapper)))
else:
raise TypeError(
'_scheduler should be a _ParamScheduler object or dict, '
f'but got {_scheduler}')
'scheduler should be a _ParamScheduler object or dict, '
f'but got {scheduler}')
return param_schedulers
@ -944,9 +972,10 @@ class Runner:
List]) -> ParamSchedulerType:
"""Build parameter schedulers.
``build_param_scheduler`` should be called after ``build_optimizer``
because the building logic will change according to the number of
optimizers built by the runner. The cases are as below:
``build_param_scheduler`` should be called after
``build_optim_wrapper`` because the building logic will change
according to the number of optimizers built by the runner.
The cases are as below:
- Single optimizer: When only one optimizer is built and used in the
runner, ``build_param_scheduler`` will return a list of
@ -968,7 +997,8 @@ class Runner:
Examples:
>>> # build one scheduler
>>> optim_cfg = dict(dict(type='SGD', lr=0.01))
>>> runner.optimizer = runner.build_optimizer(optim_cfg)
>>> runner.optim_wrapper = runner.build_optim_wrapper(
>>> optim_cfg)
>>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2])
>>> schedulers = runner.build_param_scheduler(scheduler_cfg)
>>> schedulers
@ -998,20 +1028,23 @@ class Runner:
https://mmengine.readthedocs.io/en/latest/tutorials/optimizer.html
"""
param_schedulers: ParamSchedulerType
if isinstance(self.optimizer, Optimizer):
param_schedulers = self._build_param_scheduler(
scheduler, self.optimizer)
return param_schedulers
else:
assert isinstance(self.optimizer, dict)
param_schedulers = dict()
for name, optimizer in self.optimizer.items():
if not isinstance(optimizer, Optimizer):
raise RuntimeError(
if not isinstance(self.optim_wrapper, OptimWrapperDict):
# Since `OptimWrapperDict` inherits from `OptimWrapper`,
# `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell
# whether `self.optim_wrapper` is an `OptimizerWrapper` or
# `OptimWrapperDict` instance. Therefore, here we simply check
# self.optim_wrapper is not an `OptimWrapperDict` instance and
# then assert it is an OptimWrapper instance.
assert isinstance(self.optim_wrapper, OptimWrapper), (
'`build_optimizer` should be called before'
'`build_param_scheduler` because the latter depends '
'on the former')
param_schedulers = self._build_param_scheduler(
scheduler, self.optim_wrapper) # type: ignore
return param_schedulers
else:
param_schedulers = dict()
for name, optimizer in self.optim_wrapper.items():
if isinstance(scheduler, dict) and 'type' not in scheduler:
# scheduler is a dict and each item is a ParamScheduler
# object or a config to build ParamScheduler objects
@ -1356,7 +1389,7 @@ class Runner:
# `build_optimizer` should be called before `build_param_scheduler`
# because the latter depends on the former
self.optimizer = self.build_optimizer(self.optimizer)
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
if self.param_schedulers:
self.param_schedulers = self.build_param_scheduler( # type: ignore
@ -1418,9 +1451,6 @@ class Runner:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
**kwargs: Keyword arguments passed to hook.
Raises:
TypeError: if Hook got unexpected arguments.
"""
for hook in self._hooks:
# support adding additional custom hook methods
@ -1645,12 +1675,9 @@ class Runner:
# resume optimizer
if 'optimizer' in checkpoint and resume_optimizer:
self.optimizer = self.build_optimizer(self.optimizer)
if isinstance(self.optimizer, dict):
for name, optimizer in self.optimizer.items():
optimizer.load_state_dict(checkpoint['optimizer'][name])
else:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
self.optim_wrapper.load_state_dict( # type: ignore
checkpoint['optimizer'])
# resume param scheduler
if 'param_schedulers' in checkpoint and resume_param_scheduler:
@ -1771,16 +1798,13 @@ class Runner:
}
# save optimizer state dict to checkpoint
if save_optimizer:
if isinstance(self.optimizer, Optimizer):
checkpoint['optimizer'] = self.optimizer.state_dict()
elif isinstance(self.optimizer, dict):
checkpoint['optimizer'] = dict()
for name, optimizer in self.optimizer.items():
checkpoint['optimizer'][name] = optimizer.state_dict()
if isinstance(self.optim_wrapper, OptimWrapper):
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
else:
raise TypeError(
'self.optimizer should be an optimizer or a dict '
f'containing optimizer, but got {self.optimizer}')
'self.optim_wrapper should be an `OptimWrapper` '
'or `OptimWrapperDict` instance, but got '
f'{self.optim_wrapper}')
# save param scheduler state dict
if save_param_scheduler:

View File

@ -2,7 +2,7 @@
from .hub import load_url
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
find_latest_checkpoint, has_method,
find_latest_checkpoint, has_batch_norm, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, mmcv_full_available,
@ -28,5 +28,5 @@ __all__ = [
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
'set_multi_processing'
'set_multi_processing', 'has_batch_norm'
]

View File

@ -514,3 +514,20 @@ def find_latest_checkpoint(path: str, suffix: str = 'pth'):
latest_path = checkpoint
return latest_path
def has_batch_norm(model: nn.Module) -> bool:
"""Detect whether model has a BatchNormalization layer.
Args:
model (nn.Module): training model.
Returns:
bool: whether model has a BatchNormalization layer
"""
if isinstance(model, _BatchNorm):
return True
for m in model.children():
if has_batch_norm(m):
return True
return False

View File

@ -10,6 +10,7 @@ from torch.utils.data import Dataset
from mmengine.hooks import EMAHook
from mmengine.model import ExponentialMovingAverage
from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, MODEL_WRAPPERS
from mmengine.runner import Runner
@ -79,7 +80,8 @@ class TestEMAHook(TestCase):
num_workers=0),
val_evaluator=evaluator,
work_dir=self.temp_dir.name,
optimizer=torch.optim.Adam(ToyModel().parameters()),
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(by_epoch=True, max_epochs=2),
val_cfg=dict(interval=1),
default_hooks=dict(logger=None),

View File

@ -46,8 +46,8 @@ class TestOptimizerHook:
x = torch.rand(1, 1, 3, 3)
dummy_runner = MagicMock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.optim_wrapper.zero_grad = Mock(return_value=None)
dummy_runner.optim_wrapper.step = Mock(return_value=None)
dummy_runner.model = model
dummy_runner.outputs = dict()
@ -82,7 +82,7 @@ class TestOptimizerHook:
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
dummy_runner.optimizer.step.assert_called()
dummy_runner.optim_wrapper.step.assert_called()
dummy_runner.outputs['loss'].backward.assert_called()
optimizer_hook.clip_grads.assert_called()
optimizer_hook.detect_anomalous_parameters.assert_called()
@ -109,7 +109,7 @@ class TestOptimizerHook:
optimizer_hook.after_train_iter(dummy_runner, 0)
dummy_runner.optimizer.step.assert_called()
dummy_runner.optim_wrapper.step.assert_called()
dummy_runner.outputs['loss'].backward.assert_called()
optimizer_hook.clip_grads.assert_not_called()
optimizer_hook.detect_anomalous_parameters.assert_not_called()

View File

@ -2,8 +2,12 @@
from unittest import TestCase
from unittest.mock import Mock
import torch.nn as nn
from torch.optim import SGD
from mmengine.hooks import RuntimeInfoHook
from mmengine.logging import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict
class TestRuntimeInfoHook(TestCase):
@ -47,18 +51,31 @@ class TestRuntimeInfoHook(TestCase):
self.assertEqual(message_hub.get_info('epoch'), 9)
def test_before_train_iter(self):
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
# single optimizer
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
runner.optimizer.param_groups = [{'lr': 0.01}]
runner.optim_wrapper = optim_wrapper1
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
with self.assertRaisesRegex(AssertionError,
'runner.optim_wrapper.get_lr()'):
runner.optim_wrapper = Mock()
runner.optim_wrapper.get_lr = Mock(return_value='error type')
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
# multiple optimizers
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
@ -68,8 +85,8 @@ class TestRuntimeInfoHook(TestCase):
optimizer1.param_groups = [{'lr': 0.01}]
optimizer2 = Mock()
optimizer2.param_groups = [{'lr': 0.02}]
runner.optimizer = dict(key1=optimizer1, key2=optimizer2)
runner.message_hub = message_hub
runner.optim_wrapper = optim_wrapper_dict
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)

View File

@ -37,6 +37,12 @@ def test_is_model_wrapper():
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
# _verify_model_across_ranks is added in torch1.11.0 so we should check
# whether _verify_params_across_processes is the member of
# torch.distributed before mocking
if hasattr(torch.distributed, '_verify_params_across_processes'):
torch.distributed._verify_params_across_processes = mock
model = Model()
assert not is_model_wrapper(model)

View File

@ -6,8 +6,9 @@ from unittest.mock import MagicMock
import torch
import torch.nn as nn
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer)
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils import mmcv_full_available
@ -201,30 +202,52 @@ class TestBuilder(TestCase):
def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
optimizer = build_optimizer(self.model, optimizer_cfg)
self._check_default_optimizer(optimizer, self.model)
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg)
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
# test build optimizer without type in optim_wrapper_cfg
optim_wrapper_cfg = dict(
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapper)
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
# test build function with invalid ``constructor``
with self.assertRaises(KeyError):
optimizer_cfg['constructor'] = 'INVALID_CONSTRUCTOR'
build_optimizer(self.model, optimizer_cfg)
optim_wrapper_cfg['constructor'] = 'INVALID_CONSTRUCTOR'
build_optim_wrapper(self.model, optim_wrapper_cfg)
# test build function with invalid ``paramwise_cfg``
with self.assertRaises(KeyError):
optimizer_cfg['paramwise_cfg'] = dict(invalid_mult=1)
build_optimizer(self.model, optimizer_cfg)
optim_wrapper_cfg['paramwise_cfg'] = dict(invalid_mult=1)
build_optim_wrapper(self.model, optim_wrapper_cfg)
optim_wrapper_cfg.pop('optimizer')
optim_wrapper_cfg.pop('constructor')
optim_wrapper_cfg.pop('paramwise_cfg')
self.assertRaisesRegex(
AssertionError, '`optim_wrapper_cfg` must contain',
lambda: build_optim_wrapper(self.model, optim_wrapper_cfg))
def test_build_default_optimizer_constructor(self):
optimizer_cfg = dict(
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
@ -232,22 +255,26 @@ class TestBuilder(TestCase):
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
type='DefaultOptimWrapperConstructor',
optim_wrapper_cfg=optim_wrapper,
paramwise_cfg=paramwise_cfg)
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
optimizer = optim_constructor(self.model)
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
optim_constructor_cfg)
optim_wrapper = optim_constructor(self.model)
self._check_sgd_optimizer(optim_wrapper.optimizer, self.model,
**paramwise_cfg)
def test_build_custom_optimizer_constructor(self):
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
@OPTIMIZER_CONSTRUCTORS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class MyOptimizerConstructor(DefaultOptimWrapperConstructor):
def __call__(self, model):
if hasattr(model, 'module'):
@ -268,9 +295,10 @@ class TestBuilder(TestCase):
paramwise_cfg = dict(conv1_lr_mult=5)
optim_constructor_cfg = dict(
type='MyOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
optim_constructor_cfg)
optimizer = optim_constructor(self.model)
param_groups = optimizer.param_groups
@ -291,153 +319,182 @@ class TestBuilder(TestCase):
with self.assertRaises(TypeError):
# optimizer_cfg must be a dict
optimizer_cfg = []
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
optim_constructor(self.model)
with self.assertRaises(TypeError):
# paramwise_cfg must be a dict or None
optimizer_cfg = dict(lr=0.0001)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(lr=0.0001, weight_decay=None))
paramwise_cfg = ['error']
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_constructor(self.model)
with self.assertRaises(ValueError):
# bias_decay_mult/norm_decay_mult is specified but weight_decay
# is None
optimizer_cfg = dict(lr=0.0001, weight_decay=None)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(lr=0.0001, weight_decay=None))
paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_constructor(self.model)
# basic config with ExampleModel
optimizer_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(self.model)
self._check_default_optimizer(optimizer, self.model)
momentum=self.momentum))
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
optim_wrapper = optim_constructor(self.model)
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
def test_default_optimizer_constructor_with_model_wrapper(self):
# basic config with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = None
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
self._check_default_optimizer(optimizer, model, prefix='module.')
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
optim_wrapper = optim_constructor(model)
self._check_default_optimizer(
optim_wrapper.optimizer, model, prefix='module.')
# paramwise_cfg with pseudo data parallel
model = PseudoDataParallel()
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
self._check_sgd_optimizer(
optimizer, model, prefix='module.', **paramwise_cfg)
optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg)
# basic config with DataParallel
if torch.cuda.is_available():
model = torch.nn.DataParallel(ExampleModel())
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = None
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
self._check_default_optimizer(optimizer, model, prefix='module.')
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg)
optim_wrapper = optim_constructor(model)
self._check_default_optimizer(
optim_wrapper.optimizer, model, prefix='module.')
# paramwise_cfg with DataParallel
if torch.cuda.is_available():
model = torch.nn.DataParallel(self.model)
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
self._check_sgd_optimizer(
optimizer, model, prefix='module.', **paramwise_cfg)
optim_wrapper.optimizer,
model,
prefix='module.',
**paramwise_cfg)
def test_default_optimizer_constructor_with_empty_paramwise_cfg(self):
# Empty paramwise_cfg with ExampleModel
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict()
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
self._check_default_optimizer(optimizer, self.model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
# Empty paramwise_cfg with ExampleModel and no grad
model = ExampleModel()
for param in model.parameters():
param.requires_grad = False
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict()
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
optimizer = optim_constructor(model)
self._check_default_optimizer(optimizer, model)
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
optim_wrapper = optim_constructor(model)
self._check_default_optimizer(optim_wrapper.optimizer, model)
def test_default_optimizer_constructor_with_paramwise_cfg(self):
# paramwise_cfg with ExampleModel
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
self._check_sgd_optimizer(optim_wrapper.optimizer, self.model,
**paramwise_cfg)
def test_default_optimizer_constructor_no_grad(self):
# paramwise_cfg with ExampleModel and no grad
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
@ -447,11 +504,12 @@ class TestBuilder(TestCase):
for param in self.model.parameters():
param.requires_grad = False
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
optimizer = optim_wrapper.optimizer
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
@ -465,11 +523,13 @@ class TestBuilder(TestCase):
def test_default_optimizer_constructor_bypass_duplicate(self):
# paramwise_cfg with bypass_duplicate option
model = ExampleDuplicateModel()
optimizer_cfg = dict(
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
bias_lr_mult=2,
bias_decay_mult=0.5,
@ -479,8 +539,8 @@ class TestBuilder(TestCase):
with self.assertRaisesRegex(
ValueError,
'some parameters appear in more than one parameter group'):
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_constructor(model)
paramwise_cfg = dict(
@ -490,27 +550,31 @@ class TestBuilder(TestCase):
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
bypass_duplicate=True)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
self.assertWarnsRegex(
Warning,
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
lambda: optim_constructor(model))
optimizer = optim_constructor(model)
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(
optimizer.param_groups) == len(model_parameters) == num_params
self._check_sgd_optimizer(optimizer, model, **paramwise_cfg)
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)
def test_default_optimizer_constructor_custom_key(self):
# test DefaultOptimizerConstructor with custom_keys and ExampleModel
optimizer_cfg = dict(
# test DefaultOptimWrapperConstructor with custom_keys and
# ExampleModel
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum)
momentum=self.momentum))
paramwise_cfg = dict(
custom_keys={
'param1': dict(lr_mult=10),
@ -523,23 +587,24 @@ class TestBuilder(TestCase):
with self.assertRaises(TypeError):
# custom_keys should be a dict
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg_)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg_)
optimizer = optim_constructor(self.model)
with self.assertRaises(ValueError):
# if 'decay_mult' is specified in custom_keys, weight_decay
# should be specified
optimizer_cfg_ = dict(type='SGD', lr=0.01)
optim_wrapper_cfg_ = dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))
paramwise_cfg_ = dict(
custom_keys={'.backbone': dict(decay_mult=0.5)})
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg_, paramwise_cfg_)
optimizer = optim_constructor(self.model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg_, paramwise_cfg_)
optim_constructor(self.model)
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model).optimizer
# check optimizer type and default config
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
@ -598,14 +663,17 @@ class TestBuilder(TestCase):
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'
# test DefaultOptimizerConstructor with custom_keys and ExampleModel 2
optimizer_cfg = dict(
type='SGD', lr=self.base_lr, momentum=self.momentum)
# test DefaultOptimWrapperConstructor with custom_keys and
# ExampleModel 2
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=self.base_lr, momentum=self.momentum))
paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)})
optim_constructor = DefaultOptimizerConstructor(
optimizer_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optimizer = optim_constructor(self.model).optimizer
# check optimizer type and default config
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr

View File

@ -0,0 +1,384 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from unittest import TestCase
from unittest.mock import MagicMock
import torch
import torch.distributed as torch_dist
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer
from mmengine import MessageHub, MMLogger
from mmengine.dist import all_gather
from mmengine.optim import AmpOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import TORCH_VERSION, digit_version
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class ToyModel2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
class TestOptimWrapper(MultiProcessTestCase):
# Test `OptimWrapper.accumulate_grad` will block the gradient
# synchronization when using gradient accumulation strategy in distributed
# data parallel training.
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def run_test(self, test_name: str, parent_pipe) -> None:
self.model = ToyModel()
self.optimizer = SGD(self.model.parameters(), lr=0.1)
self.logger = MMLogger.get_instance('test_optim_wrapper')
self.message_hub = MessageHub.get_instance('test_optim_wrapper_init')
super().run_test(test_name, parent_pipe)
def test_init(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.optimizer, self.optimizer)
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
self.assertEqual(optim_wrapper.accumulative_iters, 1)
self.assertIs(optim_wrapper.logger, self.logger)
self.assertIs(optim_wrapper.message_hub, self.message_hub)
with self.assertRaisesRegex(AssertionError,
'If `clip_grad_kwargs` is not None'):
OptimWrapper(self.optimizer, clip_grad=[])
def test_update_params(self):
# Test update params every iteration.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
self._mock_method(optim_wrapper)
loss = torch.tensor(1)
optim_wrapper.update_params(loss)
optim_wrapper.backward.assert_called_with(torch.tensor(1))
optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with()
with optim_wrapper.accumulate_grad(self.model, 2, 100):
optim_wrapper.update_params(torch.tensor(1))
optim_wrapper.backward.assert_called_with(torch.tensor(1))
optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with()
# It will raise an error if `accumulative_iters > 1` and
# `accumulate_grad` is not enabled.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
self._mock_method(optim_wrapper)
with self.assertRaisesRegex(AssertionError,
'gradient accumulation must be'):
optim_wrapper.update_params(loss)
# `iter=0`, Call `optimizer_step` first time.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=0, max_iters=100):
loss = torch.tensor(1)
optim_wrapper.update_params(loss)
optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
# `iter=2`, Call `optimizer_step` first time.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=2, max_iters=100):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
self._mock_method(optim_wrapper)
# Test end of training.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=99, max_iters=100):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
optim_wrapper.backward.assert_called_with(1)
# If ``accumulative_iters > 1``, call ``update_params`` with
# non-accumulate_grad context will raise an Assertion error
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
optim_wrapper.accumulative_iters = 2
with self.assertRaisesRegex(AssertionError,
'gradient accumulation must be performed'):
optim_wrapper.update_params(loss)
def test_initilize_iter_status(self):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper._initilize_iter_status(self.model)
self.assertEqual(optim_wrapper.divisible_iters, 0)
self.assertEqual(optim_wrapper.remainder_iters, 0)
# Indivisible cur_iter will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper.cur_iter = 0
optim_wrapper.max_iters = 100
with self.assertLogs(self.logger) as cm:
optim_wrapper._initilize_iter_status(self.model)
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Resume iter number is not')
# Model with batch norm will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper.cur_iter = 0
optim_wrapper.max_iters = 99
model = nn.BatchNorm2d(1)
with self.assertLogs(self.logger) as cm:
optim_wrapper._initilize_iter_status(model)
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Gradient accumulative')
def test_ger_lr(self):
model = ToyModel()
optim = SGD(model.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
def test_get_momentum(self):
# Get momentum from SGD
model = ToyModel()
optim = SGD(model.parameters(), lr=0., momentum=0.8)
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.8]))
# Get momentum from Adam
optim = Adam(model.parameters(), lr=0., betas=(0.9, 0.9))
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.9]))
def test_backward(self):
loss = MagicMock()
optim_wrapper = OptimWrapper(self.optimizer)
optim_wrapper.backward(loss)
loss.backward.assert_called()
def test_zero_grad(self):
optimizer = MagicMock(spec=Optimizer)
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper.zero_grad()
optimizer.zero_grad.assert_called()
def test_step(self):
optimizer = MagicMock(spec=Optimizer)
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper.step()
optimizer.step.assert_called()
def test_clip_grads(self):
optim_wrapper = OptimWrapper(
self.optimizer, clip_grad=dict(max_norm=35))
loss = self.model(torch.Tensor(1, 1, 1, 1))
loss.backward()
optim_wrapper._clip_grad()
log_scalars = self.message_hub.log_scalars
self.assertIn('train/grad_norm', log_scalars)
def test_state_dict(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.state_dict(),
self.optimizer.state_dict())
def test_load_state_dict(self):
optim_wrapper = OptimWrapper(self.optimizer)
model = ToyModel()
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
def test_param_groups(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.param_groups,
self.optimizer.param_groups)
def test_accumulate_grad(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel2()
ddp_model = DistributedDataParallel(model)
optimizer = SGD(ddp_model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
optim_wrapper.zero_grad()
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
# Automatically sync grads if `accumulative_iters` = 1
inputs = torch.randn(1, 1, 1, 1) * self.rank
ddp_model(inputs).sum().backward()
grad = model.conv.weight.grad
all_grads = all_gather(grad)
assert_allclose(all_grads[0], all_grads[1])
# Do not sync grads when `optim_wrapper.cur_iter` cannot be
# divided by `optim_wrapper.accumulative_iters`
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3)
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
ddp_model(inputs).sum().backward()
all_grads = all_gather(model.conv.weight.grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
# sync grads if `cur_iter == 2`
with optim_wrapper.accumulate_grad(ddp_model, 2, 100):
ddp_model(inputs).sum().backward()
all_grads = all_gather(model.conv.weight.grad)
assert_allclose(all_grads[0], all_grads[1])
def test_precision_context(self):
optim_wrapper = OptimWrapper(self.optimizer)
with optim_wrapper.precision_context():
pass
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29515'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
# TODO Test the real interface after add testing tool function which can
# test the function or method is read called.
def _mock_method(self, optim_wrapper):
optim_wrapper.backward = MagicMock()
optim_wrapper.step = MagicMock()
optim_wrapper.zero_grad = MagicMock()
class TestAmpOptimWrapper(TestCase):
def setUp(self) -> None:
self.model = ToyModel()
self.optimizer = SGD(self.model.parameters(), lr=0.1)
@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_init(self):
# Test with default arguments.
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
# Test with dynamic.
amp_optim_wrapper = AmpOptimWrapper(
'dynamic', optimizer=self.optimizer)
self.assertIsNone(amp_optim_wrapper._scale_update_param)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
# Test with dict loss_scale.
amp_optim_wrapper = AmpOptimWrapper(
dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
self.assertIsNone(amp_optim_wrapper._scale_update_param)
with self.assertRaisesRegex(TypeError,
'loss_scale must be of type float'):
AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown')
@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_step(self):
optimizer = MagicMock(spec=Optimizer)
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)
amp_optim_wrapper.loss_scaler = MagicMock()
amp_optim_wrapper.step()
amp_optim_wrapper.loss_scaler.step.assert_called_with(
amp_optim_wrapper.optimizer)
amp_optim_wrapper.loss_scaler.update.assert_called_with(
amp_optim_wrapper._scale_update_param)
@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_backward(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
loss_scaler = MagicMock()
scale_return = MagicMock()
scale_fn = MagicMock(return_value=scale_return)
loss_scaler.scale = scale_fn
amp_optim_wrapper.loss_scaler = loss_scaler
amp_optim_wrapper.backward(1)
loss_scaler.scale.assert_called_with(1)
scale_return.backward.assert_called_with()
@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_state_dict(self):
self.model = self.model.cuda()
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
amp_optim_wrapper.update_params(loss)
state_dict = amp_optim_wrapper.state_dict()
scalar_state_dict = state_dict.pop('loss_scaler')
optim_state_dict = state_dict
self.assertDictEqual(optim_state_dict,
amp_optim_wrapper.optimizer.state_dict())
self.assertDictEqual(scalar_state_dict,
amp_optim_wrapper.loss_scaler.state_dict())
@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_load_state_dict(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
self.model = self.model.cuda()
# Test load from optimizer
optimizer = SGD(self.model.parameters(), lr=0.1)
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertDictEqual(optimizer.state_dict(),
amp_optim_wrapper.optimizer.state_dict())
# Test load from optim_wrapper
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
amp_optim_wrapper_ = AmpOptimWrapper(
optimizer=SGD(self.model.parameters(), lr=0.1))
amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict())
self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(),
amp_optim_wrapper_.optimizer.state_dict())
self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(),
amp_optim_wrapper_.loss_scaler.state_dict())
@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_precision_context(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
with amp_optim_wrapper.precision_context():
x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.float16)

View File

@ -0,0 +1,142 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from unittest import TestCase
from unittest.mock import patch
import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
class TestOptimWrapperDict(TestCase):
def setUp(self) -> None:
model1 = nn.Linear(1, 1)
model2 = nn.Linear(1, 1)
self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8)
self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9)
self.optim_wrapper1 = OptimWrapper(self.optim1)
self.optim_wrapper2 = OptimWrapper(self.optim2)
self.optimizers_wrappers = dict(
optim1=self.optim_wrapper1, optim2=self.optim_wrapper2)
@patch('torch.cuda.is_available', lambda: True)
def test_init(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertEqual(optim_wrapper_dict.optim_wrappers,
self.optimizers_wrappers)
# Different types of OptimWrapper will raise an error
with self.assertRaisesRegex(
AssertionError, 'All optimizer wrappers should have the same'):
optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2)
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'):
optim_wrapper2 = OptimWrapper(
optimizer=self.optim2, accumulative_iters=2)
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
def test_accumulate_grad(self):
@contextmanager
def context_a(a, b, *args, **kwargs):
a[0] = 100
yield
a[0] = 1
@contextmanager
def context_b(a, b, *args, **kwargs):
b[0] = 200
yield
b[0] = 2
a = [0]
b = [0]
# Test enter the context both of `optim_wrapper1` and `optim_wrapper1`.
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.optim_wrapper1.accumulate_grad = context_a
self.optim_wrapper2.accumulate_grad = context_b
with optim_wrapper_dict.accumulate_grad(a, b, 0):
self.assertEqual(a[0], 100)
self.assertEqual(b[0], 200)
self.assertEqual(a[0], 1)
self.assertEqual(b[0], 2)
def test_state_dict(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
state_dict = optim_wrapper_dict.state_dict()
self.assertEqual(state_dict['optim1'],
self.optim_wrapper1.state_dict())
self.assertEqual(state_dict['optim2'],
self.optim_wrapper2.state_dict())
def test_load_state_dict(self):
# Test OptimWrapperDict can load from saved state dict.
model1 = nn.Linear(1, 1)
model2 = nn.Linear(1, 1)
optim1 = SGD(model1.parameters(), lr=0.1)
optim2 = SGD(model2.parameters(), lr=0.1)
optim_wrapper_load1 = OptimWrapper(optim1)
optim_wrapper_load2 = OptimWrapper(optim2)
optim_wrapper_dict_save = OptimWrapperDict(**self.optimizers_wrappers)
optim_wrapper_dict_load = OptimWrapperDict(
optim1=optim_wrapper_load1, optim2=optim_wrapper_load2)
state_dict = optim_wrapper_dict_save.state_dict()
optim_wrapper_dict_load.load_state_dict(state_dict)
self.assertDictEqual(optim_wrapper_dict_load.state_dict(),
optim_wrapper_dict_save.state_dict())
def test_items(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertListEqual(
list(optim_wrapper_dict.items()),
list(self.optimizers_wrappers.items()))
def test_values(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertListEqual(
list(optim_wrapper_dict.values()),
list(self.optimizers_wrappers.values()))
def test_keys(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertListEqual(
list(optim_wrapper_dict.keys()),
list(self.optimizers_wrappers.keys()))
def test_get_lr(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
lr = optim_wrapper_dict.get_lr()
self.assertEqual(lr['optim1.lr'], [0.1])
self.assertEqual(lr['optim2.lr'], [0.2])
def test_get_momentum(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
momentum = optim_wrapper_dict.get_momentum()
self.assertEqual(momentum['optim1.momentum'], [0.8])
self.assertEqual(momentum['optim2.momentum'], [0.9])
def test_getitem(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertIs(self.optimizers_wrappers['optim1'],
optim_wrapper_dict['optim1'])
self.assertIs(self.optimizers_wrappers['optim2'],
optim_wrapper_dict['optim2'])
def test_len(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertEqual(len(optim_wrapper_dict), 2)
def test_contain(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertIn('optim1', optim_wrapper_dict)
def test_repr(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
desc = repr(optim_wrapper_dict)
self.assertRegex(desc, 'name: optim1')

View File

@ -16,13 +16,15 @@ from mmengine.config import Config
from mmengine.data import DefaultSampler
from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, OptimizerHook,
ParamSchedulerHook)
IterTimerHook, LoggerHook, ParamSchedulerHook,
RuntimeInfoHook)
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.optim import DefaultOptimizerConstructor, MultiStepLR, StepLR
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR)
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
PARAM_SCHEDULERS, Registry)
MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.priority import Priority, get_priority
@ -73,24 +75,24 @@ class CustomModelWrapper(nn.Module):
self.model = model
@OPTIMIZER_CONSTRUCTORS.register_module()
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class ToyMultipleOptimizerConstructor:
def __init__(self, optimizer_cfg, paramwise_cfg=None):
if not isinstance(optimizer_cfg, dict):
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
if not isinstance(optim_wrapper_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}')
f'but got {type(optim_wrapper_cfg)}')
assert paramwise_cfg is None, (
'parawise_cfg should be set in each optimizer separately')
self.optimizer_cfg = optimizer_cfg
self.optim_wrapper_cfg = optim_wrapper_cfg
self.constructors = {}
for key, cfg in self.optimizer_cfg.items():
for key, cfg in self.optim_wrapper_cfg.items():
_cfg = cfg.copy()
paramwise_cfg_ = _cfg.pop('paramwise_cfg', None)
self.constructors[key] = DefaultOptimizerConstructor(
self.constructors[key] = DefaultOptimWrapperConstructor(
_cfg, paramwise_cfg_)
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
def __call__(self, model: nn.Module) -> OptimWrapperDict:
optimizers = {}
while hasattr(model, 'module'):
model = model.module
@ -98,7 +100,7 @@ class ToyMultipleOptimizerConstructor:
for key, constructor in self.constructors.items():
module = getattr(model, key)
optimizers[key] = constructor(module)
return optimizers
return OptimWrapperDict(**optimizers)
@DATASETS.register_module()
@ -160,13 +162,6 @@ class ToyHook2(Hook):
pass
@HOOKS.register_module()
class ToyHook3(Hook):
def before_train_iter(self, runner, data_batch):
pass
@LOOPS.register_module()
class CustomTrainLoop(BaseLoop):
@ -246,7 +241,8 @@ class TestRunner(TestCase):
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
optimizer=dict(type='SGD', lr=0.01),
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
@ -299,7 +295,7 @@ class TestRunner(TestCase):
# also be None
cfg.experiment_name = 'test_init2'
cfg.pop('train_dataloader')
cfg.pop('optimizer')
cfg.pop('optim_wrapper')
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
@ -324,7 +320,7 @@ class TestRunner(TestCase):
cfg.experiment_name = 'test_init5'
cfg.pop('train_cfg')
cfg.pop('train_dataloader')
cfg.pop('optimizer')
cfg.pop('optim_wrapper')
with self.assertRaisesRegex(ValueError, 'should be None'):
runner = Runner(**cfg)
@ -398,7 +394,7 @@ class TestRunner(TestCase):
self.assertIsInstance(runner._train_dataloader, dict)
self.assertIsInstance(runner._val_dataloader, dict)
self.assertIsInstance(runner._test_dataloader, dict)
self.assertIsInstance(runner.optimizer, dict)
self.assertIsInstance(runner.optim_wrapper, dict)
self.assertIsInstance(runner.param_schedulers[0], dict)
# After calling runner.train(),
@ -408,7 +404,7 @@ class TestRunner(TestCase):
self.assertIsInstance(runner._train_loop, BaseLoop)
self.assertIsInstance(runner.train_dataloader, DataLoader)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper, OptimWrapper)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertIsInstance(runner._val_loop, BaseLoop)
self.assertIsInstance(runner._val_loop.dataloader, DataLoader)
@ -423,10 +419,10 @@ class TestRunner(TestCase):
# 4. initialize runner with objects rather than config
model = ToyModel()
optimizer = SGD(
optim_wrapper = OptimWrapper(SGD(
model.parameters(),
lr=0.01,
)
))
toy_hook = ToyHook()
toy_hook2 = ToyHook2()
@ -438,8 +434,8 @@ class TestRunner(TestCase):
work_dir=self.temp_dir,
train_cfg=dict(by_epoch=True, max_epochs=3),
train_dataloader=train_dataloader,
optimizer=optimizer,
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
optim_wrapper=optim_wrapper,
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
val_cfg=dict(interval=1, begin=1),
val_dataloader=val_dataloader,
val_evaluator=ToyMetric1(),
@ -618,79 +614,92 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, CustomModelWrapper)
def test_build_optimizer(self):
def test_build_optim_wrapper(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_optimizer'
cfg.experiment_name = 'test_build_optim_wrapper'
runner = Runner.from_cfg(cfg)
# input should be an Optimizer object or dict
with self.assertRaisesRegex(TypeError, 'optimizer should be'):
runner.build_optimizer('invalid-type')
with self.assertRaisesRegex(TypeError, 'optimizer wrapper should be'):
runner.build_optim_wrapper('invalid-type')
# 1. test one optimizer
# 1.1 input is an Optimizer object
_optimizer = SGD(runner.model.parameters(), lr=0.01)
optimizer = runner.build_optimizer(_optimizer)
self.assertEqual(id(_optimizer), id(optimizer))
optimizer = SGD(runner.model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper)
self.assertEqual(id(optimizer), id(optim_wrapper.optimizer))
# 1.2 input is a dict
optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01))
self.assertIsInstance(optimizer, SGD)
optim_wrapper = runner.build_optim_wrapper(
dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))
self.assertIsInstance(optim_wrapper, OptimWrapper)
# 2. test multiple optmizers
# 2.1 input is a dict which contains multiple optimizer objects
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
optim_wrapper1 = OptimWrapper(optimizer1)
optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02)
optim_cfg = dict(key1=optimizer1, key2=optimizer2)
optimizer = runner.build_optimizer(optim_cfg)
self.assertIsInstance(optimizer, dict)
self.assertIsInstance(optimizer['key1'], SGD)
self.assertIsInstance(optimizer['key2'], Adam)
optim_wrapper2 = OptimWrapper(optimizer2)
optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapperDict)
self.assertIsInstance(optim_wrapper['key1'].optimizer, SGD)
self.assertIsInstance(optim_wrapper['key2'].optimizer, Adam)
# 2.2 each item mush be an optimizer object when "type" and
# "constructor" are not in optimizer
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
optimizer2 = dict(type='Adam', lr=0.02)
optim_cfg = dict(key1=optimizer1, key2=optimizer2)
optim_wrapper1 = OptimWrapper(optimizer1)
optim_wrapper2 = dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01))
optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
with self.assertRaisesRegex(ValueError,
'each item mush be an optimizer object'):
runner.build_optimizer(optim_cfg)
runner.build_optim_wrapper(optim_cfg)
# 2.3 input is a dict which contains multiple configs
optim_cfg = dict(
linear1=dict(type='SGD', lr=0.01),
linear2=dict(type='Adam', lr=0.02),
optim_wrapper_cfg = dict(
linear1=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
linear2=dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
constructor='ToyMultipleOptimizerConstructor')
optimizer = runner.build_optimizer(optim_cfg)
self.assertIsInstance(optimizer, dict)
self.assertIsInstance(optimizer['linear1'], SGD)
self.assertIsInstance(optimizer['linear2'], Adam)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapperDict)
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
def test_build_param_scheduler(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler'
runner = Runner.from_cfg(cfg)
# `build_optimizer` should be called before `build_param_scheduler`
# `build_optim_wrapper` should be called before
# `build_param_scheduler`
cfg = dict(type='MultiStepLR', milestones=[1, 2])
runner.optimizer = dict(
key1=dict(type='SGD', lr=0.01),
key2=dict(type='Adam', lr=0.02),
runner.optim_wrapper = dict(
key1=dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)),
key2=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.02)),
)
with self.assertRaisesRegex(RuntimeError, 'should be called before'):
with self.assertRaisesRegex(AssertionError, 'should be called before'):
runner.build_param_scheduler(cfg)
runner.optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01))
runner.optim_wrapper = runner.build_optim_wrapper(
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
param_schedulers = runner.build_param_scheduler(cfg)
self.assertIsInstance(param_schedulers, list)
self.assertEqual(len(param_schedulers), 1)
self.assertIsInstance(param_schedulers[0], MultiStepLR)
# 1. test one optimizer and one parameter scheduler
# 1.1 input is a ParamScheduler object
param_scheduler = MultiStepLR(runner.optimizer, milestones=[1, 2])
param_scheduler = MultiStepLR(runner.optim_wrapper, milestones=[1, 2])
param_schedulers = runner.build_param_scheduler(param_scheduler)
self.assertEqual(len(param_schedulers), 1)
self.assertEqual(id(param_schedulers[0]), id(param_scheduler))
# 1.2 input is a dict
cfg = dict(type='MultiStepLR', milestones=[1, 2])
param_schedulers = runner.build_param_scheduler(param_scheduler)
self.assertEqual(len(param_schedulers), 1)
self.assertIsInstance(param_schedulers[0], MultiStepLR)
@ -715,9 +724,11 @@ class TestRunner(TestCase):
# 3. test multiple optimizers and list of parameter schedulers
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
optim_wrapper1 = OptimWrapper(optimizer1)
optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02)
optim_cfg = dict(key1=optimizer1, key2=optimizer2)
runner.optimizer = runner.build_optimizer(optim_cfg)
optim_wrapper2 = OptimWrapper(optimizer2)
optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
runner.optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
cfg = [
dict(type='MultiStepLR', milestones=[1, 2]),
dict(type='StepLR', step_size=1)
@ -743,7 +754,8 @@ class TestRunner(TestCase):
self.assertEqual(len(param_schedulers['key2']), 2)
# 5. test converting epoch-based scheduler to iter-based
runner.optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01))
runner.optim_wrapper = runner.build_optim_wrapper(
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
# 5.1 train loop should be built before converting scheduler
cfg = dict(
@ -752,7 +764,7 @@ class TestRunner(TestCase):
AssertionError,
'Scheduler can only be converted to iter-based when '
'train loop is built.'):
param_schedulers = runner.build_param_scheduler(cfg)
runner.build_param_scheduler(cfg)
# 5.2 convert epoch-based to iter-based scheduler
cfg = dict(
@ -947,7 +959,7 @@ class TestRunner(TestCase):
cfg.experiment_name = 'test_train1'
cfg.pop('train_dataloader')
cfg.pop('train_cfg')
cfg.pop('optimizer')
cfg.pop('optim_wrapper')
cfg.pop('param_scheduler')
runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
@ -1063,7 +1075,7 @@ class TestRunner(TestCase):
cfg.experiment_name = 'test_individually_val'
cfg.pop('train_dataloader')
cfg.pop('train_cfg')
cfg.pop('optimizer')
cfg.pop('optim_wrapper')
cfg.pop('param_scheduler')
cfg.pop('test_dataloader')
cfg.pop('test_cfg')
@ -1091,7 +1103,7 @@ class TestRunner(TestCase):
cfg.experiment_name = 'test_individually_test'
cfg.pop('train_dataloader')
cfg.pop('train_cfg')
cfg.pop('optimizer')
cfg.pop('optim_wrapper')
cfg.pop('param_scheduler')
cfg.pop('val_dataloader')
cfg.pop('val_cfg')
@ -1132,15 +1144,15 @@ class TestRunner(TestCase):
get_priority('BELOW_NORMAL'))
# 1.3 `hook` is a hook object
optimizer_hook = OptimizerHook()
runner.register_hook(optimizer_hook)
runtime_info_hook = RuntimeInfoHook()
runner.register_hook(runtime_info_hook)
self.assertEqual(len(runner._hooks), 2)
# The priority of `OptimizerHook` is `HIGH` which is greater than
# The priority of `runtime_info_hook` is `HIGH` which is greater than
# `IterTimerHook`, so the first item of `_hooks` should be
# `OptimizerHook`
self.assertTrue(isinstance(runner._hooks[0], OptimizerHook))
# `runtime_info_hook`
self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook))
self.assertEqual(
get_priority(runner._hooks[0].priority), get_priority('HIGH'))
get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH'))
# 2. test `priority` parameter
# `priority` argument is not None and it will be set as priority of
@ -1198,29 +1210,6 @@ class TestRunner(TestCase):
self.assertEqual(len(runner._hooks), 8)
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
def test_call_hook(self):
# test unexpected argument in `call_hook`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_call_hook1'
runner = Runner.from_cfg(cfg)
runner._hooks = []
custom_hooks = [dict(type='ToyHook3')]
runner.register_custom_hooks(custom_hooks)
with self.assertRaisesRegex(
TypeError,
r"got an unexpected keyword argument 'batch_idx' in "
r'<test_runner.ToyHook3 object at \w+>'):
runner.call_hook('before_train_iter', batch_idx=0, data_batch=None)
# test call hook with expected arguments
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_call_hook2'
runner = Runner.from_cfg(cfg)
runner._hooks = []
custom_hooks = [dict(type='ToyHook3')]
runner.register_custom_hooks(custom_hooks)
runner.call_hook('before_train_iter', data_batch=None)
def test_register_hooks(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hooks'
@ -1229,7 +1218,7 @@ class TestRunner(TestCase):
runner._hooks = []
custom_hooks = [dict(type='ToyHook')]
runner.register_hooks(custom_hooks=custom_hooks)
# 7 default hooks + custom hook (ToyHook)
# six default hooks + custom hook (ToyHook)
self.assertEqual(len(runner._hooks), 8)
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
@ -1336,7 +1325,7 @@ class TestRunner(TestCase):
# 1.2 test `load_checkpoint`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint2'
cfg.optimizer = dict(type='SGD', lr=0.2)
cfg.optim_wrapper = dict(type='SGD', lr=0.2)
cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path)
@ -1345,22 +1334,24 @@ class TestRunner(TestCase):
self.assertTrue(runner._has_loaded)
# load checkpoint will not initialize optimizer and param_schedulers
# objects
self.assertIsInstance(runner.optimizer, dict)
self.assertIsInstance(runner.optim_wrapper, dict)
self.assertIsInstance(runner.param_schedulers, list)
self.assertIsInstance(runner.param_schedulers[0], dict)
# 1.3 test `resume`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint3'
cfg.optimizer = dict(type='SGD', lr=0.2)
cfg.optim_wrapper = dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2))
cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
runner = Runner.from_cfg(cfg)
runner.resume(path)
self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertEqual(runner.optimizer.param_groups[0]['lr'], 0.0001)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertEqual(runner.optim_wrapper.param_groups[0]['lr'], 0.0001)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
@ -1373,7 +1364,7 @@ class TestRunner(TestCase):
self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 1.5 test resume from a specified checkpoint
@ -1386,46 +1377,48 @@ class TestRunner(TestCase):
self.assertEqual(runner.epoch, 1)
self.assertEqual(runner.iter, 4)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 1.6 multiple optimizers
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
cfg.optimizer = dict(
linear1=dict(type='SGD', lr=0.01),
linear2=dict(type='Adam', lr=0.02),
constructor='ToyMultipleOptimizerConstructor',
)
cfg.optim_wrapper = dict(
linear1=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
linear2=dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
constructor='ToyMultipleOptimizerConstructor')
# disable OptimizerHook because it only works with one optimizer
cfg.default_hooks = dict(optimizer=None)
runner = Runner.from_cfg(cfg)
runner.train()
path = osp.join(self.temp_dir, 'epoch_3.pth')
self.assertTrue(osp.exists(path))
self.assertEqual(runner.optimizer['linear1'].param_groups[0]['lr'],
self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'],
0.0001)
self.assertIsInstance(runner.optimizer['linear2'], Adam)
self.assertEqual(runner.optimizer['linear2'].param_groups[0]['lr'],
self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam)
self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'],
0.0002)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint7'
cfg.optimizer = dict(
linear1=dict(type='SGD', lr=0.2),
linear2=dict(type='Adam', lr=0.03),
constructor='ToyMultipleOptimizerConstructor',
)
cfg.optim_wrapper = dict(
linear1=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)),
linear2=dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)),
constructor='ToyMultipleOptimizerConstructor')
cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
cfg.default_hooks = dict(optimizer=None)
runner = Runner.from_cfg(cfg)
runner.resume(path)
self.assertIsInstance(runner.optimizer, dict)
self.assertIsInstance(runner.optimizer['linear1'], SGD)
self.assertEqual(runner.optimizer['linear1'].param_groups[0]['lr'],
self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict)
self.assertIsInstance(runner.optim_wrapper['linear1'].optimizer, SGD)
self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'],
0.0001)
self.assertIsInstance(runner.optimizer['linear2'], Adam)
self.assertEqual(runner.optimizer['linear2'].param_groups[0]['lr'],
self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam)
self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'],
0.0002)
self.assertIsInstance(runner.param_schedulers, dict)
self.assertEqual(len(runner.param_schedulers['linear1']), 1)
@ -1479,7 +1472,7 @@ class TestRunner(TestCase):
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2.4 test auto resume
@ -1491,7 +1484,7 @@ class TestRunner(TestCase):
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2.5 test resume from a specified checkpoint
@ -1504,5 +1497,5 @@ class TestRunner(TestCase):
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 3)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)