mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add optimizer wrapper (#265)
* Support multiple optimizers * minor refinement * improve unit tests * minor fix * Update unit tests for resuming or saving ckpt for multiple optimizers * refine docstring * refine docstring * fix typo * update docstring * refactor the logic to build multiple optimizers * resolve comments * ParamSchedulers spports multiple optimizers * add optimizer_wrapper * fix comment and docstirng * fix unit test * add unit test * refine docstring * RuntimeInfoHook supports printing multi learning rates * resolve comments * add optimizer_wrapper * fix mypy * fix lint * fix OptimizerWrapperDict docstring and add unit test * rename OptimizerWrapper to OptimWrapper, OptimWrapperDict inherit OptimWrapper, and fix as comment * Fix AmpOptimizerWrapper * rename build_optmizer_wrapper to build_optim_wrapper * refine optimizer wrapper * fix AmpOptimWrapper.step, docstring * resolve confict * rename DefaultOptimConstructor * fix as comment * rename clig grad auguments * refactor optim_wrapper config * fix docstring of DefaultOptimWrapperConstructor fix docstring of DefaultOptimWrapperConstructor * add get_lr method to OptimWrapper and OptimWrapperDict * skip some amp unit test * fix unit test * fix get_lr, get_momentum docstring * refactor get_lr, get_momentum, fix as comment * fix error message Co-authored-by: zhouzaida <zhouzaida@163.com>
This commit is contained in:
parent
987e5b83f9
commit
3e3866c1b9
@ -83,7 +83,7 @@ class OptimizerHook(Hook):
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``outputs`` here. Defaults to None.
|
||||
"""
|
||||
runner.optimizer.zero_grad()
|
||||
runner.optim_wrapper.zero_grad()
|
||||
if self.detect_anomalous_params:
|
||||
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||
runner.outputs['loss'].backward()
|
||||
@ -94,7 +94,7 @@ class OptimizerHook(Hook):
|
||||
# Add grad norm to the logger
|
||||
runner.message_hub.update_scalar('train/grad_norm',
|
||||
float(grad_norm))
|
||||
runner.optimizer.step()
|
||||
runner.optim_wrapper.step()
|
||||
|
||||
def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
|
||||
"""Detect anomalous parameters that are not included in the graph.
|
||||
|
@ -41,13 +41,16 @@ class RuntimeInfoHook(Hook):
|
||||
"""Update current iter and learning rate information before every
|
||||
iteration."""
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
if isinstance(runner.optimizer, dict):
|
||||
for name, optimizer in runner.optimizer.items():
|
||||
runner.message_hub.update_scalar(
|
||||
f'train/{name}.lr', optimizer.param_groups[0]['lr'])
|
||||
else:
|
||||
runner.message_hub.update_scalar(
|
||||
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
||||
lr_dict = runner.optim_wrapper.get_lr()
|
||||
assert isinstance(lr_dict, dict), (
|
||||
'`runner.optim_wrapper.get_lr()` should return a dict '
|
||||
'of learning rate when training with OptimWrapper(single '
|
||||
'optimizer) or OptimWrapperDict(multiple optimizer), '
|
||||
f'but got {type(lr_dict)} please check your optimizer '
|
||||
'constructor return an `OptimWrapper` or `OptimWrapperDict` '
|
||||
'instance')
|
||||
for name, lr in lr_dict.items():
|
||||
runner.message_hub.update_scalar(f'train/{name}', lr[0])
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimizerConstructor, build_optimizer)
|
||||
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
||||
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
|
||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||
CosineAnnealingParamScheduler, ExponentialLR,
|
||||
@ -11,12 +12,12 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
StepParamScheduler, _ParamScheduler)
|
||||
|
||||
__all__ = [
|
||||
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
|
||||
'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR',
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optim_wrapper',
|
||||
'DefaultOptimWrapperConstructor', 'ConstantLR', 'CosineAnnealingLR',
|
||||
'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum',
|
||||
'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum',
|
||||
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
||||
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
||||
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
||||
'_ParamScheduler'
|
||||
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict'
|
||||
]
|
||||
|
@ -1,8 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer
|
||||
from .default_constructor import DefaultOptimizerConstructor
|
||||
from .amp_optimizer_wrapper import AmpOptimWrapper
|
||||
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
build_optim_wrapper)
|
||||
from .default_constructor import DefaultOptimWrapperConstructor
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
from .optimizer_wrapper_dict import OptimWrapperDict
|
||||
|
||||
__all__ = [
|
||||
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
||||
'build_optimizer'
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
||||
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
||||
'AmpOptimWrapper', 'OptimWrapperDict'
|
||||
]
|
||||
|
110
mmengine/optim/optimizer/amp_optimizer_wrapper.py
Normal file
110
mmengine/optim/optimizer/amp_optimizer_wrapper.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
class AmpOptimWrapper(OptimWrapper):
|
||||
"""A subclass of :class:`OptimWrapper` that supports automatic mixed
|
||||
precision training based on torch.cuda.amp.
|
||||
|
||||
``AmpOptimWrapper`` provides a unified interface with
|
||||
``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way
|
||||
as ``OptimWrapper``.
|
||||
|
||||
Warnings:
|
||||
``AmpOptimWrapper`` requires PyTorch >= 1.6.
|
||||
|
||||
Args:
|
||||
loss_scale (float or str or dict): The initial configuration of
|
||||
`torch.cuda.amp.GradScaler`. See more specific arguments
|
||||
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501
|
||||
|
||||
- "dynamic": Initialize GradScale without any arguments.
|
||||
- float: Initialize GradScaler with ``init_scale``.
|
||||
- dict: Initialize GradScaler with more detail configuration.
|
||||
|
||||
**kwargs: Keyword arguments passed to OptimWrapper.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_scale=512., **kwargs):
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
|
||||
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
|
||||
assert torch.cuda.is_available(), (
|
||||
'``AmpOptimizerWrapper`` is only available training on gpu')
|
||||
super().__init__(**kwargs)
|
||||
self._scale_update_param = None
|
||||
if loss_scale == 'dynamic':
|
||||
# If loss_scale is a string, it must be 'dynamic', then dynamic
|
||||
# loss scaling will be used.
|
||||
self.loss_scaler = GradScaler()
|
||||
elif isinstance(loss_scale, float):
|
||||
# Static loss scaling
|
||||
self._scale_update_param = loss_scale
|
||||
self.loss_scaler = GradScaler(init_scale=loss_scale)
|
||||
elif isinstance(loss_scale, dict):
|
||||
# More specific configuration.
|
||||
self.loss_scaler = GradScaler(**loss_scale)
|
||||
else:
|
||||
raise TypeError('loss_scale must be of type float, dict, or '
|
||||
f'"dynamic", but got {loss_scale}')
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
"""Perform gradient back propagation with :attr:`loss_scaler`.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss of current iteration.
|
||||
"""
|
||||
self.loss_scaler.scale(loss).backward()
|
||||
|
||||
def step(self):
|
||||
"""Update parameters with :attr:`loss_scaler`."""
|
||||
if self.clip_grad_kwargs:
|
||||
self.loss_scaler.unscale_(self.optimizer)
|
||||
self._clip_grad()
|
||||
self.loss_scaler.step(self.optimizer)
|
||||
self.loss_scaler.update(self._scale_update_param)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary of :attr:`optimizer` and
|
||||
:attr:`loss_scaler`.
|
||||
|
||||
Based on the state dictionary of the optimizer, the returned state
|
||||
dictionary will add a key named "loss_scaler".
|
||||
|
||||
Returns:
|
||||
dict: The merged state dict of :attr:`loss_scaler` and
|
||||
:attr:`optimizer`.
|
||||
"""
|
||||
# save state_dict of loss_scaler
|
||||
state_dict = self.optimizer.state_dict()
|
||||
state_dict['loss_scaler'] = self.loss_scaler.state_dict()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""Load and parse the state dictionary of :attr:`optimizer` and
|
||||
:attr:`loss_scaler`.
|
||||
|
||||
If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will
|
||||
load the corresponding keys. Otherwise, only the :attr:`optimizer`
|
||||
will load the state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The state dict of :attr:`optimizer` and
|
||||
:attr:`loss_scaler`
|
||||
"""
|
||||
if 'loss_scaler' in state_dict:
|
||||
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
@contextmanager
|
||||
def precision_context(self):
|
||||
"""A wrapper of ``torch.cuda.amp.autocast``"""
|
||||
with torch.cuda.amp.autocast():
|
||||
yield
|
@ -1,12 +1,14 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
||||
def register_torch_optimizers() -> List[str]:
|
||||
@ -30,31 +32,31 @@ def register_torch_optimizers() -> List[str]:
|
||||
TORCH_OPTIMIZERS = register_torch_optimizers()
|
||||
|
||||
|
||||
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
|
||||
"""Build function of optimizer.
|
||||
def build_optim_wrapper(model: nn.Module,
|
||||
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
|
||||
"""Build function of OptimWrapper.
|
||||
|
||||
If ``constructor`` is set in the ``cfg``, this method will build an
|
||||
optimizer constructor, and use optimizer constructor to build the
|
||||
optimizer. If ``constructor`` is not set, the
|
||||
``DefaultOptimizerConstructor`` will be used by default.
|
||||
optimizer wrapper constructor, and use optimizer wrapper constructor to
|
||||
build the optimizer wrapper. If ``constructor`` is not set, the
|
||||
``DefaultOptimWrapperConstructor`` will be used by default.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to be optimized.
|
||||
cfg (dict): Config of optimizer and optimizer constructor.
|
||||
default_scope (str, optional): The ``default_scope`` is used to
|
||||
reset the current registry. Defaults to None.
|
||||
cfg (dict): Config of optimizer wrapper, optimizer constructor and
|
||||
optimizer.
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: The built optimizer.
|
||||
OptimWrapper: The built optimizer wrapper.
|
||||
"""
|
||||
optimizer_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optimizer_cfg.pop('constructor',
|
||||
'DefaultOptimizerConstructor')
|
||||
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
||||
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
|
||||
optim_wrapper_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optim_wrapper_cfg.pop('constructor',
|
||||
'DefaultOptimWrapperConstructor')
|
||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
||||
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
optim_wrapper_cfg=optim_wrapper_cfg,
|
||||
paramwise_cfg=paramwise_cfg))
|
||||
optimizer = optim_constructor(model)
|
||||
return optimizer
|
||||
optim_wrapper = optim_wrapper_constructor(model)
|
||||
return optim_wrapper
|
||||
|
@ -6,17 +6,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmengine.logging.logger import print_log
|
||||
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||
OPTIMIZERS)
|
||||
from mmengine.utils import is_list_of, mmcv_full_available
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class DefaultOptimizerConstructor:
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class DefaultOptimWrapperConstructor:
|
||||
"""Default constructor for optimizers.
|
||||
|
||||
By default each parameter share the same optimizer settings, and we
|
||||
By default, each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
It is a dict and may contain the following fields:
|
||||
|
||||
@ -62,49 +64,65 @@ class DefaultOptimizerConstructor:
|
||||
model contains multiple DCN layers in places other than backbone.
|
||||
|
||||
Args:
|
||||
optimizer_cfg (dict): The config dict of the optimizer.
|
||||
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
|
||||
Positional fields are
|
||||
|
||||
- `type`: class name of the optimizer.
|
||||
- ``type``: class name of the OptimizerWrapper
|
||||
- ``optimizer``: The configuration of optimizer.
|
||||
|
||||
Optional fields are
|
||||
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, etc.
|
||||
- any arguments of the corresponding optimizer wrapper type,
|
||||
e.g., accumulative_iters, clip_grad, etc.
|
||||
|
||||
The positional fields of ``optimizer`` are
|
||||
|
||||
- `type`: class name of the optimizer.
|
||||
|
||||
Optional fields are
|
||||
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, etc.
|
||||
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
||||
>>> weight_decay=0.0001)
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01,
|
||||
>>> momentum=0.9, weight_decay=0.0001))
|
||||
>>> paramwise_cfg = dict(norm_decay_mult=0.)
|
||||
>>> optim_builder = DefaultOptimizerConstructor(
|
||||
>>> optimizer_cfg, paramwise_cfg)
|
||||
>>> optimizer = optim_builder(model)
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
|
||||
Example 2:
|
||||
>>> # assume model have attribute model.backbone and model.cls_head
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
|
||||
>>> optim_wrapper_cfg = dict(type=OptimWrapper, optimizer=dict(
|
||||
>>> type='SGD', lr=0.01, weight_decay=0.95))
|
||||
>>> paramwise_cfg = dict(custom_keys={
|
||||
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||
>>> optim_builder = DefaultOptimizerConstructor(
|
||||
>>> optimizer_cfg, paramwise_cfg)
|
||||
>>> optimizer = optim_builder(model)
|
||||
>>> '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
||||
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
||||
>>> # model.cls_head is (0.01, 0.95).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer_cfg: dict,
|
||||
optim_wrapper_cfg: dict,
|
||||
paramwise_cfg: Optional[dict] = None):
|
||||
if not isinstance(optimizer_cfg, dict):
|
||||
if not isinstance(optim_wrapper_cfg, dict):
|
||||
raise TypeError('optimizer_cfg should be a dict',
|
||||
f'but got {type(optimizer_cfg)}')
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
f'but got {type(optim_wrapper_cfg)}')
|
||||
assert 'optimizer' in optim_wrapper_cfg, (
|
||||
'`optim_wrapper_cfg` must contain "optimizer" config')
|
||||
self.optim_wrapper_cfg = optim_wrapper_cfg.copy()
|
||||
self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
|
||||
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
|
||||
self.base_lr = optimizer_cfg.get('lr', None)
|
||||
self.base_wd = optimizer_cfg.get('weight_decay', None)
|
||||
self.base_lr = self.optimizer_cfg.get('lr', None)
|
||||
self.base_wd = self.optimizer_cfg.get('weight_decay', None)
|
||||
self._validate_cfg()
|
||||
|
||||
def _validate_cfg(self) -> None:
|
||||
@ -249,19 +267,23 @@ class DefaultOptimizerConstructor:
|
||||
prefix=child_prefix,
|
||||
is_dcn_module=is_dcn_module)
|
||||
|
||||
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
|
||||
def __call__(self, model: nn.Module) -> OptimWrapper:
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
|
||||
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
|
||||
optimizer_cfg = self.optimizer_cfg.copy()
|
||||
# if no paramwise option is specified, just use the global setting
|
||||
if not self.paramwise_cfg:
|
||||
optimizer_cfg['params'] = model.parameters()
|
||||
return OPTIMIZERS.build(optimizer_cfg)
|
||||
|
||||
# set param-wise lr and weight decay recursively
|
||||
params: List = []
|
||||
self.add_params(params, model)
|
||||
optimizer_cfg['params'] = params
|
||||
|
||||
return OPTIMIZERS.build(optimizer_cfg)
|
||||
optimizer = OPTIMIZERS.build(optimizer_cfg)
|
||||
else:
|
||||
# set param-wise lr and weight decay recursively
|
||||
params: List = []
|
||||
self.add_params(params, model)
|
||||
optimizer_cfg['params'] = params
|
||||
optimizer = OPTIMIZERS.build(optimizer_cfg)
|
||||
optim_wrapper = OPTIM_WRAPPERS.build(
|
||||
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
|
||||
return optim_wrapper
|
||||
|
349
mmengine/optim/optimizer/optimizer_wrapper.py
Normal file
349
mmengine/optim/optimizer/optimizer_wrapper.py
Normal file
@ -0,0 +1,349 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import has_batch_norm
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
class OptimWrapper:
|
||||
"""Optimizer wrapper provides a common interface for updating parameters.
|
||||
|
||||
Optimizer wrapper provides a unified interface for single precision
|
||||
training and automatic mixed precision training with different hardware.
|
||||
OptimWrapper encapsulates optimizer to provide simplified interfaces
|
||||
for commonly used training techniques such as gradient accumulative and
|
||||
grad clips. ``OptimWrapper`` implements the basic logic of gradient
|
||||
accumulation and gradient clipping based on ``torch.optim.Optimizer``.
|
||||
The subclasses only need to override some methods to implement the mixed
|
||||
precision training. See more information in :class:`AmpOptimWrapper`.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Optimizer used to update model parameters.
|
||||
accumulative_iters (int): The number of iterations to accumulate
|
||||
gradients. The parameters will be updated per
|
||||
``accumulative_iters``.
|
||||
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
|
||||
the arguments of ``torch.nn.utils.clip_grad``.
|
||||
|
||||
Warnings:
|
||||
If ``accumulative_iters`` is larger than 1, :meth:`update_params` must
|
||||
be called in the context of ``accumulate_grad``.
|
||||
|
||||
Examples:
|
||||
>>> # Config sample of OptimWrapper.
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> type='OptimWrapper',
|
||||
>>> accumulative_iters=3,
|
||||
>>> clip_grad=dict(max_norm=0.2))
|
||||
>>> # Use OptimWrapper to update model.
|
||||
>>> import torch.nn as nn
|
||||
>>> import torch
|
||||
>>> from torch.optim import SGD
|
||||
>>> from torch.utils.data import DataLoader
|
||||
>>> from mmengine.optim import OptimWrapper
|
||||
>>>
|
||||
>>> model = nn.Linear(1, 1)
|
||||
>>> dataset = torch.randn(10, 1, 1)
|
||||
>>> dataloader = DataLoader(dataset)
|
||||
>>> optimizer = SGD(model.parameters(), lr=0.1)
|
||||
>>> optim_wrapper = OptimWrapper(optimizer)
|
||||
>>>
|
||||
>>> for data in dataloader:
|
||||
>>> loss = model(data)
|
||||
>>> optim_wrapper.update_params(loss)
|
||||
>>> # Enable gradient accumulation. If model is a subclass instance of
|
||||
>>> # DistributedDataParallel, ``accumulate_grad`` context manager can
|
||||
>>> # avoid unnecessary gradient synchronize.
|
||||
>>> for iter, data in enumerate(dataloader):
|
||||
>>> with optim_wrapper.accumulate_grad(
|
||||
>>> model, iter, len(dataloader)):
|
||||
>>> loss = model(data)
|
||||
>>> optim_wrapper.update_params(loss)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
accumulative_iters: int = 1,
|
||||
clip_grad: Optional[dict] = None):
|
||||
assert accumulative_iters > 0, (
|
||||
'accumulative_iters at least greater than or equal to 1')
|
||||
self.accumulative_iters = accumulative_iters
|
||||
# `max_iters` and `cur_iter` is only valid in gradient accumulative
|
||||
# mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be
|
||||
# updated in the ``accumulate_grad`` context that is enabled in
|
||||
# `runner.train_loop`.
|
||||
self.cur_iter = 0
|
||||
self.max_iters = 0
|
||||
assert isinstance(optimizer, Optimizer), (
|
||||
'optimizer must be a `torch.optim.Optimizer` instance, but got '
|
||||
f'{type(optimizer)}')
|
||||
self.optimizer = optimizer
|
||||
|
||||
if clip_grad is not None:
|
||||
# clip_grad_kwargs should not be non-empty dict.
|
||||
assert isinstance(clip_grad, dict) and clip_grad, (
|
||||
'If `clip_grad_kwargs` is not None, it should be a `dict` '
|
||||
'which is the arguments of `torch.nn.utils.clip_grad`')
|
||||
self.clip_grad_kwargs = clip_grad
|
||||
self.logger = MMLogger.get_current_instance()
|
||||
# Used to update `grad_norm` log message.
|
||||
self.message_hub = MessageHub.get_current_instance()
|
||||
self.iter_status_initialized = False
|
||||
|
||||
def update_params(self, loss: torch.Tensor) -> None:
|
||||
"""Update parameters in :attr:`optimizer`.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): A tensor for back propagation.
|
||||
"""
|
||||
if self.accumulative_iters == 1:
|
||||
# update parameters without gradient accumulation. The gradient
|
||||
# should not be rescaled and `loss_factor=1`.
|
||||
loss_factor = 1
|
||||
else:
|
||||
# gradient accumulation must be called in the context of
|
||||
# ``accumulate_grad``.
|
||||
assert hasattr(self, 'divisible_iters'), (
|
||||
'gradient accumulation must be performed in the context of'
|
||||
'`OptimWrapper.accumulate_grad`')
|
||||
# if `self.accumulative_iters > 1`, the gradient needs to be
|
||||
# rescaled and accumulated. In most cases, `loss_factor` equals to
|
||||
# `self.accumulative_iters`. However `self.max_iters` may not be
|
||||
# divisible `self.by accumulative_iters`, so the `loss_scale` for
|
||||
# the last few iterations needs to be recalculated.
|
||||
if self.cur_iter < self.divisible_iters:
|
||||
loss_factor = self.accumulative_iters
|
||||
else:
|
||||
loss_factor = self.remainder_iters
|
||||
assert loss_factor > 0, (
|
||||
'loss_factor should be larger than zero! This error could '
|
||||
'happened when gradient accumulation context enabled with an '
|
||||
'error `cur_iter` or `max_iters` please check your loop')
|
||||
|
||||
loss = loss / loss_factor
|
||||
self.backward(loss)
|
||||
# Update parameters only if `self.cur_iter` is divisible by
|
||||
# `self.accumulative_iters` or `self.cur_iter` equals to
|
||||
# `self.max_iters`
|
||||
if self._should_update(self.cur_iter, self.max_iters):
|
||||
self.step()
|
||||
self.zero_grad()
|
||||
|
||||
def backward(self, loss: torch.Tensor) -> None:
|
||||
"""Perform gradient back propagation.
|
||||
|
||||
Provide unified ``backward`` interface compatible with automatic mixed
|
||||
precision training. Subclass can overload this method to implement the
|
||||
required logic. For example, ``torch.cuda.amp`` require some extra
|
||||
operation on GradScaler during backward process.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss of current iteration.
|
||||
"""
|
||||
loss.backward()
|
||||
|
||||
def zero_grad(self) -> None:
|
||||
"""A wrapper of ``Optimizer.zero_grad``.
|
||||
|
||||
Provide unified ``zero_grad`` interface compatible with automatic mixed
|
||||
precision training. Subclass can overload this method to implement the
|
||||
required logic.
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def step(self) -> None:
|
||||
"""A wrapper of ``Optimizer.step``.
|
||||
|
||||
Provide unified ``step`` interface compatible with automatic mixed
|
||||
precision training. Subclass can overload this method to implement the
|
||||
required logic. For example, ``torch.cuda.amp`` require some extra
|
||||
operation on ``GradScaler`` during step process.
|
||||
|
||||
Clip grad if :attr:`clip_grad_kwargs` is not None, and then update
|
||||
parameters.
|
||||
"""
|
||||
if self.clip_grad_kwargs:
|
||||
self._clip_grad()
|
||||
self.optimizer.step()
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""A wrapper of ``Optimizer.state_dict``.
|
||||
|
||||
Provide unified ``state_dict`` interface compatible with automatic
|
||||
mixed precision training. Subclass can overload this method to
|
||||
implement the required logic. For example, the state dictionary of
|
||||
GradScaler should be saved when training with ``torch.cuda.amp``.
|
||||
|
||||
Returns:
|
||||
dict: The state dictionary of :attr:`optimizer`.
|
||||
"""
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
|
||||
:attr:`optimizer`.
|
||||
|
||||
Provide unified ``load_state_dict`` interface compatible with automatic
|
||||
mixed precision training. Subclass can overload this method to
|
||||
implement the required logic. For example, the state dictionary of
|
||||
GradScaler should be loaded when training with ``torch.cuda.amp``.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The state dictionary of :attr:`optimizer`.
|
||||
"""
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
def param_groups(self) -> List[dict]:
|
||||
"""A wrapper of ``Optimizer.param_groups``.
|
||||
|
||||
Make OptimizeWrapper compatible with :class:`_ParamScheduler`.
|
||||
|
||||
Returns:
|
||||
dict: the ``param_groups`` of :attr:`optimizer`.
|
||||
"""
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def get_lr(self) -> Dict[str, List[float]]:
|
||||
"""Get the learning rate of the optimizer.
|
||||
|
||||
Provide unified interface to get learning rate of optimizer.
|
||||
|
||||
Returns:
|
||||
List[float]: Learning rate of the optimizer.
|
||||
"""
|
||||
lr = [group['lr'] for group in self.param_groups]
|
||||
return dict(lr=lr)
|
||||
|
||||
def get_momentum(self) -> Dict[str, List[float]]:
|
||||
"""Get the momentum of the optimizer.
|
||||
|
||||
Provide unified interface to get momentum of optimizer.
|
||||
|
||||
Returns:
|
||||
List[float]: Momentum of the optimizer.
|
||||
"""
|
||||
momentum = []
|
||||
for group in self.param_groups:
|
||||
# Get momentum of SGD.
|
||||
if 'momentum' in group.keys():
|
||||
momentum.append(group['momentum'])
|
||||
# Get momentum of Adam.
|
||||
elif 'betas' in group.keys():
|
||||
momentum.append(group['betas'][0])
|
||||
else:
|
||||
momentum.append(0)
|
||||
return dict(momentum=momentum)
|
||||
|
||||
@contextmanager
|
||||
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
|
||||
"""A Context manager for gradient accumulation and avoiding unnecessary
|
||||
gradient synchronization during gradient accumulation.
|
||||
|
||||
If model is an instance with ``no_sync`` method (which means
|
||||
blocking the gradient synchronization) and
|
||||
``self.accumulative_iters != 1``. The model will not automatically
|
||||
synchronize gradients if ``cur_iter`` is divisible by
|
||||
``self.accumulative_iters``. Otherwise, this method will enable an
|
||||
empty context.
|
||||
|
||||
Warnings:
|
||||
This context manager must be enabled if you want to use
|
||||
gradient accumulation.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The training model.
|
||||
cur_iter (int): Current iteration during training process.
|
||||
max_iters (int): Maximum training iteration.
|
||||
"""
|
||||
assert max_iters > 0, '`max_iters` must be larger than zero'
|
||||
self.cur_iter = cur_iter
|
||||
self.max_iters = max_iters
|
||||
if not self.iter_status_initialized:
|
||||
self._initilize_iter_status(model)
|
||||
# During gradient accumulation process, the gradient synchronize
|
||||
# should only happen before updating parameters.
|
||||
if (not self._should_update(cur_iter, max_iters)
|
||||
and hasattr(model, 'no_sync')):
|
||||
with model.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def precision_context(self):
|
||||
"""precision context which enables an empty context by default.
|
||||
|
||||
The subclass used for mixed or low precision training needs to override
|
||||
this method.
|
||||
"""
|
||||
yield
|
||||
|
||||
def _clip_grad(self) -> None:
|
||||
"""Clip the gradients of parameters."""
|
||||
params: List[torch.Tensor] = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
params.extend(param_group['params'])
|
||||
|
||||
params = list(
|
||||
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
||||
if len(params) > 0:
|
||||
grad_norm = clip_grad.clip_grad_norm_(params,
|
||||
**self.clip_grad_kwargs)
|
||||
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
|
||||
|
||||
def _initilize_iter_status(self, model: nn.Module) -> None:
|
||||
"""Initialize gradient accumulation related attributes.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Training model
|
||||
"""
|
||||
if self.max_iters % self.accumulative_iters != 0:
|
||||
self.logger.warning(
|
||||
'Resume iter number is not divisible by accumulative_iters in '
|
||||
'GradientCumulativeOptimizerHook, which means the gradient of '
|
||||
'some iters is lost and the result may be influenced slightly.'
|
||||
)
|
||||
|
||||
if has_batch_norm(model) and self.accumulative_iters > 1:
|
||||
self.logger.warning(
|
||||
'Gradient accumulative may slightly decrease '
|
||||
'performance because the model has BatchNorm layers.')
|
||||
residual_iters = self.max_iters - self.cur_iter
|
||||
# The maximum number of training iteration that is divisible by
|
||||
# accumulative_iters.
|
||||
self.divisible_iters = (
|
||||
residual_iters // self.accumulative_iters *
|
||||
self.accumulative_iters)
|
||||
# Remainder of ``self.max_iters`` divided by ``self.max_iters``
|
||||
self.remainder_iters = residual_iters - self.divisible_iters
|
||||
self.iter_status_initialized = True
|
||||
|
||||
def _should_update(self, cur_iter: int, max_iters: int) -> bool:
|
||||
"""Should optim_wrapper update parameters or synchronized gradient at
|
||||
current iteration.
|
||||
|
||||
Args:
|
||||
cur_iter (int): Current iteration of training process.
|
||||
max_iters (int): Maximum iterations of training process.
|
||||
|
||||
Returns:
|
||||
bool: Whether to update parameters or synchronized gradient.
|
||||
"""
|
||||
return ((cur_iter + 1) % self.accumulative_iters == 0
|
||||
or cur_iter + 1 == max_iters)
|
||||
|
||||
def __repr__(self):
|
||||
wrapper_info = f'Type: {type(self).__name__}\n' \
|
||||
f'accumulative_iters: {self.accumulative_iters}\n' \
|
||||
f'optimizer: \n'
|
||||
optimizer_str = repr(self.optimizer) + '\n'
|
||||
return wrapper_info + optimizer_str
|
208
mmengine/optim/optimizer/optimizer_wrapper_dict.py
Normal file
208
mmengine/optim/optimizer/optimizer_wrapper_dict.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
||||
class OptimWrapperDict(OptimWrapper):
|
||||
"""A dictionary container of :obj:`OptimWrapper`.
|
||||
|
||||
If runner is training with multiple optimizers, all optimizer wrappers
|
||||
should be managed by :obj:`OptimWrapperDict` which is built by
|
||||
``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save
|
||||
the state dictionary of all optimizer wrappers.
|
||||
|
||||
Consider the semantic ambiguity of calling :meth:``update_params``,
|
||||
:meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not
|
||||
implement these methods.
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> from torch.optim import SGD
|
||||
>>> from mmengine.optim import OptimWrapperDict, OptimWrapper
|
||||
>>> model1 = nn.Linear(1, 1)
|
||||
>>> model2 = nn.Linear(1, 1)
|
||||
>>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1))
|
||||
>>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1))
|
||||
>>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1,
|
||||
>>> model2=optim_wrapper2)
|
||||
|
||||
Note:
|
||||
The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed
|
||||
in the same way as `dict`.
|
||||
|
||||
Args:
|
||||
**optim_wrappers: A dictionary of ``OptimWrapper`` instance.
|
||||
"""
|
||||
|
||||
def __init__(self, **optim_wrapper_dict: OptimWrapper):
|
||||
first_key = next(iter(optim_wrapper_dict))
|
||||
first_optim_wrapper = optim_wrapper_dict[first_key]
|
||||
assert isinstance(first_optim_wrapper, OptimWrapper), (
|
||||
'Each argument of `OptimWrapperDict` must be an `OptimWrapper '
|
||||
'instance`')
|
||||
optim_wrapper_class = type(first_optim_wrapper)
|
||||
for key, value in optim_wrapper_dict.items():
|
||||
assert type(value) == optim_wrapper_class, (
|
||||
f'All optimizer wrappers should have the same type, but found'
|
||||
f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}'
|
||||
)
|
||||
if value.accumulative_iters != 1:
|
||||
warnings.warn(
|
||||
f'The `accumulative_iters` of {key} is '
|
||||
f'{value.accumulative_iters}. OptimWrapperDict '
|
||||
'will not enable any `accumulate_grad` context of its '
|
||||
'optimizer wrappers. You should access the corresponding '
|
||||
'optimizer wrapper to enable the context.')
|
||||
self.optim_wrappers = optim_wrapper_dict
|
||||
|
||||
def update_params(self, loss: torch.Tensor) -> None:
|
||||
"""Update all optimizer wrappers would lead to a duplicate backward
|
||||
errors, and OptimWrapperDict does not know which optimizer wrapper
|
||||
should be updated.
|
||||
|
||||
Therefore, this method is not implemented. The optimizer wrapper of
|
||||
OptimWrapperDict should be accessed and call its `update_params.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'You should access the OptimWrapper of the '
|
||||
'OptimWrapperDict and call its `update_params`')
|
||||
|
||||
def backward(self, loss: torch.Tensor) -> None:
|
||||
"""Since OptimWrapperDict doesn't know which optimizer wrapper's
|
||||
backward method should be called (``loss_scaler`` maybe different in
|
||||
different :obj:AmpOptimWrapper), this method is not implemented.
|
||||
|
||||
The optimizer wrapper of OptimWrapperDict should be accessed and call
|
||||
its `backward.
|
||||
"""
|
||||
raise NotImplementedError('You should access the OptimWrapper of the '
|
||||
'OptimWrapperDict and call its `backward`')
|
||||
|
||||
def step(self) -> None:
|
||||
"""Since the backward method is not implemented, the step should not be
|
||||
implemented either."""
|
||||
raise NotImplementedError('You should access the OptimWrapper of the '
|
||||
'OptimWrapperDict and call its `step`')
|
||||
|
||||
def zero_grad(self) -> None:
|
||||
"""Set the gradients of all optimizer wrappers to zero."""
|
||||
for optim_wrapper in self.optim_wrappers.values():
|
||||
optim_wrapper.zero_grad()
|
||||
|
||||
@contextmanager
|
||||
def precision_context(self):
|
||||
optim_wrapper = next(iter(self.optim_wrappers.values()))
|
||||
with optim_wrapper.precision_context():
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
|
||||
"""Enable ``accumulate_grad`` contexts of all optimizer wrappers.
|
||||
|
||||
Warning:
|
||||
Consider there is only one ``model`` arguments for all
|
||||
optimizer wrappers, all optimizer wrappers are working under the
|
||||
same ``model.no_sync`` context. For example, there is a model
|
||||
composed of model_a(optimizer_a) and model_b(optimizer_b).
|
||||
``OptimWrapperDict.accumulate_grad`` will further
|
||||
call ``model.no_sync``, which will block the gradient
|
||||
synchronization of both a and b. If optimizer_a and
|
||||
optimizer_b have different ``accumulative_iters``, and want to
|
||||
block the gradient synchronization of model_a and model_b
|
||||
separately, the model should not implement the ``no_sync``
|
||||
method(or enable an empty context). The ``accumulate_grad`` context
|
||||
should be enabled inside the model by accessing corresponding
|
||||
optimizer wrapper.
|
||||
"""
|
||||
with ExitStack() as stack:
|
||||
for optim_wrapper in self.optim_wrappers.values():
|
||||
stack.enter_context(
|
||||
optim_wrapper.accumulate_grad(model, cur_iter, max_iters))
|
||||
yield
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load the state dictionary from the ``state_dict``.
|
||||
|
||||
Args:
|
||||
state_dict (dict): Each key-value pair in `state_dict` represents
|
||||
the name and the state dictionary of corresponding
|
||||
:obj:`OptimWrapper`.
|
||||
"""
|
||||
for name, _state_dict in state_dict.items():
|
||||
assert name in self.optim_wrappers, (
|
||||
f'Mismatched `state_dict`! cannot found {name} in '
|
||||
'OptimWrapperDict')
|
||||
self.optim_wrappers[name].load_state_dict(_state_dict)
|
||||
|
||||
def get_lr(self) -> Dict[str, List[float]]:
|
||||
"""Get the learning rate of all optimizers.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[float]]: Learning rate of all optimizers.
|
||||
"""
|
||||
lr_dict = dict()
|
||||
for name, optim_wrapper in self.optim_wrappers.items():
|
||||
lr_dict[f'{name}.lr'] = optim_wrapper.get_lr()['lr']
|
||||
return lr_dict
|
||||
|
||||
def get_momentum(self) -> Dict[str, List[float]]:
|
||||
"""Get the momentum of all optimizers.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[float]]: momentum of all optimizers.
|
||||
"""
|
||||
momentum_dict = dict()
|
||||
for name, optim_wrapper in self.optim_wrappers.items():
|
||||
momentum_dict[f'{name}.momentum'] = optim_wrapper.get_momentum(
|
||||
)['momentum']
|
||||
return momentum_dict
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary of all optimizer wrappers.
|
||||
|
||||
Returns:
|
||||
dict: Each key-value pair in the dictionary represents the name
|
||||
and state dictionary of corresponding :obj:`OptimWrapper`.
|
||||
"""
|
||||
state_dict = dict()
|
||||
for name, optim_wrapper in self.optim_wrappers.items():
|
||||
state_dict[name] = optim_wrapper.state_dict()
|
||||
return state_dict
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
|
||||
"""A generator to get the name and corresponding
|
||||
:obj:`OptimWrapper`"""
|
||||
yield from self.optim_wrappers.items()
|
||||
|
||||
def values(self) -> Iterator[OptimWrapper]:
|
||||
"""A generator to get :obj:`OptimWrapper`"""
|
||||
yield from self.optim_wrappers.values()
|
||||
|
||||
def keys(self) -> Iterator[str]:
|
||||
"""A generator to get the name of :obj:`OptimWrapper`"""
|
||||
yield from self.optim_wrappers.keys()
|
||||
|
||||
def __getitem__(self, key: str) -> OptimWrapper:
|
||||
assert key in self.optim_wrappers, (
|
||||
f'Cannot find {key} in OptimWrapperDict, please check '
|
||||
'your optimizer constructor.')
|
||||
return self.optim_wrappers[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.optim_wrappers
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.optim_wrappers)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
desc = ''
|
||||
for name, optim_wrapper in self.optim_wrappers.items():
|
||||
desc += f'name: {name}\n'
|
||||
desc += repr(optim_wrapper)
|
||||
return desc
|
@ -22,7 +22,7 @@ class ConstantLR(LRSchedulerMixin, ConstantParamScheduler):
|
||||
changes to the learning rate value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
factor (float): The number we multiply learning rate until the
|
||||
milestone. Defaults to 1./3.
|
||||
begin (int): Step at which to start updating the learning rate.
|
||||
@ -68,7 +68,7 @@ class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler):
|
||||
only implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum learning rate. Defaults to 0.
|
||||
begin (int): Step at which to start updating the learning rate.
|
||||
@ -92,7 +92,7 @@ class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
begin (int): Step at which to start updating the learning rate.
|
||||
Defaults to 0.
|
||||
@ -116,7 +116,7 @@ class LinearLR(LRSchedulerMixin, LinearParamScheduler):
|
||||
Notice that such decay can happen simultaneously with other changes to the
|
||||
learning rate from outside this scheduler.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
start_factor (float): The number we multiply learning rate in the
|
||||
first epoch. The multiplication factor changes towards end_factor
|
||||
in the following epochs. Defaults to 1./3.
|
||||
@ -143,7 +143,7 @@ class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler):
|
||||
outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
Defaults to 0.1.
|
||||
@ -167,7 +167,7 @@ class StepLR(LRSchedulerMixin, StepParamScheduler):
|
||||
other changes to the learning rate from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
step_size (int): Period of learning rate decay.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
Defaults to 0.1.
|
||||
@ -193,7 +193,7 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
|
||||
parameter value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
eta_min (float): Minimum learning rate at the end of scheduling.
|
||||
Defaults to 0.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
|
@ -22,7 +22,8 @@ class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler):
|
||||
momentum value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
factor (float): The number we multiply momentum until the milestone.
|
||||
Defaults to 1./3.
|
||||
begin (int): Step at which to start updating the momentum.
|
||||
@ -69,7 +70,8 @@ class CosineAnnealingMomentum(MomentumSchedulerMixin,
|
||||
only implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum momentum value. Defaults to 0.
|
||||
begin (int): Step at which to start updating the momentum.
|
||||
@ -93,7 +95,8 @@ class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler):
|
||||
"""Decays the momentum of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
gamma (float): Multiplicative factor of momentum value decay.
|
||||
begin (int): Step at which to start updating the momentum.
|
||||
Defaults to 0.
|
||||
@ -117,7 +120,8 @@ class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler):
|
||||
Notice that such decay can happen simultaneously with other changes to the
|
||||
momentum from outside this scheduler.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
start_factor (float): The number we multiply momentum in the
|
||||
first epoch. The multiplication factor changes towards end_factor
|
||||
in the following epochs. Defaults to 1./3.
|
||||
@ -144,7 +148,8 @@ class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler):
|
||||
scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of momentum value decay.
|
||||
Defaults to 0.1.
|
||||
@ -168,7 +173,8 @@ class StepMomentum(MomentumSchedulerMixin, StepParamScheduler):
|
||||
to the momentum from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
step_size (int): Period of momentum value decay.
|
||||
gamma (float): Multiplicative factor of momentum value decay.
|
||||
Defaults to 0.1.
|
||||
@ -194,7 +200,8 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
|
||||
parameter value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
eta_min (float): Minimum momentum at the end of scheduling.
|
||||
Defaults to 0.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
|
@ -4,14 +4,17 @@ import warnings
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from functools import wraps
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import PARAM_SCHEDULERS
|
||||
|
||||
INF = int(1e9)
|
||||
|
||||
OptimizerType = Union[OptimWrapper, Optimizer]
|
||||
|
||||
|
||||
class _ParamScheduler:
|
||||
"""Base class for parameter schedulers.
|
||||
@ -23,7 +26,7 @@ class _ParamScheduler:
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
@ -40,7 +43,7 @@ class _ParamScheduler:
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerType,
|
||||
param_name: str,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
@ -49,7 +52,7 @@ class _ParamScheduler:
|
||||
verbose: bool = False):
|
||||
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
if not isinstance(optimizer, (Optimizer, OptimWrapper)):
|
||||
raise TypeError('``optimizer`` should be an Optimizer,'
|
||||
'but got {}'.format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
@ -111,8 +114,8 @@ class _ParamScheduler:
|
||||
return wrapper
|
||||
|
||||
# add counter to optimizer
|
||||
self.optimizer.step = with_counter(self.optimizer.step)
|
||||
self.optimizer._global_step = -1
|
||||
self.optimizer.step = with_counter(self.optimizer.step) # type: ignore
|
||||
self.optimizer._global_step = -1 # type: ignore
|
||||
|
||||
self._global_step = -1
|
||||
self.verbose = verbose
|
||||
@ -218,7 +221,7 @@ class StepParamScheduler(_ParamScheduler):
|
||||
other changes to the parameter value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||
step_size (int): Period of parameter value decay.
|
||||
gamma (float): Multiplicative factor of parameter value decay.
|
||||
Defaults to 0.1.
|
||||
@ -235,7 +238,7 @@ class StepParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerType,
|
||||
param_name: str,
|
||||
step_size: int,
|
||||
gamma: float = 0.1,
|
||||
@ -304,7 +307,7 @@ class MultiStepParamScheduler(_ParamScheduler):
|
||||
scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of parameter value decay.
|
||||
Defaults to 0.1.
|
||||
@ -321,7 +324,7 @@ class MultiStepParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerType,
|
||||
param_name: str,
|
||||
milestones: List[int],
|
||||
gamma: float = 0.1,
|
||||
@ -391,7 +394,8 @@ class ConstantParamScheduler(_ParamScheduler):
|
||||
parameter value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
factor (float): The number we multiply parameter value until the
|
||||
milestone. Defaults to 1./3.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
@ -407,7 +411,7 @@ class ConstantParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerType,
|
||||
param_name: str,
|
||||
factor: float = 1.0 / 3,
|
||||
begin: int = 0,
|
||||
@ -477,7 +481,8 @@ class ExponentialParamScheduler(_ParamScheduler):
|
||||
"""Decays the parameter value of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
gamma (float): Multiplicative factor of parameter value decay.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
@ -492,7 +497,7 @@ class ExponentialParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerType,
|
||||
param_name: str,
|
||||
gamma: float,
|
||||
begin: int = 0,
|
||||
@ -573,7 +578,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||
only implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum parameter value. Defaults to 0.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
@ -592,7 +598,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: Union[Optimizer, OptimWrapper],
|
||||
param_name: str,
|
||||
T_max: int,
|
||||
eta_min: float = 0.,
|
||||
@ -670,7 +676,8 @@ class LinearParamScheduler(_ParamScheduler):
|
||||
parameter value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
start_factor (float): The number we multiply parameter value in the
|
||||
first epoch. The multiplication factor changes towards end_factor
|
||||
in the following epochs. Defaults to 1./3.
|
||||
@ -689,7 +696,7 @@ class LinearParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: Union[Optimizer, OptimWrapper],
|
||||
param_name: str,
|
||||
start_factor: float = 1.0 / 3,
|
||||
end_factor: float = 1.0,
|
||||
@ -765,7 +772,8 @@ class PolyParamScheduler(_ParamScheduler):
|
||||
parameter value from outside this scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||
Defaults to 0.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
@ -782,7 +790,7 @@ class PolyParamScheduler(_ParamScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: Union[Optimizer, OptimWrapper],
|
||||
param_name: str,
|
||||
eta_min: float = 0,
|
||||
power: float = 1.0,
|
||||
|
@ -2,17 +2,17 @@
|
||||
from .default_scope import DefaultScope
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
|
||||
METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
|
||||
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
|
||||
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
|
||||
WEIGHT_INITIALIZERS)
|
||||
METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
|
||||
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
|
||||
VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||
from .utils import count_registered_modules, traverse_registry_tree
|
||||
|
||||
__all__ = [
|
||||
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
||||
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
|
||||
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
|
||||
'LOG_PROCESSORS', 'DefaultScope', 'traverse_registry_tree',
|
||||
'count_registered_modules'
|
||||
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
|
||||
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
|
||||
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope',
|
||||
'traverse_registry_tree', 'count_registered_modules'
|
||||
]
|
||||
|
@ -32,7 +32,7 @@ WEIGHT_INITIALIZERS = Registry('weight initializer')
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer')
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor')
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry('parameter scheduler')
|
||||
# manage all kinds of metrics
|
||||
@ -48,3 +48,6 @@ VISBACKENDS = Registry('vis_backend')
|
||||
|
||||
# manage logprocessor
|
||||
LOG_PROCESSORS = Registry('log_processor')
|
||||
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optim_wrapper')
|
||||
|
@ -6,6 +6,7 @@ import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
@ -25,7 +26,8 @@ from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.optim import _ParamScheduler, build_optimizer
|
||||
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||
VISUALIZERS, DefaultScope,
|
||||
@ -44,6 +46,7 @@ from .priority import Priority, get_priority
|
||||
ConfigType = Union[Dict, Config, ConfigDict]
|
||||
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
|
||||
List[_ParamScheduler]]]
|
||||
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
|
||||
|
||||
|
||||
class Runner:
|
||||
@ -97,10 +100,14 @@ class Runner:
|
||||
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
|
||||
specified. Defaults to None.
|
||||
See :meth:`build_test_loop` for more details.
|
||||
optimizer (Optimizer or dict, optional): Computing gradient of model
|
||||
parameters. If specified, :attr:`train_dataloader` should also be
|
||||
specified. Defaults to None.
|
||||
See :meth:`build_optimizer` for examples.
|
||||
optim_wrapper (OptimWrapper or dict, optional):
|
||||
Computing gradient of model parameters. If specified,
|
||||
:attr:`train_dataloader` should also be specified. If automatic
|
||||
mixed precision or gradient accmulation
|
||||
training is required. The type of ``optim_wrapper`` should be
|
||||
AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for
|
||||
examples. Defaults to None.
|
||||
|
||||
param_scheduler (_ParamScheduler or dict or list, optional):
|
||||
Parameter scheduler for updating optimizer parameters. If
|
||||
specified, :attr:`optimizer` should also be specified.
|
||||
@ -177,7 +184,8 @@ class Runner:
|
||||
>>> sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
>>> batch_size=1,
|
||||
>>> num_workers=0),
|
||||
>>> optimizer=dict(type='SGD', lr=0.01),
|
||||
>>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(
|
||||
>>> type='SGD', lr=0.01)),
|
||||
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
>>> val_evaluator=dict(type='ToyEvaluator'),
|
||||
>>> test_evaluator=dict(type='ToyEvaluator'),
|
||||
@ -217,7 +225,7 @@ class Runner:
|
||||
train_cfg: Optional[Dict] = None,
|
||||
val_cfg: Optional[Dict] = None,
|
||||
test_cfg: Optional[Dict] = None,
|
||||
optimizer: Optional[Union[Optimizer, Dict]] = None,
|
||||
optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
|
||||
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
|
||||
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
|
||||
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
|
||||
@ -249,7 +257,7 @@ class Runner:
|
||||
self.cfg = Config(dict())
|
||||
|
||||
# lazy initialization
|
||||
training_related = [train_dataloader, train_cfg, optimizer]
|
||||
training_related = [train_dataloader, train_cfg, optim_wrapper]
|
||||
if not (all(item is None for item in training_related)
|
||||
or all(item is not None for item in training_related)):
|
||||
raise ValueError(
|
||||
@ -257,14 +265,16 @@ class Runner:
|
||||
'all None or not None, but got '
|
||||
f'train_dataloader={train_dataloader}, '
|
||||
f'train_cfg={train_cfg}, '
|
||||
f'optimizer={optimizer}.')
|
||||
f'optim_wrapper={optim_wrapper}.')
|
||||
self._train_dataloader = train_dataloader
|
||||
self._train_loop = train_cfg
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
|
||||
self.optim_wrapper = optim_wrapper
|
||||
|
||||
# If there is no need to adjust learning rate, momentum or other
|
||||
# parameters of optimizer, param_scheduler can be None
|
||||
if param_scheduler is not None and self.optimizer is None:
|
||||
if param_scheduler is not None and self.optim_wrapper is None:
|
||||
raise ValueError(
|
||||
'param_scheduler should be None when optimizer is None, '
|
||||
f'but got {param_scheduler}')
|
||||
@ -400,7 +410,7 @@ class Runner:
|
||||
train_cfg=cfg.get('train_cfg'),
|
||||
val_cfg=cfg.get('val_cfg'),
|
||||
test_cfg=cfg.get('test_cfg'),
|
||||
optimizer=cfg.get('optimizer'),
|
||||
optim_wrapper=cfg.get('optim_wrapper'),
|
||||
param_scheduler=cfg.get('param_scheduler'),
|
||||
val_evaluator=cfg.get('val_evaluator'),
|
||||
test_evaluator=cfg.get('test_evaluator'),
|
||||
@ -803,21 +813,25 @@ class Runner:
|
||||
|
||||
return model
|
||||
|
||||
def build_optimizer(
|
||||
self, optimizer: Union[Optimizer, Dict]
|
||||
) -> Union[Optimizer, Dict[str, Optimizer]]:
|
||||
"""Build an optimizer or multiple optimizers.
|
||||
def build_optim_wrapper(
|
||||
self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict]
|
||||
) -> Union[OptimWrapper, OptimWrapperDict]:
|
||||
"""Build optimizer wrapper.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or dict): An Optimizer object or a dict to
|
||||
build Optimizer objects. If ``optimizer`` is an Optimizer
|
||||
object, just returns itself.
|
||||
optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a
|
||||
dict to build OptimWrapper objects. If ``optim_wrapper`` is an
|
||||
OptimWrapper, just return an ``OptimizeWrapper`` instance.
|
||||
|
||||
Examples:
|
||||
>>> # build an optimizer
|
||||
>>> optim_cfg = dict(type='SGD', lr=0.01)
|
||||
>>> optimizer = runner.build_optimizer(optim_cfg)
|
||||
>>> optimizer
|
||||
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
|
||||
... type='SGD', lr=0.01))
|
||||
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
>>> optim_wrapper
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
optimizer:
|
||||
SGD (
|
||||
Parameter Group 0
|
||||
dampening: 0
|
||||
@ -828,71 +842,85 @@ class Runner:
|
||||
)
|
||||
|
||||
>>> # build multiple optimizers
|
||||
>>> optim_cfg = dict(
|
||||
... generator=dict(type='SGD', lr=0.01),
|
||||
... discriminator=dict(type='Adam',lr=0.02)
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
... generator=dict(type='OptimWrapper', optimizer=dict(
|
||||
... type='SGD', lr=0.01)),
|
||||
... discriminator=dict(type='OptimWrapper', optimizer=dict(
|
||||
... type='Adam', lr=0.001))
|
||||
... # need to customize a multiple optimizer constructor
|
||||
... constructor='CustomizedMultipleOptimizersConstructor',
|
||||
...)
|
||||
>>> optimizer = runner.build_optimizer(optim_cfg)
|
||||
>>> optimizer
|
||||
{'generator': SGD (
|
||||
>>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg)
|
||||
>>> optim_wrapper
|
||||
name: generator
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
optimizer:
|
||||
SGD (
|
||||
Parameter Group 0
|
||||
dampening: 0
|
||||
lr: 0.01
|
||||
lr: 0.1
|
||||
momentum: 0
|
||||
nesterov: False
|
||||
weight_decay: 0
|
||||
),
|
||||
'discriminator': SGD (
|
||||
)
|
||||
name: discriminator
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
optimizer:
|
||||
'discriminator': Adam (
|
||||
Parameter Group 0
|
||||
dampening: 0
|
||||
lr: 0.02
|
||||
momentum: 0
|
||||
nesterov: False
|
||||
weight_decay: 0
|
||||
)}
|
||||
)
|
||||
|
||||
Important:
|
||||
If you need to build multiple optimizers, you should implement a
|
||||
MultipleOptimizerConstructor which gets parameters passed to
|
||||
corresponding optimizers. More details about how to customize
|
||||
OptimizerConstructor can be found at `optimizer-docs`_.
|
||||
corresponding optimizers and compose the ``OptimWrapperDict``.
|
||||
More details about how to customize OptimizerConstructor can be
|
||||
found at `optimizer-docs`_.
|
||||
|
||||
Returns:
|
||||
Optimizer or dict[str, Optimizer]: Optimizer build from
|
||||
``optimizer``.
|
||||
OptimWrapper: Optimizer wrapper build from ``optimizer_cfg``.
|
||||
|
||||
.. _optimizer-docs:
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/optimizer.html
|
||||
"""
|
||||
if isinstance(optimizer, Optimizer):
|
||||
return optimizer
|
||||
elif isinstance(optimizer, dict):
|
||||
if 'type' not in optimizer and 'constructor' not in optimizer:
|
||||
for name, optim in optimizer.items():
|
||||
if not isinstance(optim, Optimizer):
|
||||
if isinstance(optim_wrapper, OptimWrapper):
|
||||
return optim_wrapper
|
||||
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)):
|
||||
if 'type' not in optim_wrapper and ('constructor'
|
||||
not in optim_wrapper):
|
||||
optim_wrappers = OrderedDict()
|
||||
for name, optim in optim_wrapper.items():
|
||||
if not isinstance(optim, OptimWrapper):
|
||||
raise ValueError(
|
||||
'each item mush be an optimizer object when "type"'
|
||||
' and "constructor" are not in optimizer, '
|
||||
f'but got {name}={optim}')
|
||||
return optimizer
|
||||
|
||||
return build_optimizer(self.model, optimizer)
|
||||
optim_wrappers[name] = optim
|
||||
return OptimWrapperDict(**optim_wrappers)
|
||||
else:
|
||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
|
||||
return optim_wrapper
|
||||
else:
|
||||
raise TypeError('optimizer should be an Optimizer object or dict, '
|
||||
f'but got {optimizer}')
|
||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||
f'object or dict, but got {optim_wrapper}')
|
||||
|
||||
def _build_param_scheduler(self, scheduler: Union[_ParamScheduler, Dict,
|
||||
List],
|
||||
optimizer: Optimizer) -> List[_ParamScheduler]:
|
||||
def _build_param_scheduler(
|
||||
self, scheduler: Union[_ParamScheduler, Dict, List],
|
||||
optim_wrapper: OptimWrapper) -> List[_ParamScheduler]:
|
||||
"""Build parameter schedulers for a single optimizer.
|
||||
|
||||
Args:
|
||||
scheduler (_ParamScheduler or dict or list): A Param Scheduler
|
||||
object or a dict or list of dict to build parameter schedulers.
|
||||
optimizer (Optimizer): An optimizer object is passed to construnct
|
||||
ParamScheduler object.
|
||||
optim_wrapper (OptimWrapper): An optimizer wrapper object is
|
||||
passed to construct ParamScheduler object.
|
||||
|
||||
Returns:
|
||||
list[_ParamScheduler]: List of parameter schedulers build from
|
||||
@ -922,7 +950,7 @@ class Runner:
|
||||
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
|
||||
param_schedulers.append(
|
||||
cls.build_iter_from_epoch( # type: ignore
|
||||
optimizer=self.optimizer,
|
||||
optimizer=optim_wrapper,
|
||||
**_scheduler,
|
||||
epoch_length=len(
|
||||
self.train_dataloader), # type: ignore
|
||||
@ -931,11 +959,11 @@ class Runner:
|
||||
param_schedulers.append(
|
||||
PARAM_SCHEDULERS.build(
|
||||
_scheduler,
|
||||
default_args=dict(optimizer=optimizer)))
|
||||
default_args=dict(optimizer=optim_wrapper)))
|
||||
else:
|
||||
raise TypeError(
|
||||
'_scheduler should be a _ParamScheduler object or dict, '
|
||||
f'but got {_scheduler}')
|
||||
'scheduler should be a _ParamScheduler object or dict, '
|
||||
f'but got {scheduler}')
|
||||
|
||||
return param_schedulers
|
||||
|
||||
@ -944,9 +972,10 @@ class Runner:
|
||||
List]) -> ParamSchedulerType:
|
||||
"""Build parameter schedulers.
|
||||
|
||||
``build_param_scheduler`` should be called after ``build_optimizer``
|
||||
because the building logic will change according to the number of
|
||||
optimizers built by the runner. The cases are as below:
|
||||
``build_param_scheduler`` should be called after
|
||||
``build_optim_wrapper`` because the building logic will change
|
||||
according to the number of optimizers built by the runner.
|
||||
The cases are as below:
|
||||
|
||||
- Single optimizer: When only one optimizer is built and used in the
|
||||
runner, ``build_param_scheduler`` will return a list of
|
||||
@ -968,7 +997,8 @@ class Runner:
|
||||
Examples:
|
||||
>>> # build one scheduler
|
||||
>>> optim_cfg = dict(dict(type='SGD', lr=0.01))
|
||||
>>> runner.optimizer = runner.build_optimizer(optim_cfg)
|
||||
>>> runner.optim_wrapper = runner.build_optim_wrapper(
|
||||
>>> optim_cfg)
|
||||
>>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2])
|
||||
>>> schedulers = runner.build_param_scheduler(scheduler_cfg)
|
||||
>>> schedulers
|
||||
@ -998,20 +1028,23 @@ class Runner:
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/optimizer.html
|
||||
"""
|
||||
param_schedulers: ParamSchedulerType
|
||||
if isinstance(self.optimizer, Optimizer):
|
||||
if not isinstance(self.optim_wrapper, OptimWrapperDict):
|
||||
# Since `OptimWrapperDict` inherits from `OptimWrapper`,
|
||||
# `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell
|
||||
# whether `self.optim_wrapper` is an `OptimizerWrapper` or
|
||||
# `OptimWrapperDict` instance. Therefore, here we simply check
|
||||
# self.optim_wrapper is not an `OptimWrapperDict` instance and
|
||||
# then assert it is an OptimWrapper instance.
|
||||
assert isinstance(self.optim_wrapper, OptimWrapper), (
|
||||
'`build_optimizer` should be called before'
|
||||
'`build_param_scheduler` because the latter depends '
|
||||
'on the former')
|
||||
param_schedulers = self._build_param_scheduler(
|
||||
scheduler, self.optimizer)
|
||||
scheduler, self.optim_wrapper) # type: ignore
|
||||
return param_schedulers
|
||||
else:
|
||||
assert isinstance(self.optimizer, dict)
|
||||
param_schedulers = dict()
|
||||
for name, optimizer in self.optimizer.items():
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise RuntimeError(
|
||||
'`build_optimizer` should be called before'
|
||||
'`build_param_scheduler` because the latter depends '
|
||||
'on the former')
|
||||
|
||||
for name, optimizer in self.optim_wrapper.items():
|
||||
if isinstance(scheduler, dict) and 'type' not in scheduler:
|
||||
# scheduler is a dict and each item is a ParamScheduler
|
||||
# object or a config to build ParamScheduler objects
|
||||
@ -1356,7 +1389,7 @@ class Runner:
|
||||
|
||||
# `build_optimizer` should be called before `build_param_scheduler`
|
||||
# because the latter depends on the former
|
||||
self.optimizer = self.build_optimizer(self.optimizer)
|
||||
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
|
||||
|
||||
if self.param_schedulers:
|
||||
self.param_schedulers = self.build_param_scheduler( # type: ignore
|
||||
@ -1418,9 +1451,6 @@ class Runner:
|
||||
fn_name (str): The function name in each hook to be called, such as
|
||||
"before_train_epoch".
|
||||
**kwargs: Keyword arguments passed to hook.
|
||||
|
||||
Raises:
|
||||
TypeError: if Hook got unexpected arguments.
|
||||
"""
|
||||
for hook in self._hooks:
|
||||
# support adding additional custom hook methods
|
||||
@ -1645,12 +1675,9 @@ class Runner:
|
||||
|
||||
# resume optimizer
|
||||
if 'optimizer' in checkpoint and resume_optimizer:
|
||||
self.optimizer = self.build_optimizer(self.optimizer)
|
||||
if isinstance(self.optimizer, dict):
|
||||
for name, optimizer in self.optimizer.items():
|
||||
optimizer.load_state_dict(checkpoint['optimizer'][name])
|
||||
else:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
|
||||
self.optim_wrapper.load_state_dict( # type: ignore
|
||||
checkpoint['optimizer'])
|
||||
|
||||
# resume param scheduler
|
||||
if 'param_schedulers' in checkpoint and resume_param_scheduler:
|
||||
@ -1771,16 +1798,13 @@ class Runner:
|
||||
}
|
||||
# save optimizer state dict to checkpoint
|
||||
if save_optimizer:
|
||||
if isinstance(self.optimizer, Optimizer):
|
||||
checkpoint['optimizer'] = self.optimizer.state_dict()
|
||||
elif isinstance(self.optimizer, dict):
|
||||
checkpoint['optimizer'] = dict()
|
||||
for name, optimizer in self.optimizer.items():
|
||||
checkpoint['optimizer'][name] = optimizer.state_dict()
|
||||
if isinstance(self.optim_wrapper, OptimWrapper):
|
||||
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
|
||||
else:
|
||||
raise TypeError(
|
||||
'self.optimizer should be an optimizer or a dict '
|
||||
f'containing optimizer, but got {self.optimizer}')
|
||||
'self.optim_wrapper should be an `OptimWrapper` '
|
||||
'or `OptimWrapperDict` instance, but got '
|
||||
f'{self.optim_wrapper}')
|
||||
|
||||
# save param scheduler state dict
|
||||
if save_param_scheduler:
|
||||
|
@ -2,7 +2,7 @@
|
||||
from .hub import load_url
|
||||
from .manager import ManagerMeta, ManagerMixin
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
find_latest_checkpoint, has_method,
|
||||
find_latest_checkpoint, has_batch_norm, has_method,
|
||||
import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
iter_cast, list_cast, mmcv_full_available,
|
||||
@ -28,5 +28,5 @@ __all__ = [
|
||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
|
||||
'set_multi_processing'
|
||||
'set_multi_processing', 'has_batch_norm'
|
||||
]
|
||||
|
@ -514,3 +514,20 @@ def find_latest_checkpoint(path: str, suffix: str = 'pth'):
|
||||
latest_path = checkpoint
|
||||
|
||||
return latest_path
|
||||
|
||||
|
||||
def has_batch_norm(model: nn.Module) -> bool:
|
||||
"""Detect whether model has a BatchNormalization layer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): training model.
|
||||
|
||||
Returns:
|
||||
bool: whether model has a BatchNormalization layer
|
||||
"""
|
||||
if isinstance(model, _BatchNorm):
|
||||
return True
|
||||
for m in model.children():
|
||||
if has_batch_norm(m):
|
||||
return True
|
||||
return False
|
||||
|
@ -10,6 +10,7 @@ from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.model import ExponentialMovingAverage
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
||||
from mmengine.runner import Runner
|
||||
|
||||
@ -79,7 +80,8 @@ class TestEMAHook(TestCase):
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
optimizer=torch.optim.Adam(ToyModel().parameters()),
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2),
|
||||
val_cfg=dict(interval=1),
|
||||
default_hooks=dict(logger=None),
|
||||
|
@ -46,8 +46,8 @@ class TestOptimizerHook:
|
||||
x = torch.rand(1, 1, 3, 3)
|
||||
|
||||
dummy_runner = MagicMock()
|
||||
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
|
||||
dummy_runner.optimizer.step = Mock(return_value=None)
|
||||
dummy_runner.optim_wrapper.zero_grad = Mock(return_value=None)
|
||||
dummy_runner.optim_wrapper.step = Mock(return_value=None)
|
||||
dummy_runner.model = model
|
||||
dummy_runner.outputs = dict()
|
||||
|
||||
@ -82,7 +82,7 @@ class TestOptimizerHook:
|
||||
assert 'conv3.bias' in dummy_runner.logger.msg
|
||||
assert 'conv1.weight' not in dummy_runner.logger.msg
|
||||
assert 'conv1.bias' not in dummy_runner.logger.msg
|
||||
dummy_runner.optimizer.step.assert_called()
|
||||
dummy_runner.optim_wrapper.step.assert_called()
|
||||
dummy_runner.outputs['loss'].backward.assert_called()
|
||||
optimizer_hook.clip_grads.assert_called()
|
||||
optimizer_hook.detect_anomalous_parameters.assert_called()
|
||||
@ -109,7 +109,7 @@ class TestOptimizerHook:
|
||||
|
||||
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||
|
||||
dummy_runner.optimizer.step.assert_called()
|
||||
dummy_runner.optim_wrapper.step.assert_called()
|
||||
dummy_runner.outputs['loss'].backward.assert_called()
|
||||
optimizer_hook.clip_grads.assert_not_called()
|
||||
optimizer_hook.detect_anomalous_parameters.assert_not_called()
|
||||
|
@ -2,8 +2,12 @@
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.hooks import RuntimeInfoHook
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||
|
||||
|
||||
class TestRuntimeInfoHook(TestCase):
|
||||
@ -47,18 +51,31 @@ class TestRuntimeInfoHook(TestCase):
|
||||
self.assertEqual(message_hub.get_info('epoch'), 9)
|
||||
|
||||
def test_before_train_iter(self):
|
||||
model = nn.Linear(1, 1)
|
||||
optim1 = SGD(model.parameters(), lr=0.01)
|
||||
optim2 = SGD(model.parameters(), lr=0.02)
|
||||
optim_wrapper1 = OptimWrapper(optim1)
|
||||
optim_wrapper2 = OptimWrapper(optim2)
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
key1=optim_wrapper1, key2=optim_wrapper2)
|
||||
# single optimizer
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_iter')
|
||||
runner = Mock()
|
||||
runner.iter = 9
|
||||
runner.optimizer.param_groups = [{'lr': 0.01}]
|
||||
runner.optim_wrapper = optim_wrapper1
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
self.assertEqual(message_hub.get_info('iter'), 9)
|
||||
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'runner.optim_wrapper.get_lr()'):
|
||||
runner.optim_wrapper = Mock()
|
||||
runner.optim_wrapper.get_lr = Mock(return_value='error type')
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
|
||||
# multiple optimizers
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_iter')
|
||||
@ -68,8 +85,8 @@ class TestRuntimeInfoHook(TestCase):
|
||||
optimizer1.param_groups = [{'lr': 0.01}]
|
||||
optimizer2 = Mock()
|
||||
optimizer2.param_groups = [{'lr': 0.02}]
|
||||
runner.optimizer = dict(key1=optimizer1, key2=optimizer2)
|
||||
runner.message_hub = message_hub
|
||||
runner.optim_wrapper = optim_wrapper_dict
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
self.assertEqual(message_hub.get_info('iter'), 9)
|
||||
|
@ -37,6 +37,12 @@ def test_is_model_wrapper():
|
||||
if hasattr(torch.distributed, '_verify_model_across_ranks'):
|
||||
torch.distributed._verify_model_across_ranks = mock
|
||||
|
||||
# _verify_model_across_ranks is added in torch1.11.0 so we should check
|
||||
# whether _verify_params_across_processes is the member of
|
||||
# torch.distributed before mocking
|
||||
if hasattr(torch.distributed, '_verify_params_across_processes'):
|
||||
torch.distributed._verify_params_across_processes = mock
|
||||
|
||||
model = Model()
|
||||
assert not is_model_wrapper(model)
|
||||
|
||||
|
@ -6,8 +6,9 @@ from unittest.mock import MagicMock
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimizerConstructor, build_optimizer)
|
||||
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimWrapperConstructor, OptimWrapper,
|
||||
build_optim_wrapper)
|
||||
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
|
||||
from mmengine.registry import build_from_cfg
|
||||
from mmengine.utils import mmcv_full_available
|
||||
@ -201,30 +202,52 @@ class TestBuilder(TestCase):
|
||||
|
||||
def test_build_optimizer(self):
|
||||
# test build function without ``constructor`` and ``paramwise_cfg``
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optimizer = build_optimizer(self.model, optimizer_cfg)
|
||||
self._check_default_optimizer(optimizer, self.model)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
||||
|
||||
# test build optimizer without type in optim_wrapper_cfg
|
||||
optim_wrapper_cfg = dict(
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg)
|
||||
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
||||
|
||||
# test build function with invalid ``constructor``
|
||||
with self.assertRaises(KeyError):
|
||||
optimizer_cfg['constructor'] = 'INVALID_CONSTRUCTOR'
|
||||
build_optimizer(self.model, optimizer_cfg)
|
||||
optim_wrapper_cfg['constructor'] = 'INVALID_CONSTRUCTOR'
|
||||
build_optim_wrapper(self.model, optim_wrapper_cfg)
|
||||
|
||||
# test build function with invalid ``paramwise_cfg``
|
||||
with self.assertRaises(KeyError):
|
||||
optimizer_cfg['paramwise_cfg'] = dict(invalid_mult=1)
|
||||
build_optimizer(self.model, optimizer_cfg)
|
||||
optim_wrapper_cfg['paramwise_cfg'] = dict(invalid_mult=1)
|
||||
build_optim_wrapper(self.model, optim_wrapper_cfg)
|
||||
|
||||
optim_wrapper_cfg.pop('optimizer')
|
||||
optim_wrapper_cfg.pop('constructor')
|
||||
optim_wrapper_cfg.pop('paramwise_cfg')
|
||||
self.assertRaisesRegex(
|
||||
AssertionError, '`optim_wrapper_cfg` must contain',
|
||||
lambda: build_optim_wrapper(self.model, optim_wrapper_cfg))
|
||||
|
||||
def test_build_default_optimizer_constructor(self):
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
@ -232,22 +255,26 @@ class TestBuilder(TestCase):
|
||||
dwconv_decay_mult=0.1,
|
||||
dcn_offset_lr_mult=0.1)
|
||||
optim_constructor_cfg = dict(
|
||||
type='DefaultOptimizerConstructor',
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
type='DefaultOptimWrapperConstructor',
|
||||
optim_wrapper_cfg=optim_wrapper,
|
||||
paramwise_cfg=paramwise_cfg)
|
||||
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
|
||||
optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
optim_constructor_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
self._check_sgd_optimizer(optim_wrapper.optimizer, self.model,
|
||||
**paramwise_cfg)
|
||||
|
||||
def test_build_custom_optimizer_constructor(self):
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class MyOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class MyOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
|
||||
def __call__(self, model):
|
||||
if hasattr(model, 'module'):
|
||||
@ -268,9 +295,10 @@ class TestBuilder(TestCase):
|
||||
paramwise_cfg = dict(conv1_lr_mult=5)
|
||||
optim_constructor_cfg = dict(
|
||||
type='MyOptimizerConstructor',
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
optim_wrapper_cfg=optim_wrapper_cfg,
|
||||
paramwise_cfg=paramwise_cfg)
|
||||
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
|
||||
optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
optim_constructor_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
|
||||
param_groups = optimizer.param_groups
|
||||
@ -291,153 +319,182 @@ class TestBuilder(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
# optimizer_cfg must be a dict
|
||||
optimizer_cfg = []
|
||||
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
|
||||
optim_constructor(self.model)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# paramwise_cfg must be a dict or None
|
||||
optimizer_cfg = dict(lr=0.0001)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(lr=0.0001, weight_decay=None))
|
||||
paramwise_cfg = ['error']
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_constructor(self.model)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# bias_decay_mult/norm_decay_mult is specified but weight_decay
|
||||
# is None
|
||||
optimizer_cfg = dict(lr=0.0001, weight_decay=None)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(lr=0.0001, weight_decay=None))
|
||||
paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1)
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_constructor(self.model)
|
||||
|
||||
# basic config with ExampleModel
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
self._check_default_optimizer(optimizer, self.model)
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
||||
|
||||
def test_default_optimizer_constructor_with_model_wrapper(self):
|
||||
# basic config with pseudo data parallel
|
||||
model = PseudoDataParallel()
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = None
|
||||
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
self._check_default_optimizer(optimizer, model, prefix='module.')
|
||||
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self._check_default_optimizer(
|
||||
optim_wrapper.optimizer, model, prefix='module.')
|
||||
|
||||
# paramwise_cfg with pseudo data parallel
|
||||
model = PseudoDataParallel()
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
norm_decay_mult=0,
|
||||
dwconv_decay_mult=0.1,
|
||||
dcn_offset_lr_mult=0.1)
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self._check_sgd_optimizer(
|
||||
optimizer, model, prefix='module.', **paramwise_cfg)
|
||||
optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg)
|
||||
|
||||
# basic config with DataParallel
|
||||
if torch.cuda.is_available():
|
||||
model = torch.nn.DataParallel(ExampleModel())
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = None
|
||||
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
self._check_default_optimizer(optimizer, model, prefix='module.')
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self._check_default_optimizer(
|
||||
optim_wrapper.optimizer, model, prefix='module.')
|
||||
|
||||
# paramwise_cfg with DataParallel
|
||||
if torch.cuda.is_available():
|
||||
model = torch.nn.DataParallel(self.model)
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
norm_decay_mult=0,
|
||||
dwconv_decay_mult=0.1,
|
||||
dcn_offset_lr_mult=0.1)
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self._check_sgd_optimizer(
|
||||
optimizer, model, prefix='module.', **paramwise_cfg)
|
||||
optim_wrapper.optimizer,
|
||||
model,
|
||||
prefix='module.',
|
||||
**paramwise_cfg)
|
||||
|
||||
def test_default_optimizer_constructor_with_empty_paramwise_cfg(self):
|
||||
# Empty paramwise_cfg with ExampleModel
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict()
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
self._check_default_optimizer(optimizer, self.model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
||||
|
||||
# Empty paramwise_cfg with ExampleModel and no grad
|
||||
model = ExampleModel()
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict()
|
||||
optim_constructor = DefaultOptimizerConstructor(optimizer_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
self._check_default_optimizer(optimizer, model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, model)
|
||||
|
||||
def test_default_optimizer_constructor_with_paramwise_cfg(self):
|
||||
# paramwise_cfg with ExampleModel
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
norm_decay_mult=0,
|
||||
dwconv_decay_mult=0.1,
|
||||
dcn_offset_lr_mult=0.1)
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
self._check_sgd_optimizer(optim_wrapper.optimizer, self.model,
|
||||
**paramwise_cfg)
|
||||
|
||||
def test_default_optimizer_constructor_no_grad(self):
|
||||
# paramwise_cfg with ExampleModel and no grad
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
@ -447,11 +504,12 @@ class TestBuilder(TestCase):
|
||||
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
optimizer = optim_wrapper.optimizer
|
||||
param_groups = optimizer.param_groups
|
||||
assert isinstance(optimizer, torch.optim.SGD)
|
||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
||||
assert optimizer.defaults['lr'] == self.base_lr
|
||||
assert optimizer.defaults['momentum'] == self.momentum
|
||||
assert optimizer.defaults['weight_decay'] == self.base_wd
|
||||
@ -465,11 +523,13 @@ class TestBuilder(TestCase):
|
||||
def test_default_optimizer_constructor_bypass_duplicate(self):
|
||||
# paramwise_cfg with bypass_duplicate option
|
||||
model = ExampleDuplicateModel()
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
bias_lr_mult=2,
|
||||
bias_decay_mult=0.5,
|
||||
@ -479,8 +539,8 @@ class TestBuilder(TestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'some parameters appear in more than one parameter group'):
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_constructor(model)
|
||||
|
||||
paramwise_cfg = dict(
|
||||
@ -490,27 +550,31 @@ class TestBuilder(TestCase):
|
||||
dwconv_decay_mult=0.1,
|
||||
dcn_offset_lr_mult=0.1,
|
||||
bypass_duplicate=True)
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
|
||||
self.assertWarnsRegex(
|
||||
Warning,
|
||||
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
|
||||
lambda: optim_constructor(model))
|
||||
optimizer = optim_constructor(model)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
model_parameters = list(model.parameters())
|
||||
num_params = 14 if MMCV_FULL_AVAILABLE else 11
|
||||
assert len(
|
||||
optimizer.param_groups) == len(model_parameters) == num_params
|
||||
self._check_sgd_optimizer(optimizer, model, **paramwise_cfg)
|
||||
assert len(optim_wrapper.optimizer.param_groups) == len(
|
||||
model_parameters) == num_params
|
||||
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
|
||||
**paramwise_cfg)
|
||||
|
||||
def test_default_optimizer_constructor_custom_key(self):
|
||||
# test DefaultOptimizerConstructor with custom_keys and ExampleModel
|
||||
optimizer_cfg = dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum)
|
||||
# test DefaultOptimWrapperConstructor with custom_keys and
|
||||
# ExampleModel
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=self.base_lr,
|
||||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict(
|
||||
custom_keys={
|
||||
'param1': dict(lr_mult=10),
|
||||
@ -523,23 +587,24 @@ class TestBuilder(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
# custom_keys should be a dict
|
||||
paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001])
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg_)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg_)
|
||||
optimizer = optim_constructor(self.model)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# if 'decay_mult' is specified in custom_keys, weight_decay
|
||||
# should be specified
|
||||
optimizer_cfg_ = dict(type='SGD', lr=0.01)
|
||||
optim_wrapper_cfg_ = dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))
|
||||
paramwise_cfg_ = dict(
|
||||
custom_keys={'.backbone': dict(decay_mult=0.5)})
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg_, paramwise_cfg_)
|
||||
optimizer = optim_constructor(self.model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg_, paramwise_cfg_)
|
||||
optim_constructor(self.model)
|
||||
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model).optimizer
|
||||
# check optimizer type and default config
|
||||
assert isinstance(optimizer, torch.optim.SGD)
|
||||
assert optimizer.defaults['lr'] == self.base_lr
|
||||
@ -598,14 +663,17 @@ class TestBuilder(TestCase):
|
||||
assert param_groups[i][setting] == settings[
|
||||
setting], f'{name} {setting}'
|
||||
|
||||
# test DefaultOptimizerConstructor with custom_keys and ExampleModel 2
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=self.base_lr, momentum=self.momentum)
|
||||
# test DefaultOptimWrapperConstructor with custom_keys and
|
||||
# ExampleModel 2
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD', lr=self.base_lr, momentum=self.momentum))
|
||||
paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)})
|
||||
|
||||
optim_constructor = DefaultOptimizerConstructor(
|
||||
optimizer_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optimizer = optim_constructor(self.model).optimizer
|
||||
# check optimizer type and default config
|
||||
assert isinstance(optimizer, torch.optim.SGD)
|
||||
assert optimizer.defaults['lr'] == self.base_lr
|
||||
|
384
tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Normal file
384
tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Normal file
@ -0,0 +1,384 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_dist
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import SGD, Adam, Optimizer
|
||||
|
||||
from mmengine import MessageHub, MMLogger
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
self.conv3 = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ToyModel2(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestOptimWrapper(MultiProcessTestCase):
|
||||
# Test `OptimWrapper.accumulate_grad` will block the gradient
|
||||
# synchronization when using gradient accumulation strategy in distributed
|
||||
# data parallel training.
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def run_test(self, test_name: str, parent_pipe) -> None:
|
||||
self.model = ToyModel()
|
||||
self.optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||
self.logger = MMLogger.get_instance('test_optim_wrapper')
|
||||
self.message_hub = MessageHub.get_instance('test_optim_wrapper_init')
|
||||
super().run_test(test_name, parent_pipe)
|
||||
|
||||
def test_init(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
self.assertEqual(optim_wrapper.optimizer, self.optimizer)
|
||||
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
|
||||
self.assertEqual(optim_wrapper.accumulative_iters, 1)
|
||||
self.assertIs(optim_wrapper.logger, self.logger)
|
||||
self.assertIs(optim_wrapper.message_hub, self.message_hub)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'If `clip_grad_kwargs` is not None'):
|
||||
OptimWrapper(self.optimizer, clip_grad=[])
|
||||
|
||||
def test_update_params(self):
|
||||
# Test update params every iteration.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
|
||||
self._mock_method(optim_wrapper)
|
||||
loss = torch.tensor(1)
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.backward.assert_called_with(torch.tensor(1))
|
||||
optim_wrapper.step.assert_called_with()
|
||||
optim_wrapper.zero_grad.assert_called_with()
|
||||
|
||||
with optim_wrapper.accumulate_grad(self.model, 2, 100):
|
||||
optim_wrapper.update_params(torch.tensor(1))
|
||||
optim_wrapper.backward.assert_called_with(torch.tensor(1))
|
||||
optim_wrapper.step.assert_called_with()
|
||||
optim_wrapper.zero_grad.assert_called_with()
|
||||
|
||||
# It will raise an error if `accumulative_iters > 1` and
|
||||
# `accumulate_grad` is not enabled.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
self._mock_method(optim_wrapper)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'gradient accumulation must be'):
|
||||
optim_wrapper.update_params(loss)
|
||||
|
||||
# `iter=0`, Call `optimizer_step` first time.
|
||||
with optim_wrapper.accumulate_grad(
|
||||
self.model, cur_iter=0, max_iters=100):
|
||||
loss = torch.tensor(1)
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3)
|
||||
optim_wrapper.step.assert_not_called()
|
||||
optim_wrapper.zero_grad.assert_not_called()
|
||||
|
||||
# `iter=2`, Call `optimizer_step` first time.
|
||||
with optim_wrapper.accumulate_grad(
|
||||
self.model, cur_iter=2, max_iters=100):
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
self._mock_method(optim_wrapper)
|
||||
# Test end of training.
|
||||
with optim_wrapper.accumulate_grad(
|
||||
self.model, cur_iter=99, max_iters=100):
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
optim_wrapper.backward.assert_called_with(1)
|
||||
|
||||
# If ``accumulative_iters > 1``, call ``update_params`` with
|
||||
# non-accumulate_grad context will raise an Assertion error
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
|
||||
optim_wrapper.accumulative_iters = 2
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'gradient accumulation must be performed'):
|
||||
optim_wrapper.update_params(loss)
|
||||
|
||||
def test_initilize_iter_status(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
optim_wrapper._initilize_iter_status(self.model)
|
||||
self.assertEqual(optim_wrapper.divisible_iters, 0)
|
||||
self.assertEqual(optim_wrapper.remainder_iters, 0)
|
||||
|
||||
# Indivisible cur_iter will output warning.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
optim_wrapper.cur_iter = 0
|
||||
optim_wrapper.max_iters = 100
|
||||
with self.assertLogs(self.logger) as cm:
|
||||
optim_wrapper._initilize_iter_status(self.model)
|
||||
self.assertEqual(len(cm.output), 1)
|
||||
self.assertRegex(cm.records[0].msg, 'Resume iter number is not')
|
||||
|
||||
# Model with batch norm will output warning.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
optim_wrapper.cur_iter = 0
|
||||
optim_wrapper.max_iters = 99
|
||||
model = nn.BatchNorm2d(1)
|
||||
with self.assertLogs(self.logger) as cm:
|
||||
optim_wrapper._initilize_iter_status(model)
|
||||
self.assertEqual(len(cm.output), 1)
|
||||
self.assertRegex(cm.records[0].msg, 'Gradient accumulative')
|
||||
|
||||
def test_ger_lr(self):
|
||||
model = ToyModel()
|
||||
optim = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper = OptimWrapper(optim)
|
||||
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
|
||||
|
||||
def test_get_momentum(self):
|
||||
# Get momentum from SGD
|
||||
model = ToyModel()
|
||||
optim = SGD(model.parameters(), lr=0., momentum=0.8)
|
||||
optim_wrapper = OptimWrapper(optim)
|
||||
self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.8]))
|
||||
# Get momentum from Adam
|
||||
optim = Adam(model.parameters(), lr=0., betas=(0.9, 0.9))
|
||||
optim_wrapper = OptimWrapper(optim)
|
||||
self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.9]))
|
||||
|
||||
def test_backward(self):
|
||||
loss = MagicMock()
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
optim_wrapper.backward(loss)
|
||||
loss.backward.assert_called()
|
||||
|
||||
def test_zero_grad(self):
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
optim_wrapper.zero_grad()
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
def test_step(self):
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
optim_wrapper.step()
|
||||
optimizer.step.assert_called()
|
||||
|
||||
def test_clip_grads(self):
|
||||
optim_wrapper = OptimWrapper(
|
||||
self.optimizer, clip_grad=dict(max_norm=35))
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1))
|
||||
loss.backward()
|
||||
optim_wrapper._clip_grad()
|
||||
log_scalars = self.message_hub.log_scalars
|
||||
self.assertIn('train/grad_norm', log_scalars)
|
||||
|
||||
def test_state_dict(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
self.assertEqual(optim_wrapper.state_dict(),
|
||||
self.optimizer.state_dict())
|
||||
|
||||
def test_load_state_dict(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
model = ToyModel()
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||
|
||||
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
|
||||
|
||||
def test_param_groups(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
self.assertEqual(optim_wrapper.param_groups,
|
||||
self.optimizer.param_groups)
|
||||
|
||||
def test_accumulate_grad(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel2()
|
||||
ddp_model = DistributedDataParallel(model)
|
||||
optimizer = SGD(ddp_model.parameters(), lr=0.01)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
|
||||
optim_wrapper.zero_grad()
|
||||
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
|
||||
# Automatically sync grads if `accumulative_iters` = 1
|
||||
inputs = torch.randn(1, 1, 1, 1) * self.rank
|
||||
ddp_model(inputs).sum().backward()
|
||||
grad = model.conv.weight.grad
|
||||
all_grads = all_gather(grad)
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# Do not sync grads when `optim_wrapper.cur_iter` cannot be
|
||||
# divided by `optim_wrapper.accumulative_iters`
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3)
|
||||
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
|
||||
ddp_model(inputs).sum().backward()
|
||||
all_grads = all_gather(model.conv.weight.grad)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# sync grads if `cur_iter == 2`
|
||||
with optim_wrapper.accumulate_grad(ddp_model, 2, 100):
|
||||
ddp_model(inputs).sum().backward()
|
||||
all_grads = all_gather(model.conv.weight.grad)
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
def test_precision_context(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
with optim_wrapper.precision_context():
|
||||
pass
|
||||
|
||||
def _init_dist_env(self, rank, world_size):
|
||||
"""Initialize the distributed environment."""
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29515'
|
||||
os.environ['RANK'] = str(rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='gloo', rank=rank, world_size=world_size)
|
||||
|
||||
# TODO Test the real interface after add testing tool function which can
|
||||
# test the function or method is read called.
|
||||
def _mock_method(self, optim_wrapper):
|
||||
optim_wrapper.backward = MagicMock()
|
||||
optim_wrapper.step = MagicMock()
|
||||
optim_wrapper.zero_grad = MagicMock()
|
||||
|
||||
|
||||
class TestAmpOptimWrapper(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model = ToyModel()
|
||||
self.optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_init(self):
|
||||
# Test with default arguments.
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
|
||||
|
||||
# Test with dynamic.
|
||||
amp_optim_wrapper = AmpOptimWrapper(
|
||||
'dynamic', optimizer=self.optimizer)
|
||||
self.assertIsNone(amp_optim_wrapper._scale_update_param)
|
||||
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
|
||||
|
||||
# Test with dict loss_scale.
|
||||
amp_optim_wrapper = AmpOptimWrapper(
|
||||
dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
|
||||
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
|
||||
self.assertIsNone(amp_optim_wrapper._scale_update_param)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'loss_scale must be of type float'):
|
||||
AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown')
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_step(self):
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)
|
||||
amp_optim_wrapper.loss_scaler = MagicMock()
|
||||
amp_optim_wrapper.step()
|
||||
amp_optim_wrapper.loss_scaler.step.assert_called_with(
|
||||
amp_optim_wrapper.optimizer)
|
||||
amp_optim_wrapper.loss_scaler.update.assert_called_with(
|
||||
amp_optim_wrapper._scale_update_param)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_backward(self):
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
loss_scaler = MagicMock()
|
||||
scale_return = MagicMock()
|
||||
scale_fn = MagicMock(return_value=scale_return)
|
||||
loss_scaler.scale = scale_fn
|
||||
amp_optim_wrapper.loss_scaler = loss_scaler
|
||||
|
||||
amp_optim_wrapper.backward(1)
|
||||
loss_scaler.scale.assert_called_with(1)
|
||||
scale_return.backward.assert_called_with()
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_state_dict(self):
|
||||
self.model = self.model.cuda()
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||
amp_optim_wrapper.update_params(loss)
|
||||
state_dict = amp_optim_wrapper.state_dict()
|
||||
scalar_state_dict = state_dict.pop('loss_scaler')
|
||||
optim_state_dict = state_dict
|
||||
|
||||
self.assertDictEqual(optim_state_dict,
|
||||
amp_optim_wrapper.optimizer.state_dict())
|
||||
self.assertDictEqual(scalar_state_dict,
|
||||
amp_optim_wrapper.loss_scaler.state_dict())
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_load_state_dict(self):
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
self.model = self.model.cuda()
|
||||
# Test load from optimizer
|
||||
optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||
|
||||
self.assertDictEqual(optimizer.state_dict(),
|
||||
amp_optim_wrapper.optimizer.state_dict())
|
||||
# Test load from optim_wrapper
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
amp_optim_wrapper_ = AmpOptimWrapper(
|
||||
optimizer=SGD(self.model.parameters(), lr=0.1))
|
||||
amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict())
|
||||
self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(),
|
||||
amp_optim_wrapper_.optimizer.state_dict())
|
||||
self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(),
|
||||
amp_optim_wrapper_.loss_scaler.state_dict())
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available()
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_precision_context(self):
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
with amp_optim_wrapper.precision_context():
|
||||
x = torch.randn(1, 1, 1, 1).cuda()
|
||||
y = nn.Conv2d(1, 1, 1).cuda()(x)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
142
tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py
Normal file
142
tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import contextmanager
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
|
||||
|
||||
|
||||
class TestOptimWrapperDict(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
model1 = nn.Linear(1, 1)
|
||||
model2 = nn.Linear(1, 1)
|
||||
self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8)
|
||||
self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9)
|
||||
self.optim_wrapper1 = OptimWrapper(self.optim1)
|
||||
self.optim_wrapper2 = OptimWrapper(self.optim2)
|
||||
self.optimizers_wrappers = dict(
|
||||
optim1=self.optim_wrapper1, optim2=self.optim_wrapper2)
|
||||
|
||||
@patch('torch.cuda.is_available', lambda: True)
|
||||
def test_init(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertEqual(optim_wrapper_dict.optim_wrappers,
|
||||
self.optimizers_wrappers)
|
||||
# Different types of OptimWrapper will raise an error
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, 'All optimizer wrappers should have the same'):
|
||||
optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2)
|
||||
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'):
|
||||
optim_wrapper2 = OptimWrapper(
|
||||
optimizer=self.optim2, accumulative_iters=2)
|
||||
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
|
||||
|
||||
def test_accumulate_grad(self):
|
||||
|
||||
@contextmanager
|
||||
def context_a(a, b, *args, **kwargs):
|
||||
a[0] = 100
|
||||
yield
|
||||
a[0] = 1
|
||||
|
||||
@contextmanager
|
||||
def context_b(a, b, *args, **kwargs):
|
||||
b[0] = 200
|
||||
yield
|
||||
b[0] = 2
|
||||
|
||||
a = [0]
|
||||
b = [0]
|
||||
# Test enter the context both of `optim_wrapper1` and `optim_wrapper1`.
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.optim_wrapper1.accumulate_grad = context_a
|
||||
self.optim_wrapper2.accumulate_grad = context_b
|
||||
with optim_wrapper_dict.accumulate_grad(a, b, 0):
|
||||
self.assertEqual(a[0], 100)
|
||||
self.assertEqual(b[0], 200)
|
||||
|
||||
self.assertEqual(a[0], 1)
|
||||
self.assertEqual(b[0], 2)
|
||||
|
||||
def test_state_dict(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
state_dict = optim_wrapper_dict.state_dict()
|
||||
self.assertEqual(state_dict['optim1'],
|
||||
self.optim_wrapper1.state_dict())
|
||||
self.assertEqual(state_dict['optim2'],
|
||||
self.optim_wrapper2.state_dict())
|
||||
|
||||
def test_load_state_dict(self):
|
||||
# Test OptimWrapperDict can load from saved state dict.
|
||||
model1 = nn.Linear(1, 1)
|
||||
model2 = nn.Linear(1, 1)
|
||||
optim1 = SGD(model1.parameters(), lr=0.1)
|
||||
optim2 = SGD(model2.parameters(), lr=0.1)
|
||||
optim_wrapper_load1 = OptimWrapper(optim1)
|
||||
optim_wrapper_load2 = OptimWrapper(optim2)
|
||||
|
||||
optim_wrapper_dict_save = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
optim_wrapper_dict_load = OptimWrapperDict(
|
||||
optim1=optim_wrapper_load1, optim2=optim_wrapper_load2)
|
||||
state_dict = optim_wrapper_dict_save.state_dict()
|
||||
optim_wrapper_dict_load.load_state_dict(state_dict)
|
||||
|
||||
self.assertDictEqual(optim_wrapper_dict_load.state_dict(),
|
||||
optim_wrapper_dict_save.state_dict())
|
||||
|
||||
def test_items(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertListEqual(
|
||||
list(optim_wrapper_dict.items()),
|
||||
list(self.optimizers_wrappers.items()))
|
||||
|
||||
def test_values(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertListEqual(
|
||||
list(optim_wrapper_dict.values()),
|
||||
list(self.optimizers_wrappers.values()))
|
||||
|
||||
def test_keys(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertListEqual(
|
||||
list(optim_wrapper_dict.keys()),
|
||||
list(self.optimizers_wrappers.keys()))
|
||||
|
||||
def test_get_lr(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
lr = optim_wrapper_dict.get_lr()
|
||||
self.assertEqual(lr['optim1.lr'], [0.1])
|
||||
self.assertEqual(lr['optim2.lr'], [0.2])
|
||||
|
||||
def test_get_momentum(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
momentum = optim_wrapper_dict.get_momentum()
|
||||
self.assertEqual(momentum['optim1.momentum'], [0.8])
|
||||
self.assertEqual(momentum['optim2.momentum'], [0.9])
|
||||
|
||||
def test_getitem(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertIs(self.optimizers_wrappers['optim1'],
|
||||
optim_wrapper_dict['optim1'])
|
||||
self.assertIs(self.optimizers_wrappers['optim2'],
|
||||
optim_wrapper_dict['optim2'])
|
||||
|
||||
def test_len(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertEqual(len(optim_wrapper_dict), 2)
|
||||
|
||||
def test_contain(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertIn('optim1', optim_wrapper_dict)
|
||||
|
||||
def test_repr(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
desc = repr(optim_wrapper_dict)
|
||||
self.assertRegex(desc, 'name: optim1')
|
@ -16,13 +16,15 @@ from mmengine.config import Config
|
||||
from mmengine.data import DefaultSampler
|
||||
from mmengine.evaluator import BaseMetric, Evaluator
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
||||
IterTimerHook, LoggerHook, OptimizerHook,
|
||||
ParamSchedulerHook)
|
||||
IterTimerHook, LoggerHook, ParamSchedulerHook,
|
||||
RuntimeInfoHook)
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.optim import DefaultOptimizerConstructor, MultiStepLR, StepLR
|
||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||
OptimWrapper, OptimWrapperDict, StepLR)
|
||||
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
|
||||
PARAM_SCHEDULERS, Registry)
|
||||
MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
||||
Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.priority import Priority, get_priority
|
||||
@ -73,24 +75,24 @@ class CustomModelWrapper(nn.Module):
|
||||
self.model = model
|
||||
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class ToyMultipleOptimizerConstructor:
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
if not isinstance(optimizer_cfg, dict):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
if not isinstance(optim_wrapper_cfg, dict):
|
||||
raise TypeError('optimizer_cfg should be a dict',
|
||||
f'but got {type(optimizer_cfg)}')
|
||||
f'but got {type(optim_wrapper_cfg)}')
|
||||
assert paramwise_cfg is None, (
|
||||
'parawise_cfg should be set in each optimizer separately')
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
self.optim_wrapper_cfg = optim_wrapper_cfg
|
||||
self.constructors = {}
|
||||
for key, cfg in self.optimizer_cfg.items():
|
||||
for key, cfg in self.optim_wrapper_cfg.items():
|
||||
_cfg = cfg.copy()
|
||||
paramwise_cfg_ = _cfg.pop('paramwise_cfg', None)
|
||||
self.constructors[key] = DefaultOptimizerConstructor(
|
||||
self.constructors[key] = DefaultOptimWrapperConstructor(
|
||||
_cfg, paramwise_cfg_)
|
||||
|
||||
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
|
||||
def __call__(self, model: nn.Module) -> OptimWrapperDict:
|
||||
optimizers = {}
|
||||
while hasattr(model, 'module'):
|
||||
model = model.module
|
||||
@ -98,7 +100,7 @@ class ToyMultipleOptimizerConstructor:
|
||||
for key, constructor in self.constructors.items():
|
||||
module = getattr(model, key)
|
||||
optimizers[key] = constructor(module)
|
||||
return optimizers
|
||||
return OptimWrapperDict(**optimizers)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
@ -160,13 +162,6 @@ class ToyHook2(Hook):
|
||||
pass
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ToyHook3(Hook):
|
||||
|
||||
def before_train_iter(self, runner, data_batch):
|
||||
pass
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class CustomTrainLoop(BaseLoop):
|
||||
|
||||
@ -246,7 +241,8 @@ class TestRunner(TestCase):
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
optimizer=dict(type='SGD', lr=0.01),
|
||||
optim_wrapper=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
val_evaluator=dict(type='ToyMetric1'),
|
||||
test_evaluator=dict(type='ToyMetric1'),
|
||||
@ -299,7 +295,7 @@ class TestRunner(TestCase):
|
||||
# also be None
|
||||
cfg.experiment_name = 'test_init2'
|
||||
cfg.pop('train_dataloader')
|
||||
cfg.pop('optimizer')
|
||||
cfg.pop('optim_wrapper')
|
||||
cfg.pop('param_scheduler')
|
||||
runner = Runner(**cfg)
|
||||
self.assertIsInstance(runner, Runner)
|
||||
@ -324,7 +320,7 @@ class TestRunner(TestCase):
|
||||
cfg.experiment_name = 'test_init5'
|
||||
cfg.pop('train_cfg')
|
||||
cfg.pop('train_dataloader')
|
||||
cfg.pop('optimizer')
|
||||
cfg.pop('optim_wrapper')
|
||||
with self.assertRaisesRegex(ValueError, 'should be None'):
|
||||
runner = Runner(**cfg)
|
||||
|
||||
@ -398,7 +394,7 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner._train_dataloader, dict)
|
||||
self.assertIsInstance(runner._val_dataloader, dict)
|
||||
self.assertIsInstance(runner._test_dataloader, dict)
|
||||
self.assertIsInstance(runner.optimizer, dict)
|
||||
self.assertIsInstance(runner.optim_wrapper, dict)
|
||||
self.assertIsInstance(runner.param_schedulers[0], dict)
|
||||
|
||||
# After calling runner.train(),
|
||||
@ -408,7 +404,7 @@ class TestRunner(TestCase):
|
||||
|
||||
self.assertIsInstance(runner._train_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.train_dataloader, DataLoader)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper, OptimWrapper)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
self.assertIsInstance(runner._val_loop, BaseLoop)
|
||||
self.assertIsInstance(runner._val_loop.dataloader, DataLoader)
|
||||
@ -423,10 +419,10 @@ class TestRunner(TestCase):
|
||||
|
||||
# 4. initialize runner with objects rather than config
|
||||
model = ToyModel()
|
||||
optimizer = SGD(
|
||||
optim_wrapper = OptimWrapper(SGD(
|
||||
model.parameters(),
|
||||
lr=0.01,
|
||||
)
|
||||
))
|
||||
toy_hook = ToyHook()
|
||||
toy_hook2 = ToyHook2()
|
||||
|
||||
@ -438,8 +434,8 @@ class TestRunner(TestCase):
|
||||
work_dir=self.temp_dir,
|
||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
train_dataloader=train_dataloader,
|
||||
optimizer=optimizer,
|
||||
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
|
||||
optim_wrapper=optim_wrapper,
|
||||
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
||||
val_cfg=dict(interval=1, begin=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_evaluator=ToyMetric1(),
|
||||
@ -618,79 +614,92 @@ class TestRunner(TestCase):
|
||||
runner = Runner.from_cfg(cfg)
|
||||
self.assertIsInstance(runner.model, CustomModelWrapper)
|
||||
|
||||
def test_build_optimizer(self):
|
||||
def test_build_optim_wrapper(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_optimizer'
|
||||
cfg.experiment_name = 'test_build_optim_wrapper'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# input should be an Optimizer object or dict
|
||||
with self.assertRaisesRegex(TypeError, 'optimizer should be'):
|
||||
runner.build_optimizer('invalid-type')
|
||||
with self.assertRaisesRegex(TypeError, 'optimizer wrapper should be'):
|
||||
runner.build_optim_wrapper('invalid-type')
|
||||
|
||||
# 1. test one optimizer
|
||||
# 1.1 input is an Optimizer object
|
||||
_optimizer = SGD(runner.model.parameters(), lr=0.01)
|
||||
optimizer = runner.build_optimizer(_optimizer)
|
||||
self.assertEqual(id(_optimizer), id(optimizer))
|
||||
optimizer = SGD(runner.model.parameters(), lr=0.01)
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper)
|
||||
self.assertEqual(id(optimizer), id(optim_wrapper.optimizer))
|
||||
|
||||
# 1.2 input is a dict
|
||||
optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01))
|
||||
self.assertIsInstance(optimizer, SGD)
|
||||
optim_wrapper = runner.build_optim_wrapper(
|
||||
dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))
|
||||
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
||||
|
||||
# 2. test multiple optmizers
|
||||
# 2.1 input is a dict which contains multiple optimizer objects
|
||||
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
|
||||
optim_wrapper1 = OptimWrapper(optimizer1)
|
||||
optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02)
|
||||
optim_cfg = dict(key1=optimizer1, key2=optimizer2)
|
||||
optimizer = runner.build_optimizer(optim_cfg)
|
||||
self.assertIsInstance(optimizer, dict)
|
||||
self.assertIsInstance(optimizer['key1'], SGD)
|
||||
self.assertIsInstance(optimizer['key2'], Adam)
|
||||
optim_wrapper2 = OptimWrapper(optimizer2)
|
||||
optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
|
||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
self.assertIsInstance(optim_wrapper, OptimWrapperDict)
|
||||
self.assertIsInstance(optim_wrapper['key1'].optimizer, SGD)
|
||||
self.assertIsInstance(optim_wrapper['key2'].optimizer, Adam)
|
||||
|
||||
# 2.2 each item mush be an optimizer object when "type" and
|
||||
# "constructor" are not in optimizer
|
||||
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
|
||||
optimizer2 = dict(type='Adam', lr=0.02)
|
||||
optim_cfg = dict(key1=optimizer1, key2=optimizer2)
|
||||
optim_wrapper1 = OptimWrapper(optimizer1)
|
||||
optim_wrapper2 = dict(
|
||||
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01))
|
||||
optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'each item mush be an optimizer object'):
|
||||
runner.build_optimizer(optim_cfg)
|
||||
runner.build_optim_wrapper(optim_cfg)
|
||||
|
||||
# 2.3 input is a dict which contains multiple configs
|
||||
optim_cfg = dict(
|
||||
linear1=dict(type='SGD', lr=0.01),
|
||||
linear2=dict(type='Adam', lr=0.02),
|
||||
optim_wrapper_cfg = dict(
|
||||
linear1=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
||||
linear2=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
|
||||
constructor='ToyMultipleOptimizerConstructor')
|
||||
optimizer = runner.build_optimizer(optim_cfg)
|
||||
self.assertIsInstance(optimizer, dict)
|
||||
self.assertIsInstance(optimizer['linear1'], SGD)
|
||||
self.assertIsInstance(optimizer['linear2'], Adam)
|
||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
self.assertIsInstance(optim_wrapper, OptimWrapperDict)
|
||||
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
|
||||
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
|
||||
|
||||
def test_build_param_scheduler(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_param_scheduler'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# `build_optimizer` should be called before `build_param_scheduler`
|
||||
# `build_optim_wrapper` should be called before
|
||||
# `build_param_scheduler`
|
||||
cfg = dict(type='MultiStepLR', milestones=[1, 2])
|
||||
runner.optimizer = dict(
|
||||
key1=dict(type='SGD', lr=0.01),
|
||||
key2=dict(type='Adam', lr=0.02),
|
||||
runner.optim_wrapper = dict(
|
||||
key1=dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)),
|
||||
key2=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.02)),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, 'should be called before'):
|
||||
with self.assertRaisesRegex(AssertionError, 'should be called before'):
|
||||
runner.build_param_scheduler(cfg)
|
||||
|
||||
runner.optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01))
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(
|
||||
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
|
||||
param_schedulers = runner.build_param_scheduler(cfg)
|
||||
self.assertIsInstance(param_schedulers, list)
|
||||
self.assertEqual(len(param_schedulers), 1)
|
||||
self.assertIsInstance(param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 1. test one optimizer and one parameter scheduler
|
||||
# 1.1 input is a ParamScheduler object
|
||||
param_scheduler = MultiStepLR(runner.optimizer, milestones=[1, 2])
|
||||
param_scheduler = MultiStepLR(runner.optim_wrapper, milestones=[1, 2])
|
||||
param_schedulers = runner.build_param_scheduler(param_scheduler)
|
||||
self.assertEqual(len(param_schedulers), 1)
|
||||
self.assertEqual(id(param_schedulers[0]), id(param_scheduler))
|
||||
|
||||
# 1.2 input is a dict
|
||||
cfg = dict(type='MultiStepLR', milestones=[1, 2])
|
||||
param_schedulers = runner.build_param_scheduler(param_scheduler)
|
||||
self.assertEqual(len(param_schedulers), 1)
|
||||
self.assertIsInstance(param_schedulers[0], MultiStepLR)
|
||||
@ -715,9 +724,11 @@ class TestRunner(TestCase):
|
||||
|
||||
# 3. test multiple optimizers and list of parameter schedulers
|
||||
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
|
||||
optim_wrapper1 = OptimWrapper(optimizer1)
|
||||
optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02)
|
||||
optim_cfg = dict(key1=optimizer1, key2=optimizer2)
|
||||
runner.optimizer = runner.build_optimizer(optim_cfg)
|
||||
optim_wrapper2 = OptimWrapper(optimizer2)
|
||||
optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
cfg = [
|
||||
dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
dict(type='StepLR', step_size=1)
|
||||
@ -743,7 +754,8 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(len(param_schedulers['key2']), 2)
|
||||
|
||||
# 5. test converting epoch-based scheduler to iter-based
|
||||
runner.optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01))
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(
|
||||
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
|
||||
|
||||
# 5.1 train loop should be built before converting scheduler
|
||||
cfg = dict(
|
||||
@ -752,7 +764,7 @@ class TestRunner(TestCase):
|
||||
AssertionError,
|
||||
'Scheduler can only be converted to iter-based when '
|
||||
'train loop is built.'):
|
||||
param_schedulers = runner.build_param_scheduler(cfg)
|
||||
runner.build_param_scheduler(cfg)
|
||||
|
||||
# 5.2 convert epoch-based to iter-based scheduler
|
||||
cfg = dict(
|
||||
@ -947,7 +959,7 @@ class TestRunner(TestCase):
|
||||
cfg.experiment_name = 'test_train1'
|
||||
cfg.pop('train_dataloader')
|
||||
cfg.pop('train_cfg')
|
||||
cfg.pop('optimizer')
|
||||
cfg.pop('optim_wrapper')
|
||||
cfg.pop('param_scheduler')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
||||
@ -1063,7 +1075,7 @@ class TestRunner(TestCase):
|
||||
cfg.experiment_name = 'test_individually_val'
|
||||
cfg.pop('train_dataloader')
|
||||
cfg.pop('train_cfg')
|
||||
cfg.pop('optimizer')
|
||||
cfg.pop('optim_wrapper')
|
||||
cfg.pop('param_scheduler')
|
||||
cfg.pop('test_dataloader')
|
||||
cfg.pop('test_cfg')
|
||||
@ -1091,7 +1103,7 @@ class TestRunner(TestCase):
|
||||
cfg.experiment_name = 'test_individually_test'
|
||||
cfg.pop('train_dataloader')
|
||||
cfg.pop('train_cfg')
|
||||
cfg.pop('optimizer')
|
||||
cfg.pop('optim_wrapper')
|
||||
cfg.pop('param_scheduler')
|
||||
cfg.pop('val_dataloader')
|
||||
cfg.pop('val_cfg')
|
||||
@ -1132,15 +1144,15 @@ class TestRunner(TestCase):
|
||||
get_priority('BELOW_NORMAL'))
|
||||
|
||||
# 1.3 `hook` is a hook object
|
||||
optimizer_hook = OptimizerHook()
|
||||
runner.register_hook(optimizer_hook)
|
||||
runtime_info_hook = RuntimeInfoHook()
|
||||
runner.register_hook(runtime_info_hook)
|
||||
self.assertEqual(len(runner._hooks), 2)
|
||||
# The priority of `OptimizerHook` is `HIGH` which is greater than
|
||||
# The priority of `runtime_info_hook` is `HIGH` which is greater than
|
||||
# `IterTimerHook`, so the first item of `_hooks` should be
|
||||
# `OptimizerHook`
|
||||
self.assertTrue(isinstance(runner._hooks[0], OptimizerHook))
|
||||
# `runtime_info_hook`
|
||||
self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook))
|
||||
self.assertEqual(
|
||||
get_priority(runner._hooks[0].priority), get_priority('HIGH'))
|
||||
get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH'))
|
||||
|
||||
# 2. test `priority` parameter
|
||||
# `priority` argument is not None and it will be set as priority of
|
||||
@ -1198,29 +1210,6 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(len(runner._hooks), 8)
|
||||
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||
|
||||
def test_call_hook(self):
|
||||
# test unexpected argument in `call_hook`
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_call_hook1'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook3')]
|
||||
runner.register_custom_hooks(custom_hooks)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"got an unexpected keyword argument 'batch_idx' in "
|
||||
r'<test_runner.ToyHook3 object at \w+>'):
|
||||
runner.call_hook('before_train_iter', batch_idx=0, data_batch=None)
|
||||
|
||||
# test call hook with expected arguments
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_call_hook2'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook3')]
|
||||
runner.register_custom_hooks(custom_hooks)
|
||||
runner.call_hook('before_train_iter', data_batch=None)
|
||||
|
||||
def test_register_hooks(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_register_hooks'
|
||||
@ -1229,7 +1218,7 @@ class TestRunner(TestCase):
|
||||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook')]
|
||||
runner.register_hooks(custom_hooks=custom_hooks)
|
||||
# 7 default hooks + custom hook (ToyHook)
|
||||
# six default hooks + custom hook (ToyHook)
|
||||
self.assertEqual(len(runner._hooks), 8)
|
||||
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||
|
||||
@ -1336,7 +1325,7 @@ class TestRunner(TestCase):
|
||||
# 1.2 test `load_checkpoint`
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint2'
|
||||
cfg.optimizer = dict(type='SGD', lr=0.2)
|
||||
cfg.optim_wrapper = dict(type='SGD', lr=0.2)
|
||||
cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_checkpoint(path)
|
||||
@ -1345,22 +1334,24 @@ class TestRunner(TestCase):
|
||||
self.assertTrue(runner._has_loaded)
|
||||
# load checkpoint will not initialize optimizer and param_schedulers
|
||||
# objects
|
||||
self.assertIsInstance(runner.optimizer, dict)
|
||||
self.assertIsInstance(runner.optim_wrapper, dict)
|
||||
self.assertIsInstance(runner.param_schedulers, list)
|
||||
self.assertIsInstance(runner.param_schedulers[0], dict)
|
||||
|
||||
# 1.3 test `resume`
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint3'
|
||||
cfg.optimizer = dict(type='SGD', lr=0.2)
|
||||
cfg.optim_wrapper = dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2))
|
||||
cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.resume(path)
|
||||
self.assertEqual(runner.epoch, 3)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertEqual(runner.optimizer.param_groups[0]['lr'], 0.0001)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertEqual(runner.optim_wrapper.param_groups[0]['lr'], 0.0001)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
|
||||
|
||||
@ -1373,7 +1364,7 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(runner.epoch, 3)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 1.5 test resume from a specified checkpoint
|
||||
@ -1386,46 +1377,48 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(runner.epoch, 1)
|
||||
self.assertEqual(runner.iter, 4)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 1.6 multiple optimizers
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint6'
|
||||
cfg.optimizer = dict(
|
||||
linear1=dict(type='SGD', lr=0.01),
|
||||
linear2=dict(type='Adam', lr=0.02),
|
||||
constructor='ToyMultipleOptimizerConstructor',
|
||||
)
|
||||
cfg.optim_wrapper = dict(
|
||||
linear1=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
||||
linear2=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
|
||||
constructor='ToyMultipleOptimizerConstructor')
|
||||
# disable OptimizerHook because it only works with one optimizer
|
||||
cfg.default_hooks = dict(optimizer=None)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
path = osp.join(self.temp_dir, 'epoch_3.pth')
|
||||
self.assertTrue(osp.exists(path))
|
||||
self.assertEqual(runner.optimizer['linear1'].param_groups[0]['lr'],
|
||||
self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'],
|
||||
0.0001)
|
||||
self.assertIsInstance(runner.optimizer['linear2'], Adam)
|
||||
self.assertEqual(runner.optimizer['linear2'].param_groups[0]['lr'],
|
||||
self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam)
|
||||
self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'],
|
||||
0.0002)
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint7'
|
||||
cfg.optimizer = dict(
|
||||
linear1=dict(type='SGD', lr=0.2),
|
||||
linear2=dict(type='Adam', lr=0.03),
|
||||
constructor='ToyMultipleOptimizerConstructor',
|
||||
)
|
||||
cfg.optim_wrapper = dict(
|
||||
linear1=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)),
|
||||
linear2=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)),
|
||||
constructor='ToyMultipleOptimizerConstructor')
|
||||
cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
|
||||
cfg.default_hooks = dict(optimizer=None)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.resume(path)
|
||||
self.assertIsInstance(runner.optimizer, dict)
|
||||
self.assertIsInstance(runner.optimizer['linear1'], SGD)
|
||||
self.assertEqual(runner.optimizer['linear1'].param_groups[0]['lr'],
|
||||
self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict)
|
||||
self.assertIsInstance(runner.optim_wrapper['linear1'].optimizer, SGD)
|
||||
self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'],
|
||||
0.0001)
|
||||
self.assertIsInstance(runner.optimizer['linear2'], Adam)
|
||||
self.assertEqual(runner.optimizer['linear2'].param_groups[0]['lr'],
|
||||
self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam)
|
||||
self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'],
|
||||
0.0002)
|
||||
self.assertIsInstance(runner.param_schedulers, dict)
|
||||
self.assertEqual(len(runner.param_schedulers['linear1']), 1)
|
||||
@ -1479,7 +1472,7 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2.4 test auto resume
|
||||
@ -1491,7 +1484,7 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2.5 test resume from a specified checkpoint
|
||||
@ -1504,5 +1497,5 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
self.assertEqual(runner.iter, 3)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
Loading…
x
Reference in New Issue
Block a user