diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 6a8f7c11..c00d9dea 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -83,7 +83,7 @@ class OptimizerHook(Hook): In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. """ - runner.optimizer.zero_grad() + runner.optim_wrapper.zero_grad() if self.detect_anomalous_params: self.detect_anomalous_parameters(runner.outputs['loss'], runner) runner.outputs['loss'].backward() @@ -94,7 +94,7 @@ class OptimizerHook(Hook): # Add grad norm to the logger runner.message_hub.update_scalar('train/grad_norm', float(grad_norm)) - runner.optimizer.step() + runner.optim_wrapper.step() def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None: """Detect anomalous parameters that are not included in the graph. diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 68556464..56186ff5 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -41,13 +41,16 @@ class RuntimeInfoHook(Hook): """Update current iter and learning rate information before every iteration.""" runner.message_hub.update_info('iter', runner.iter) - if isinstance(runner.optimizer, dict): - for name, optimizer in runner.optimizer.items(): - runner.message_hub.update_scalar( - f'train/{name}.lr', optimizer.param_groups[0]['lr']) - else: - runner.message_hub.update_scalar( - 'train/lr', runner.optimizer.param_groups[0]['lr']) + lr_dict = runner.optim_wrapper.get_lr() + assert isinstance(lr_dict, dict), ( + '`runner.optim_wrapper.get_lr()` should return a dict ' + 'of learning rate when training with OptimWrapper(single ' + 'optimizer) or OptimWrapperDict(multiple optimizer), ' + f'but got {type(lr_dict)} please check your optimizer ' + 'constructor return an `OptimWrapper` or `OptimWrapperDict` ' + 'instance') + for name, lr in lr_dict.items(): + runner.message_hub.update_scalar(f'train/{name}', lr[0]) def after_train_iter(self, runner, diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py index 029b7aa3..9c6979cd 100644 --- a/mmengine/optim/__init__.py +++ b/mmengine/optim/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, - DefaultOptimizerConstructor, build_optimizer) +from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, + AmpOptimWrapper, DefaultOptimWrapperConstructor, + OptimWrapper, OptimWrapperDict, build_optim_wrapper) from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, CosineAnnealingLR, CosineAnnealingMomentum, CosineAnnealingParamScheduler, ExponentialLR, @@ -11,12 +12,12 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, StepParamScheduler, _ParamScheduler) __all__ = [ - 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer', - 'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR', + 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optim_wrapper', + 'DefaultOptimWrapperConstructor', 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler', 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', - '_ParamScheduler' + '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict' ] diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py index a74c6b8e..5d08a144 100644 --- a/mmengine/optim/optimizer/__init__.py +++ b/mmengine/optim/optimizer/__init__.py @@ -1,8 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer -from .default_constructor import DefaultOptimizerConstructor +from .amp_optimizer_wrapper import AmpOptimWrapper +from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, + build_optim_wrapper) +from .default_constructor import DefaultOptimWrapperConstructor +from .optimizer_wrapper import OptimWrapper +from .optimizer_wrapper_dict import OptimWrapperDict __all__ = [ - 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', - 'build_optimizer' + 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', + 'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper', + 'AmpOptimWrapper', 'OptimWrapperDict' ] diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py new file mode 100644 index 00000000..16823e61 --- /dev/null +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager + +import torch +from torch.cuda.amp import GradScaler + +from mmengine.registry import OPTIM_WRAPPERS +from mmengine.utils import TORCH_VERSION, digit_version +from .optimizer_wrapper import OptimWrapper + + +@OPTIM_WRAPPERS.register_module() +class AmpOptimWrapper(OptimWrapper): + """A subclass of :class:`OptimWrapper` that supports automatic mixed + precision training based on torch.cuda.amp. + + ``AmpOptimWrapper`` provides a unified interface with + ``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way + as ``OptimWrapper``. + + Warnings: + ``AmpOptimWrapper`` requires PyTorch >= 1.6. + + Args: + loss_scale (float or str or dict): The initial configuration of + `torch.cuda.amp.GradScaler`. See more specific arguments + introduction at `PyTorch AMP `_ # noqa: E501 + + - "dynamic": Initialize GradScale without any arguments. + - float: Initialize GradScaler with ``init_scale``. + - dict: Initialize GradScaler with more detail configuration. + + **kwargs: Keyword arguments passed to OptimWrapper. + """ + + def __init__(self, loss_scale=512., **kwargs): + assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( + '`torch.cuda.amp` is only available when pytorch version >= 1.6') + assert torch.cuda.is_available(), ( + '``AmpOptimizerWrapper`` is only available training on gpu') + super().__init__(**kwargs) + self._scale_update_param = None + if loss_scale == 'dynamic': + # If loss_scale is a string, it must be 'dynamic', then dynamic + # loss scaling will be used. + self.loss_scaler = GradScaler() + elif isinstance(loss_scale, float): + # Static loss scaling + self._scale_update_param = loss_scale + self.loss_scaler = GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + # More specific configuration. + self.loss_scaler = GradScaler(**loss_scale) + else: + raise TypeError('loss_scale must be of type float, dict, or ' + f'"dynamic", but got {loss_scale}') + + def backward(self, loss: torch.Tensor): + """Perform gradient back propagation with :attr:`loss_scaler`. + + Args: + loss (torch.Tensor): The loss of current iteration. + """ + self.loss_scaler.scale(loss).backward() + + def step(self): + """Update parameters with :attr:`loss_scaler`.""" + if self.clip_grad_kwargs: + self.loss_scaler.unscale_(self.optimizer) + self._clip_grad() + self.loss_scaler.step(self.optimizer) + self.loss_scaler.update(self._scale_update_param) + + def state_dict(self) -> dict: + """Get the state dictionary of :attr:`optimizer` and + :attr:`loss_scaler`. + + Based on the state dictionary of the optimizer, the returned state + dictionary will add a key named "loss_scaler". + + Returns: + dict: The merged state dict of :attr:`loss_scaler` and + :attr:`optimizer`. + """ + # save state_dict of loss_scaler + state_dict = self.optimizer.state_dict() + state_dict['loss_scaler'] = self.loss_scaler.state_dict() + return state_dict + + def load_state_dict(self, state_dict: dict): + """Load and parse the state dictionary of :attr:`optimizer` and + :attr:`loss_scaler`. + + If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will + load the corresponding keys. Otherwise, only the :attr:`optimizer` + will load the state dictionary. + + Args: + state_dict (dict): The state dict of :attr:`optimizer` and + :attr:`loss_scaler` + """ + if 'loss_scaler' in state_dict: + self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) + self.optimizer.load_state_dict(state_dict) + + @contextmanager + def precision_context(self): + """A wrapper of ``torch.cuda.amp.autocast``""" + with torch.cuda.amp.autocast(): + yield diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 31350f6f..01f26a99 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -1,12 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect -from typing import List +from typing import List, Union import torch import torch.nn as nn -from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS +from mmengine.config import Config, ConfigDict +from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS +from .optimizer_wrapper import OptimWrapper def register_torch_optimizers() -> List[str]: @@ -30,31 +32,31 @@ def register_torch_optimizers() -> List[str]: TORCH_OPTIMIZERS = register_torch_optimizers() -def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer: - """Build function of optimizer. +def build_optim_wrapper(model: nn.Module, + cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper: + """Build function of OptimWrapper. If ``constructor`` is set in the ``cfg``, this method will build an - optimizer constructor, and use optimizer constructor to build the - optimizer. If ``constructor`` is not set, the - ``DefaultOptimizerConstructor`` will be used by default. + optimizer wrapper constructor, and use optimizer wrapper constructor to + build the optimizer wrapper. If ``constructor`` is not set, the + ``DefaultOptimWrapperConstructor`` will be used by default. Args: model (nn.Module): Model to be optimized. - cfg (dict): Config of optimizer and optimizer constructor. - default_scope (str, optional): The ``default_scope`` is used to - reset the current registry. Defaults to None. + cfg (dict): Config of optimizer wrapper, optimizer constructor and + optimizer. Returns: - torch.optim.Optimizer: The built optimizer. + OptimWrapper: The built optimizer wrapper. """ - optimizer_cfg = copy.deepcopy(cfg) - constructor_type = optimizer_cfg.pop('constructor', - 'DefaultOptimizerConstructor') - paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) - optim_constructor = OPTIMIZER_CONSTRUCTORS.build( + optim_wrapper_cfg = copy.deepcopy(cfg) + constructor_type = optim_wrapper_cfg.pop('constructor', + 'DefaultOptimWrapperConstructor') + paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( dict( type=constructor_type, - optimizer_cfg=optimizer_cfg, + optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) - optimizer = optim_constructor(model) - return optimizer + optim_wrapper = optim_wrapper_constructor(model) + return optim_wrapper diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 1fb2102e..d6ed15d0 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -6,17 +6,19 @@ import torch import torch.nn as nn from torch.nn import GroupNorm, LayerNorm -from mmengine.logging.logger import print_log -from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS +from mmengine.logging import print_log +from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, + OPTIMIZERS) from mmengine.utils import is_list_of, mmcv_full_available from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from .optimizer_wrapper import OptimWrapper -@OPTIMIZER_CONSTRUCTORS.register_module() -class DefaultOptimizerConstructor: +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class DefaultOptimWrapperConstructor: """Default constructor for optimizers. - By default each parameter share the same optimizer settings, and we + By default, each parameter share the same optimizer settings, and we provide an argument ``paramwise_cfg`` to specify parameter-wise settings. It is a dict and may contain the following fields: @@ -62,49 +64,65 @@ class DefaultOptimizerConstructor: model contains multiple DCN layers in places other than backbone. Args: - optimizer_cfg (dict): The config dict of the optimizer. + optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. Positional fields are - - `type`: class name of the optimizer. + - ``type``: class name of the OptimizerWrapper + - ``optimizer``: The configuration of optimizer. Optional fields are - - any arguments of the corresponding optimizer type, e.g., - lr, weight_decay, momentum, etc. + - any arguments of the corresponding optimizer wrapper type, + e.g., accumulative_iters, clip_grad, etc. + + The positional fields of ``optimizer`` are + + - `type`: class name of the optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + paramwise_cfg (dict, optional): Parameter-wise options. Example 1: >>> model = torch.nn.modules.Conv1d(1, 1, 1) - >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, - >>> weight_decay=0.0001) + >>> optim_wrapper_cfg = dict( + >>> dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01, + >>> momentum=0.9, weight_decay=0.0001)) >>> paramwise_cfg = dict(norm_decay_mult=0.) - >>> optim_builder = DefaultOptimizerConstructor( - >>> optimizer_cfg, paramwise_cfg) - >>> optimizer = optim_builder(model) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) Example 2: >>> # assume model have attribute model.backbone and model.cls_head - >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) + >>> optim_wrapper_cfg = dict(type=OptimWrapper, optimizer=dict( + >>> type='SGD', lr=0.01, weight_decay=0.95)) >>> paramwise_cfg = dict(custom_keys={ - '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) - >>> optim_builder = DefaultOptimizerConstructor( - >>> optimizer_cfg, paramwise_cfg) - >>> optimizer = optim_builder(model) + >>> '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) >>> # Then the `lr` and `weight_decay` for model.backbone is >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for >>> # model.cls_head is (0.01, 0.95). """ def __init__(self, - optimizer_cfg: dict, + optim_wrapper_cfg: dict, paramwise_cfg: Optional[dict] = None): - if not isinstance(optimizer_cfg, dict): + if not isinstance(optim_wrapper_cfg, dict): raise TypeError('optimizer_cfg should be a dict', - f'but got {type(optimizer_cfg)}') - self.optimizer_cfg = optimizer_cfg + f'but got {type(optim_wrapper_cfg)}') + assert 'optimizer' in optim_wrapper_cfg, ( + '`optim_wrapper_cfg` must contain "optimizer" config') + self.optim_wrapper_cfg = optim_wrapper_cfg.copy() + self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer') self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg - self.base_lr = optimizer_cfg.get('lr', None) - self.base_wd = optimizer_cfg.get('weight_decay', None) + self.base_lr = self.optimizer_cfg.get('lr', None) + self.base_wd = self.optimizer_cfg.get('weight_decay', None) self._validate_cfg() def _validate_cfg(self) -> None: @@ -249,19 +267,23 @@ class DefaultOptimizerConstructor: prefix=child_prefix, is_dcn_module=is_dcn_module) - def __call__(self, model: nn.Module) -> torch.optim.Optimizer: + def __call__(self, model: nn.Module) -> OptimWrapper: if hasattr(model, 'module'): model = model.module + optim_wrapper_cfg = self.optim_wrapper_cfg.copy() + optim_wrapper_cfg.setdefault('type', 'OptimWrapper') optimizer_cfg = self.optimizer_cfg.copy() # if no paramwise option is specified, just use the global setting if not self.paramwise_cfg: optimizer_cfg['params'] = model.parameters() - return OPTIMIZERS.build(optimizer_cfg) - - # set param-wise lr and weight decay recursively - params: List = [] - self.add_params(params, model) - optimizer_cfg['params'] = params - - return OPTIMIZERS.build(optimizer_cfg) + optimizer = OPTIMIZERS.build(optimizer_cfg) + else: + # set param-wise lr and weight decay recursively + params: List = [] + self.add_params(params, model) + optimizer_cfg['params'] = params + optimizer = OPTIMIZERS.build(optimizer_cfg) + optim_wrapper = OPTIM_WRAPPERS.build( + optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) + return optim_wrapper diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py new file mode 100644 index 00000000..d71249bc --- /dev/null +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -0,0 +1,349 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch.nn.utils import clip_grad +from torch.optim import Optimizer + +from mmengine.logging import MessageHub, MMLogger +from mmengine.registry import OPTIM_WRAPPERS +from mmengine.utils import has_batch_norm + + +@OPTIM_WRAPPERS.register_module() +class OptimWrapper: + """Optimizer wrapper provides a common interface for updating parameters. + + Optimizer wrapper provides a unified interface for single precision + training and automatic mixed precision training with different hardware. + OptimWrapper encapsulates optimizer to provide simplified interfaces + for commonly used training techniques such as gradient accumulative and + grad clips. ``OptimWrapper`` implements the basic logic of gradient + accumulation and gradient clipping based on ``torch.optim.Optimizer``. + The subclasses only need to override some methods to implement the mixed + precision training. See more information in :class:`AmpOptimWrapper`. + + Args: + optimizer (Optimizer): Optimizer used to update model parameters. + accumulative_iters (int): The number of iterations to accumulate + gradients. The parameters will be updated per + ``accumulative_iters``. + clip_grad (dict, optional): If ``clip_grad`` is not None, it will be + the arguments of ``torch.nn.utils.clip_grad``. + + Warnings: + If ``accumulative_iters`` is larger than 1, :meth:`update_params` must + be called in the context of ``accumulate_grad``. + + Examples: + >>> # Config sample of OptimWrapper. + >>> optim_wrapper_cfg = dict( + >>> type='OptimWrapper', + >>> accumulative_iters=3, + >>> clip_grad=dict(max_norm=0.2)) + >>> # Use OptimWrapper to update model. + >>> import torch.nn as nn + >>> import torch + >>> from torch.optim import SGD + >>> from torch.utils.data import DataLoader + >>> from mmengine.optim import OptimWrapper + >>> + >>> model = nn.Linear(1, 1) + >>> dataset = torch.randn(10, 1, 1) + >>> dataloader = DataLoader(dataset) + >>> optimizer = SGD(model.parameters(), lr=0.1) + >>> optim_wrapper = OptimWrapper(optimizer) + >>> + >>> for data in dataloader: + >>> loss = model(data) + >>> optim_wrapper.update_params(loss) + >>> # Enable gradient accumulation. If model is a subclass instance of + >>> # DistributedDataParallel, ``accumulate_grad`` context manager can + >>> # avoid unnecessary gradient synchronize. + >>> for iter, data in enumerate(dataloader): + >>> with optim_wrapper.accumulate_grad( + >>> model, iter, len(dataloader)): + >>> loss = model(data) + >>> optim_wrapper.update_params(loss) + """ + + def __init__(self, + optimizer: Optimizer, + accumulative_iters: int = 1, + clip_grad: Optional[dict] = None): + assert accumulative_iters > 0, ( + 'accumulative_iters at least greater than or equal to 1') + self.accumulative_iters = accumulative_iters + # `max_iters` and `cur_iter` is only valid in gradient accumulative + # mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be + # updated in the ``accumulate_grad`` context that is enabled in + # `runner.train_loop`. + self.cur_iter = 0 + self.max_iters = 0 + assert isinstance(optimizer, Optimizer), ( + 'optimizer must be a `torch.optim.Optimizer` instance, but got ' + f'{type(optimizer)}') + self.optimizer = optimizer + + if clip_grad is not None: + # clip_grad_kwargs should not be non-empty dict. + assert isinstance(clip_grad, dict) and clip_grad, ( + 'If `clip_grad_kwargs` is not None, it should be a `dict` ' + 'which is the arguments of `torch.nn.utils.clip_grad`') + self.clip_grad_kwargs = clip_grad + self.logger = MMLogger.get_current_instance() + # Used to update `grad_norm` log message. + self.message_hub = MessageHub.get_current_instance() + self.iter_status_initialized = False + + def update_params(self, loss: torch.Tensor) -> None: + """Update parameters in :attr:`optimizer`. + + Args: + loss (torch.Tensor): A tensor for back propagation. + """ + if self.accumulative_iters == 1: + # update parameters without gradient accumulation. The gradient + # should not be rescaled and `loss_factor=1`. + loss_factor = 1 + else: + # gradient accumulation must be called in the context of + # ``accumulate_grad``. + assert hasattr(self, 'divisible_iters'), ( + 'gradient accumulation must be performed in the context of' + '`OptimWrapper.accumulate_grad`') + # if `self.accumulative_iters > 1`, the gradient needs to be + # rescaled and accumulated. In most cases, `loss_factor` equals to + # `self.accumulative_iters`. However `self.max_iters` may not be + # divisible `self.by accumulative_iters`, so the `loss_scale` for + # the last few iterations needs to be recalculated. + if self.cur_iter < self.divisible_iters: + loss_factor = self.accumulative_iters + else: + loss_factor = self.remainder_iters + assert loss_factor > 0, ( + 'loss_factor should be larger than zero! This error could ' + 'happened when gradient accumulation context enabled with an ' + 'error `cur_iter` or `max_iters` please check your loop') + + loss = loss / loss_factor + self.backward(loss) + # Update parameters only if `self.cur_iter` is divisible by + # `self.accumulative_iters` or `self.cur_iter` equals to + # `self.max_iters` + if self._should_update(self.cur_iter, self.max_iters): + self.step() + self.zero_grad() + + def backward(self, loss: torch.Tensor) -> None: + """Perform gradient back propagation. + + Provide unified ``backward`` interface compatible with automatic mixed + precision training. Subclass can overload this method to implement the + required logic. For example, ``torch.cuda.amp`` require some extra + operation on GradScaler during backward process. + + Args: + loss (torch.Tensor): The loss of current iteration. + """ + loss.backward() + + def zero_grad(self) -> None: + """A wrapper of ``Optimizer.zero_grad``. + + Provide unified ``zero_grad`` interface compatible with automatic mixed + precision training. Subclass can overload this method to implement the + required logic. + """ + self.optimizer.zero_grad() + + def step(self) -> None: + """A wrapper of ``Optimizer.step``. + + Provide unified ``step`` interface compatible with automatic mixed + precision training. Subclass can overload this method to implement the + required logic. For example, ``torch.cuda.amp`` require some extra + operation on ``GradScaler`` during step process. + + Clip grad if :attr:`clip_grad_kwargs` is not None, and then update + parameters. + """ + if self.clip_grad_kwargs: + self._clip_grad() + self.optimizer.step() + + def state_dict(self) -> dict: + """A wrapper of ``Optimizer.state_dict``. + + Provide unified ``state_dict`` interface compatible with automatic + mixed precision training. Subclass can overload this method to + implement the required logic. For example, the state dictionary of + GradScaler should be saved when training with ``torch.cuda.amp``. + + Returns: + dict: The state dictionary of :attr:`optimizer`. + """ + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """A wrapper of ``Optimizer.load_state_dict``. load the state dict of + :attr:`optimizer`. + + Provide unified ``load_state_dict`` interface compatible with automatic + mixed precision training. Subclass can overload this method to + implement the required logic. For example, the state dictionary of + GradScaler should be loaded when training with ``torch.cuda.amp``. + + Args: + state_dict (dict): The state dictionary of :attr:`optimizer`. + """ + self.optimizer.load_state_dict(state_dict) + + @property + def param_groups(self) -> List[dict]: + """A wrapper of ``Optimizer.param_groups``. + + Make OptimizeWrapper compatible with :class:`_ParamScheduler`. + + Returns: + dict: the ``param_groups`` of :attr:`optimizer`. + """ + return self.optimizer.param_groups + + def get_lr(self) -> Dict[str, List[float]]: + """Get the learning rate of the optimizer. + + Provide unified interface to get learning rate of optimizer. + + Returns: + List[float]: Learning rate of the optimizer. + """ + lr = [group['lr'] for group in self.param_groups] + return dict(lr=lr) + + def get_momentum(self) -> Dict[str, List[float]]: + """Get the momentum of the optimizer. + + Provide unified interface to get momentum of optimizer. + + Returns: + List[float]: Momentum of the optimizer. + """ + momentum = [] + for group in self.param_groups: + # Get momentum of SGD. + if 'momentum' in group.keys(): + momentum.append(group['momentum']) + # Get momentum of Adam. + elif 'betas' in group.keys(): + momentum.append(group['betas'][0]) + else: + momentum.append(0) + return dict(momentum=momentum) + + @contextmanager + def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int): + """A Context manager for gradient accumulation and avoiding unnecessary + gradient synchronization during gradient accumulation. + + If model is an instance with ``no_sync`` method (which means + blocking the gradient synchronization) and + ``self.accumulative_iters != 1``. The model will not automatically + synchronize gradients if ``cur_iter`` is divisible by + ``self.accumulative_iters``. Otherwise, this method will enable an + empty context. + + Warnings: + This context manager must be enabled if you want to use + gradient accumulation. + + Args: + model (nn.Module): The training model. + cur_iter (int): Current iteration during training process. + max_iters (int): Maximum training iteration. + """ + assert max_iters > 0, '`max_iters` must be larger than zero' + self.cur_iter = cur_iter + self.max_iters = max_iters + if not self.iter_status_initialized: + self._initilize_iter_status(model) + # During gradient accumulation process, the gradient synchronize + # should only happen before updating parameters. + if (not self._should_update(cur_iter, max_iters) + and hasattr(model, 'no_sync')): + with model.no_sync(): + yield + else: + yield + + @contextmanager + def precision_context(self): + """precision context which enables an empty context by default. + + The subclass used for mixed or low precision training needs to override + this method. + """ + yield + + def _clip_grad(self) -> None: + """Clip the gradients of parameters.""" + params: List[torch.Tensor] = [] + for param_group in self.optimizer.param_groups: + params.extend(param_group['params']) + + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + grad_norm = clip_grad.clip_grad_norm_(params, + **self.clip_grad_kwargs) + self.message_hub.update_scalar('train/grad_norm', float(grad_norm)) + + def _initilize_iter_status(self, model: nn.Module) -> None: + """Initialize gradient accumulation related attributes. + + Args: + model (nn.Module): Training model + """ + if self.max_iters % self.accumulative_iters != 0: + self.logger.warning( + 'Resume iter number is not divisible by accumulative_iters in ' + 'GradientCumulativeOptimizerHook, which means the gradient of ' + 'some iters is lost and the result may be influenced slightly.' + ) + + if has_batch_norm(model) and self.accumulative_iters > 1: + self.logger.warning( + 'Gradient accumulative may slightly decrease ' + 'performance because the model has BatchNorm layers.') + residual_iters = self.max_iters - self.cur_iter + # The maximum number of training iteration that is divisible by + # accumulative_iters. + self.divisible_iters = ( + residual_iters // self.accumulative_iters * + self.accumulative_iters) + # Remainder of ``self.max_iters`` divided by ``self.max_iters`` + self.remainder_iters = residual_iters - self.divisible_iters + self.iter_status_initialized = True + + def _should_update(self, cur_iter: int, max_iters: int) -> bool: + """Should optim_wrapper update parameters or synchronized gradient at + current iteration. + + Args: + cur_iter (int): Current iteration of training process. + max_iters (int): Maximum iterations of training process. + + Returns: + bool: Whether to update parameters or synchronized gradient. + """ + return ((cur_iter + 1) % self.accumulative_iters == 0 + or cur_iter + 1 == max_iters) + + def __repr__(self): + wrapper_info = f'Type: {type(self).__name__}\n' \ + f'accumulative_iters: {self.accumulative_iters}\n' \ + f'optimizer: \n' + optimizer_str = repr(self.optimizer) + '\n' + return wrapper_info + optimizer_str diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py new file mode 100644 index 00000000..98293e9d --- /dev/null +++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from contextlib import ExitStack, contextmanager +from typing import Dict, Iterator, List, Tuple + +import torch +import torch.nn as nn + +from .optimizer_wrapper import OptimWrapper + + +class OptimWrapperDict(OptimWrapper): + """A dictionary container of :obj:`OptimWrapper`. + + If runner is training with multiple optimizers, all optimizer wrappers + should be managed by :obj:`OptimWrapperDict` which is built by + ``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save + the state dictionary of all optimizer wrappers. + + Consider the semantic ambiguity of calling :meth:``update_params``, + :meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not + implement these methods. + + Examples: + >>> import torch.nn as nn + >>> from torch.optim import SGD + >>> from mmengine.optim import OptimWrapperDict, OptimWrapper + >>> model1 = nn.Linear(1, 1) + >>> model2 = nn.Linear(1, 1) + >>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1)) + >>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1)) + >>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1, + >>> model2=optim_wrapper2) + + Note: + The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed + in the same way as `dict`. + + Args: + **optim_wrappers: A dictionary of ``OptimWrapper`` instance. + """ + + def __init__(self, **optim_wrapper_dict: OptimWrapper): + first_key = next(iter(optim_wrapper_dict)) + first_optim_wrapper = optim_wrapper_dict[first_key] + assert isinstance(first_optim_wrapper, OptimWrapper), ( + 'Each argument of `OptimWrapperDict` must be an `OptimWrapper ' + 'instance`') + optim_wrapper_class = type(first_optim_wrapper) + for key, value in optim_wrapper_dict.items(): + assert type(value) == optim_wrapper_class, ( + f'All optimizer wrappers should have the same type, but found' + f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}' + ) + if value.accumulative_iters != 1: + warnings.warn( + f'The `accumulative_iters` of {key} is ' + f'{value.accumulative_iters}. OptimWrapperDict ' + 'will not enable any `accumulate_grad` context of its ' + 'optimizer wrappers. You should access the corresponding ' + 'optimizer wrapper to enable the context.') + self.optim_wrappers = optim_wrapper_dict + + def update_params(self, loss: torch.Tensor) -> None: + """Update all optimizer wrappers would lead to a duplicate backward + errors, and OptimWrapperDict does not know which optimizer wrapper + should be updated. + + Therefore, this method is not implemented. The optimizer wrapper of + OptimWrapperDict should be accessed and call its `update_params. + """ + raise NotImplementedError( + 'You should access the OptimWrapper of the ' + 'OptimWrapperDict and call its `update_params`') + + def backward(self, loss: torch.Tensor) -> None: + """Since OptimWrapperDict doesn't know which optimizer wrapper's + backward method should be called (``loss_scaler`` maybe different in + different :obj:AmpOptimWrapper), this method is not implemented. + + The optimizer wrapper of OptimWrapperDict should be accessed and call + its `backward. + """ + raise NotImplementedError('You should access the OptimWrapper of the ' + 'OptimWrapperDict and call its `backward`') + + def step(self) -> None: + """Since the backward method is not implemented, the step should not be + implemented either.""" + raise NotImplementedError('You should access the OptimWrapper of the ' + 'OptimWrapperDict and call its `step`') + + def zero_grad(self) -> None: + """Set the gradients of all optimizer wrappers to zero.""" + for optim_wrapper in self.optim_wrappers.values(): + optim_wrapper.zero_grad() + + @contextmanager + def precision_context(self): + optim_wrapper = next(iter(self.optim_wrappers.values())) + with optim_wrapper.precision_context(): + yield + + @contextmanager + def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int): + """Enable ``accumulate_grad`` contexts of all optimizer wrappers. + + Warning: + Consider there is only one ``model`` arguments for all + optimizer wrappers, all optimizer wrappers are working under the + same ``model.no_sync`` context. For example, there is a model + composed of model_a(optimizer_a) and model_b(optimizer_b). + ``OptimWrapperDict.accumulate_grad`` will further + call ``model.no_sync``, which will block the gradient + synchronization of both a and b. If optimizer_a and + optimizer_b have different ``accumulative_iters``, and want to + block the gradient synchronization of model_a and model_b + separately, the model should not implement the ``no_sync`` + method(or enable an empty context). The ``accumulate_grad`` context + should be enabled inside the model by accessing corresponding + optimizer wrapper. + """ + with ExitStack() as stack: + for optim_wrapper in self.optim_wrappers.values(): + stack.enter_context( + optim_wrapper.accumulate_grad(model, cur_iter, max_iters)) + yield + + def load_state_dict(self, state_dict: dict) -> None: + """Load the state dictionary from the ``state_dict``. + + Args: + state_dict (dict): Each key-value pair in `state_dict` represents + the name and the state dictionary of corresponding + :obj:`OptimWrapper`. + """ + for name, _state_dict in state_dict.items(): + assert name in self.optim_wrappers, ( + f'Mismatched `state_dict`! cannot found {name} in ' + 'OptimWrapperDict') + self.optim_wrappers[name].load_state_dict(_state_dict) + + def get_lr(self) -> Dict[str, List[float]]: + """Get the learning rate of all optimizers. + + Returns: + Dict[str, List[float]]: Learning rate of all optimizers. + """ + lr_dict = dict() + for name, optim_wrapper in self.optim_wrappers.items(): + lr_dict[f'{name}.lr'] = optim_wrapper.get_lr()['lr'] + return lr_dict + + def get_momentum(self) -> Dict[str, List[float]]: + """Get the momentum of all optimizers. + + Returns: + Dict[str, List[float]]: momentum of all optimizers. + """ + momentum_dict = dict() + for name, optim_wrapper in self.optim_wrappers.items(): + momentum_dict[f'{name}.momentum'] = optim_wrapper.get_momentum( + )['momentum'] + return momentum_dict + + def state_dict(self) -> dict: + """Get the state dictionary of all optimizer wrappers. + + Returns: + dict: Each key-value pair in the dictionary represents the name + and state dictionary of corresponding :obj:`OptimWrapper`. + """ + state_dict = dict() + for name, optim_wrapper in self.optim_wrappers.items(): + state_dict[name] = optim_wrapper.state_dict() + return state_dict + + def items(self) -> Iterator[Tuple[str, OptimWrapper]]: + """A generator to get the name and corresponding + :obj:`OptimWrapper`""" + yield from self.optim_wrappers.items() + + def values(self) -> Iterator[OptimWrapper]: + """A generator to get :obj:`OptimWrapper`""" + yield from self.optim_wrappers.values() + + def keys(self) -> Iterator[str]: + """A generator to get the name of :obj:`OptimWrapper`""" + yield from self.optim_wrappers.keys() + + def __getitem__(self, key: str) -> OptimWrapper: + assert key in self.optim_wrappers, ( + f'Cannot find {key} in OptimWrapperDict, please check ' + 'your optimizer constructor.') + return self.optim_wrappers[key] + + def __contains__(self, key: str) -> bool: + return key in self.optim_wrappers + + def __len__(self) -> int: + return len(self.optim_wrappers) + + def __repr__(self) -> str: + desc = '' + for name, optim_wrapper in self.optim_wrappers.items(): + desc += f'name: {name}\n' + desc += repr(optim_wrapper) + return desc diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index 69ba9f28..11e3ee58 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -22,7 +22,7 @@ class ConstantLR(LRSchedulerMixin, ConstantParamScheduler): changes to the learning rate value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. factor (float): The number we multiply learning rate until the milestone. Defaults to 1./3. begin (int): Step at which to start updating the learning rate. @@ -68,7 +68,7 @@ class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler): only implements the cosine annealing part of SGDR, and not the restarts. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. T_max (int): Maximum number of iterations. eta_min (float): Minimum learning rate. Defaults to 0. begin (int): Step at which to start updating the learning rate. @@ -92,7 +92,7 @@ class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler): """Decays the learning rate of each parameter group by gamma every epoch. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. gamma (float): Multiplicative factor of learning rate decay. begin (int): Step at which to start updating the learning rate. Defaults to 0. @@ -116,7 +116,7 @@ class LinearLR(LRSchedulerMixin, LinearParamScheduler): Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. start_factor (float): The number we multiply learning rate in the first epoch. The multiplication factor changes towards end_factor in the following epochs. Defaults to 1./3. @@ -143,7 +143,7 @@ class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler): outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. milestones (list): List of epoch indices. Must be increasing. gamma (float): Multiplicative factor of learning rate decay. Defaults to 0.1. @@ -167,7 +167,7 @@ class StepLR(LRSchedulerMixin, StepParamScheduler): other changes to the learning rate from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. step_size (int): Period of learning rate decay. gamma (float): Multiplicative factor of learning rate decay. Defaults to 0.1. @@ -193,7 +193,7 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler): parameter value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. eta_min (float): Minimum learning rate at the end of scheduling. Defaults to 0. power (float): The power of the polynomial. Defaults to 1.0. diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index 5c789b2c..59c38af7 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -22,7 +22,8 @@ class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler): momentum value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. factor (float): The number we multiply momentum until the milestone. Defaults to 1./3. begin (int): Step at which to start updating the momentum. @@ -69,7 +70,8 @@ class CosineAnnealingMomentum(MomentumSchedulerMixin, only implements the cosine annealing part of SGDR, and not the restarts. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. T_max (int): Maximum number of iterations. eta_min (float): Minimum momentum value. Defaults to 0. begin (int): Step at which to start updating the momentum. @@ -93,7 +95,8 @@ class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler): """Decays the momentum of each parameter group by gamma every epoch. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. gamma (float): Multiplicative factor of momentum value decay. begin (int): Step at which to start updating the momentum. Defaults to 0. @@ -117,7 +120,8 @@ class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler): Notice that such decay can happen simultaneously with other changes to the momentum from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. start_factor (float): The number we multiply momentum in the first epoch. The multiplication factor changes towards end_factor in the following epochs. Defaults to 1./3. @@ -144,7 +148,8 @@ class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler): scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. milestones (list): List of epoch indices. Must be increasing. gamma (float): Multiplicative factor of momentum value decay. Defaults to 0.1. @@ -168,7 +173,8 @@ class StepMomentum(MomentumSchedulerMixin, StepParamScheduler): to the momentum from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. step_size (int): Period of momentum value decay. gamma (float): Multiplicative factor of momentum value decay. Defaults to 0.1. @@ -194,7 +200,8 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler): parameter value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. eta_min (float): Minimum momentum at the end of scheduling. Defaults to 0. power (float): The power of the polynomial. Defaults to 1.0. diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 4c9a2732..93a5afce 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -4,14 +4,17 @@ import warnings import weakref from collections import Counter from functools import wraps -from typing import Callable, List +from typing import Callable, List, Union from torch.optim import Optimizer +from mmengine.optim import OptimWrapper from mmengine.registry import PARAM_SCHEDULERS INF = int(1e9) +OptimizerType = Union[OptimWrapper, Optimizer] + class _ParamScheduler: """Base class for parameter schedulers. @@ -23,7 +26,7 @@ class _ParamScheduler: https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (OptimWrapper or Optimizer): Wrapped optimizer. param_name (str): Name of the parameter to be adjusted, such as ``lr``, ``momentum``. begin (int): Step at which to start updating the parameters. @@ -40,7 +43,7 @@ class _ParamScheduler: """ # noqa: E501 def __init__(self, - optimizer: Optimizer, + optimizer: OptimizerType, param_name: str, begin: int = 0, end: int = INF, @@ -49,7 +52,7 @@ class _ParamScheduler: verbose: bool = False): # Attach optimizer - if not isinstance(optimizer, Optimizer): + if not isinstance(optimizer, (Optimizer, OptimWrapper)): raise TypeError('``optimizer`` should be an Optimizer,' 'but got {}'.format(type(optimizer).__name__)) self.optimizer = optimizer @@ -111,8 +114,8 @@ class _ParamScheduler: return wrapper # add counter to optimizer - self.optimizer.step = with_counter(self.optimizer.step) - self.optimizer._global_step = -1 + self.optimizer.step = with_counter(self.optimizer.step) # type: ignore + self.optimizer._global_step = -1 # type: ignore self._global_step = -1 self.verbose = verbose @@ -218,7 +221,7 @@ class StepParamScheduler(_ParamScheduler): other changes to the parameter value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (OptimWrapper or Optimizer): Wrapped optimizer. step_size (int): Period of parameter value decay. gamma (float): Multiplicative factor of parameter value decay. Defaults to 0.1. @@ -235,7 +238,7 @@ class StepParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: OptimizerType, param_name: str, step_size: int, gamma: float = 0.1, @@ -304,7 +307,7 @@ class MultiStepParamScheduler(_ParamScheduler): scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (OptimWrapper or Optimizer): Wrapped optimizer. milestones (list): List of epoch indices. Must be increasing. gamma (float): Multiplicative factor of parameter value decay. Defaults to 0.1. @@ -321,7 +324,7 @@ class MultiStepParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: OptimizerType, param_name: str, milestones: List[int], gamma: float = 0.1, @@ -391,7 +394,8 @@ class ConstantParamScheduler(_ParamScheduler): parameter value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. factor (float): The number we multiply parameter value until the milestone. Defaults to 1./3. begin (int): Step at which to start updating the parameters. @@ -407,7 +411,7 @@ class ConstantParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: OptimizerType, param_name: str, factor: float = 1.0 / 3, begin: int = 0, @@ -477,7 +481,8 @@ class ExponentialParamScheduler(_ParamScheduler): """Decays the parameter value of each parameter group by gamma every epoch. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. gamma (float): Multiplicative factor of parameter value decay. begin (int): Step at which to start updating the parameters. Defaults to 0. @@ -492,7 +497,7 @@ class ExponentialParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: OptimizerType, param_name: str, gamma: float, begin: int = 0, @@ -573,7 +578,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler): only implements the cosine annealing part of SGDR, and not the restarts. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. T_max (int): Maximum number of iterations. eta_min (float): Minimum parameter value. Defaults to 0. begin (int): Step at which to start updating the parameters. @@ -592,7 +598,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: Union[Optimizer, OptimWrapper], param_name: str, T_max: int, eta_min: float = 0., @@ -670,7 +676,8 @@ class LinearParamScheduler(_ParamScheduler): parameter value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. start_factor (float): The number we multiply parameter value in the first epoch. The multiplication factor changes towards end_factor in the following epochs. Defaults to 1./3. @@ -689,7 +696,7 @@ class LinearParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: Union[Optimizer, OptimWrapper], param_name: str, start_factor: float = 1.0 / 3, end_factor: float = 1.0, @@ -765,7 +772,8 @@ class PolyParamScheduler(_ParamScheduler): parameter value from outside this scheduler. Args: - optimizer (Optimizer): Wrapped optimizer. + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. eta_min (float): Minimum parameter value at the end of scheduling. Defaults to 0. power (float): The power of the polynomial. Defaults to 1.0. @@ -782,7 +790,7 @@ class PolyParamScheduler(_ParamScheduler): """ def __init__(self, - optimizer: Optimizer, + optimizer: Union[Optimizer, OptimWrapper], param_name: str, eta_min: float = 0, power: float = 1.0, diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index a67532b2..b43fbb2b 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -2,17 +2,17 @@ from .default_scope import DefaultScope from .registry import Registry, build_from_cfg from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, - METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, - OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, - TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, - WEIGHT_INITIALIZERS) + METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, + OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS, + RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, + VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) from .utils import count_registered_modules, traverse_registry_tree __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', - 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', - 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', - 'LOG_PROCESSORS', 'DefaultScope', 'traverse_registry_tree', - 'count_registered_modules' + 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', + 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS', + 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope', + 'traverse_registry_tree', 'count_registered_modules' ] diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index d3a47e47..5860167a 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -32,7 +32,7 @@ WEIGHT_INITIALIZERS = Registry('weight initializer') # mangage all kinds of optimizers like `SGD` and `Adam` OPTIMIZERS = Registry('optimizer') # manage constructors that customize the optimization hyperparameters. -OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor') +OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor') # mangage all kinds of parameter schedulers like `MultiStepLR` PARAM_SCHEDULERS = Registry('parameter scheduler') # manage all kinds of metrics @@ -48,3 +48,6 @@ VISBACKENDS = Registry('vis_backend') # manage logprocessor LOG_PROCESSORS = Registry('log_processor') + +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry('optim_wrapper') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 06ecf071..75d0aaf6 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -6,6 +6,7 @@ import random import shutil import time import warnings +from collections import OrderedDict from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Union @@ -25,7 +26,8 @@ from mmengine.evaluator import Evaluator from mmengine.hooks import Hook from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import is_model_wrapper -from mmengine.optim import _ParamScheduler, build_optimizer +from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, + build_optim_wrapper) from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, VISUALIZERS, DefaultScope, @@ -44,6 +46,7 @@ from .priority import Priority, get_priority ConfigType = Union[Dict, Config, ConfigDict] ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, List[_ParamScheduler]]] +OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] class Runner: @@ -97,10 +100,14 @@ class Runner: If ``test_cfg`` specified, :attr:`test_dataloader` should also be specified. Defaults to None. See :meth:`build_test_loop` for more details. - optimizer (Optimizer or dict, optional): Computing gradient of model - parameters. If specified, :attr:`train_dataloader` should also be - specified. Defaults to None. - See :meth:`build_optimizer` for examples. + optim_wrapper (OptimWrapper or dict, optional): + Computing gradient of model parameters. If specified, + :attr:`train_dataloader` should also be specified. If automatic + mixed precision or gradient accmulation + training is required. The type of ``optim_wrapper`` should be + AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for + examples. Defaults to None. + param_scheduler (_ParamScheduler or dict or list, optional): Parameter scheduler for updating optimizer parameters. If specified, :attr:`optimizer` should also be specified. @@ -177,7 +184,8 @@ class Runner: >>> sampler=dict(type='DefaultSampler', shuffle=False), >>> batch_size=1, >>> num_workers=0), - >>> optimizer=dict(type='SGD', lr=0.01), + >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( + >>> type='SGD', lr=0.01)), >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), >>> val_evaluator=dict(type='ToyEvaluator'), >>> test_evaluator=dict(type='ToyEvaluator'), @@ -217,7 +225,7 @@ class Runner: train_cfg: Optional[Dict] = None, val_cfg: Optional[Dict] = None, test_cfg: Optional[Dict] = None, - optimizer: Optional[Union[Optimizer, Dict]] = None, + optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None, param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None, val_evaluator: Optional[Union[Evaluator, Dict, List]] = None, test_evaluator: Optional[Union[Evaluator, Dict, List]] = None, @@ -249,7 +257,7 @@ class Runner: self.cfg = Config(dict()) # lazy initialization - training_related = [train_dataloader, train_cfg, optimizer] + training_related = [train_dataloader, train_cfg, optim_wrapper] if not (all(item is None for item in training_related) or all(item is not None for item in training_related)): raise ValueError( @@ -257,14 +265,16 @@ class Runner: 'all None or not None, but got ' f'train_dataloader={train_dataloader}, ' f'train_cfg={train_cfg}, ' - f'optimizer={optimizer}.') + f'optim_wrapper={optim_wrapper}.') self._train_dataloader = train_dataloader self._train_loop = train_cfg - self.optimizer = optimizer + + self.optim_wrapper: Optional[Union[OptimWrapper, dict]] + self.optim_wrapper = optim_wrapper # If there is no need to adjust learning rate, momentum or other # parameters of optimizer, param_scheduler can be None - if param_scheduler is not None and self.optimizer is None: + if param_scheduler is not None and self.optim_wrapper is None: raise ValueError( 'param_scheduler should be None when optimizer is None, ' f'but got {param_scheduler}') @@ -400,7 +410,7 @@ class Runner: train_cfg=cfg.get('train_cfg'), val_cfg=cfg.get('val_cfg'), test_cfg=cfg.get('test_cfg'), - optimizer=cfg.get('optimizer'), + optim_wrapper=cfg.get('optim_wrapper'), param_scheduler=cfg.get('param_scheduler'), val_evaluator=cfg.get('val_evaluator'), test_evaluator=cfg.get('test_evaluator'), @@ -803,21 +813,25 @@ class Runner: return model - def build_optimizer( - self, optimizer: Union[Optimizer, Dict] - ) -> Union[Optimizer, Dict[str, Optimizer]]: - """Build an optimizer or multiple optimizers. + def build_optim_wrapper( + self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict] + ) -> Union[OptimWrapper, OptimWrapperDict]: + """Build optimizer wrapper. Args: - optimizer (Optimizer or dict): An Optimizer object or a dict to - build Optimizer objects. If ``optimizer`` is an Optimizer - object, just returns itself. + optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a + dict to build OptimWrapper objects. If ``optim_wrapper`` is an + OptimWrapper, just return an ``OptimizeWrapper`` instance. Examples: >>> # build an optimizer - >>> optim_cfg = dict(type='SGD', lr=0.01) - >>> optimizer = runner.build_optimizer(optim_cfg) - >>> optimizer + >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( + ... type='SGD', lr=0.01)) + >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) + >>> optim_wrapper + Type: OptimWrapper + accumulative_iters: 1 + optimizer: SGD ( Parameter Group 0 dampening: 0 @@ -828,71 +842,85 @@ class Runner: ) >>> # build multiple optimizers - >>> optim_cfg = dict( - ... generator=dict(type='SGD', lr=0.01), - ... discriminator=dict(type='Adam',lr=0.02) + >>> optim_wrapper_cfg = dict( + ... generator=dict(type='OptimWrapper', optimizer=dict( + ... type='SGD', lr=0.01)), + ... discriminator=dict(type='OptimWrapper', optimizer=dict( + ... type='Adam', lr=0.001)) ... # need to customize a multiple optimizer constructor ... constructor='CustomizedMultipleOptimizersConstructor', ...) - >>> optimizer = runner.build_optimizer(optim_cfg) - >>> optimizer - {'generator': SGD ( + >>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) + >>> optim_wrapper + name: generator + Type: OptimWrapper + accumulative_iters: 1 + optimizer: + SGD ( Parameter Group 0 dampening: 0 - lr: 0.01 + lr: 0.1 momentum: 0 nesterov: False weight_decay: 0 - ), - 'discriminator': SGD ( + ) + name: discriminator + Type: OptimWrapper + accumulative_iters: 1 + optimizer: + 'discriminator': Adam ( Parameter Group 0 dampening: 0 lr: 0.02 momentum: 0 nesterov: False weight_decay: 0 - )} + ) Important: If you need to build multiple optimizers, you should implement a MultipleOptimizerConstructor which gets parameters passed to - corresponding optimizers. More details about how to customize - OptimizerConstructor can be found at `optimizer-docs`_. + corresponding optimizers and compose the ``OptimWrapperDict``. + More details about how to customize OptimizerConstructor can be + found at `optimizer-docs`_. Returns: - Optimizer or dict[str, Optimizer]: Optimizer build from - ``optimizer``. + OptimWrapper: Optimizer wrapper build from ``optimizer_cfg``. .. _optimizer-docs: https://mmengine.readthedocs.io/en/latest/tutorials/optimizer.html """ - if isinstance(optimizer, Optimizer): - return optimizer - elif isinstance(optimizer, dict): - if 'type' not in optimizer and 'constructor' not in optimizer: - for name, optim in optimizer.items(): - if not isinstance(optim, Optimizer): + if isinstance(optim_wrapper, OptimWrapper): + return optim_wrapper + elif isinstance(optim_wrapper, (dict, ConfigDict, Config)): + if 'type' not in optim_wrapper and ('constructor' + not in optim_wrapper): + optim_wrappers = OrderedDict() + for name, optim in optim_wrapper.items(): + if not isinstance(optim, OptimWrapper): raise ValueError( 'each item mush be an optimizer object when "type"' ' and "constructor" are not in optimizer, ' f'but got {name}={optim}') - return optimizer - - return build_optimizer(self.model, optimizer) + optim_wrappers[name] = optim + return OptimWrapperDict(**optim_wrappers) + else: + optim_wrapper = build_optim_wrapper(self.model, optim_wrapper) + return optim_wrapper else: - raise TypeError('optimizer should be an Optimizer object or dict, ' - f'but got {optimizer}') + raise TypeError('optimizer wrapper should be an OptimWrapper ' + f'object or dict, but got {optim_wrapper}') - def _build_param_scheduler(self, scheduler: Union[_ParamScheduler, Dict, - List], - optimizer: Optimizer) -> List[_ParamScheduler]: + def _build_param_scheduler( + self, scheduler: Union[_ParamScheduler, Dict, List], + optim_wrapper: OptimWrapper) -> List[_ParamScheduler]: """Build parameter schedulers for a single optimizer. Args: scheduler (_ParamScheduler or dict or list): A Param Scheduler object or a dict or list of dict to build parameter schedulers. - optimizer (Optimizer): An optimizer object is passed to construnct - ParamScheduler object. + optim_wrapper (OptimWrapper): An optimizer wrapper object is + passed to construct ParamScheduler object. Returns: list[_ParamScheduler]: List of parameter schedulers build from @@ -922,7 +950,7 @@ class Runner: cls = PARAM_SCHEDULERS.get(_scheduler.pop('type')) param_schedulers.append( cls.build_iter_from_epoch( # type: ignore - optimizer=self.optimizer, + optimizer=optim_wrapper, **_scheduler, epoch_length=len( self.train_dataloader), # type: ignore @@ -931,11 +959,11 @@ class Runner: param_schedulers.append( PARAM_SCHEDULERS.build( _scheduler, - default_args=dict(optimizer=optimizer))) + default_args=dict(optimizer=optim_wrapper))) else: raise TypeError( - '_scheduler should be a _ParamScheduler object or dict, ' - f'but got {_scheduler}') + 'scheduler should be a _ParamScheduler object or dict, ' + f'but got {scheduler}') return param_schedulers @@ -944,9 +972,10 @@ class Runner: List]) -> ParamSchedulerType: """Build parameter schedulers. - ``build_param_scheduler`` should be called after ``build_optimizer`` - because the building logic will change according to the number of - optimizers built by the runner. The cases are as below: + ``build_param_scheduler`` should be called after + ``build_optim_wrapper`` because the building logic will change + according to the number of optimizers built by the runner. + The cases are as below: - Single optimizer: When only one optimizer is built and used in the runner, ``build_param_scheduler`` will return a list of @@ -968,7 +997,8 @@ class Runner: Examples: >>> # build one scheduler >>> optim_cfg = dict(dict(type='SGD', lr=0.01)) - >>> runner.optimizer = runner.build_optimizer(optim_cfg) + >>> runner.optim_wrapper = runner.build_optim_wrapper( + >>> optim_cfg) >>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) >>> schedulers = runner.build_param_scheduler(scheduler_cfg) >>> schedulers @@ -998,20 +1028,23 @@ class Runner: https://mmengine.readthedocs.io/en/latest/tutorials/optimizer.html """ param_schedulers: ParamSchedulerType - if isinstance(self.optimizer, Optimizer): + if not isinstance(self.optim_wrapper, OptimWrapperDict): + # Since `OptimWrapperDict` inherits from `OptimWrapper`, + # `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell + # whether `self.optim_wrapper` is an `OptimizerWrapper` or + # `OptimWrapperDict` instance. Therefore, here we simply check + # self.optim_wrapper is not an `OptimWrapperDict` instance and + # then assert it is an OptimWrapper instance. + assert isinstance(self.optim_wrapper, OptimWrapper), ( + '`build_optimizer` should be called before' + '`build_param_scheduler` because the latter depends ' + 'on the former') param_schedulers = self._build_param_scheduler( - scheduler, self.optimizer) + scheduler, self.optim_wrapper) # type: ignore return param_schedulers else: - assert isinstance(self.optimizer, dict) param_schedulers = dict() - for name, optimizer in self.optimizer.items(): - if not isinstance(optimizer, Optimizer): - raise RuntimeError( - '`build_optimizer` should be called before' - '`build_param_scheduler` because the latter depends ' - 'on the former') - + for name, optimizer in self.optim_wrapper.items(): if isinstance(scheduler, dict) and 'type' not in scheduler: # scheduler is a dict and each item is a ParamScheduler # object or a config to build ParamScheduler objects @@ -1356,7 +1389,7 @@ class Runner: # `build_optimizer` should be called before `build_param_scheduler` # because the latter depends on the former - self.optimizer = self.build_optimizer(self.optimizer) + self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) if self.param_schedulers: self.param_schedulers = self.build_param_scheduler( # type: ignore @@ -1418,9 +1451,6 @@ class Runner: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". **kwargs: Keyword arguments passed to hook. - - Raises: - TypeError: if Hook got unexpected arguments. """ for hook in self._hooks: # support adding additional custom hook methods @@ -1645,12 +1675,9 @@ class Runner: # resume optimizer if 'optimizer' in checkpoint and resume_optimizer: - self.optimizer = self.build_optimizer(self.optimizer) - if isinstance(self.optimizer, dict): - for name, optimizer in self.optimizer.items(): - optimizer.load_state_dict(checkpoint['optimizer'][name]) - else: - self.optimizer.load_state_dict(checkpoint['optimizer']) + self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) + self.optim_wrapper.load_state_dict( # type: ignore + checkpoint['optimizer']) # resume param scheduler if 'param_schedulers' in checkpoint and resume_param_scheduler: @@ -1771,16 +1798,13 @@ class Runner: } # save optimizer state dict to checkpoint if save_optimizer: - if isinstance(self.optimizer, Optimizer): - checkpoint['optimizer'] = self.optimizer.state_dict() - elif isinstance(self.optimizer, dict): - checkpoint['optimizer'] = dict() - for name, optimizer in self.optimizer.items(): - checkpoint['optimizer'][name] = optimizer.state_dict() + if isinstance(self.optim_wrapper, OptimWrapper): + checkpoint['optimizer'] = self.optim_wrapper.state_dict() else: raise TypeError( - 'self.optimizer should be an optimizer or a dict ' - f'containing optimizer, but got {self.optimizer}') + 'self.optim_wrapper should be an `OptimWrapper` ' + 'or `OptimWrapperDict` instance, but got ' + f'{self.optim_wrapper}') # save param scheduler state dict if save_param_scheduler: diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 01e30f86..690b3b39 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -2,7 +2,7 @@ from .hub import load_url from .manager import ManagerMeta, ManagerMixin from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - find_latest_checkpoint, has_method, + find_latest_checkpoint, has_batch_norm, has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, mmcv_full_available, @@ -28,5 +28,5 @@ __all__ = [ 'is_method_overridden', 'has_method', 'mmcv_full_available', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin', - 'set_multi_processing' + 'set_multi_processing', 'has_batch_norm' ] diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index e99c6400..13b8e322 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -514,3 +514,20 @@ def find_latest_checkpoint(path: str, suffix: str = 'pth'): latest_path = checkpoint return latest_path + + +def has_batch_norm(model: nn.Module) -> bool: + """Detect whether model has a BatchNormalization layer. + + Args: + model (nn.Module): training model. + + Returns: + bool: whether model has a BatchNormalization layer + """ + if isinstance(model, _BatchNorm): + return True + for m in model.children(): + if has_batch_norm(m): + return True + return False diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py index 995d6f8e..bc02b1e8 100644 --- a/tests/test_hook/test_ema_hook.py +++ b/tests/test_hook/test_ema_hook.py @@ -10,6 +10,7 @@ from torch.utils.data import Dataset from mmengine.hooks import EMAHook from mmengine.model import ExponentialMovingAverage +from mmengine.optim import OptimWrapper from mmengine.registry import DATASETS, MODEL_WRAPPERS from mmengine.runner import Runner @@ -79,7 +80,8 @@ class TestEMAHook(TestCase): num_workers=0), val_evaluator=evaluator, work_dir=self.temp_dir.name, - optimizer=torch.optim.Adam(ToyModel().parameters()), + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), train_cfg=dict(by_epoch=True, max_epochs=2), val_cfg=dict(interval=1), default_hooks=dict(logger=None), diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py index dc11ee0f..1e665a93 100644 --- a/tests/test_hook/test_optimizer_hook.py +++ b/tests/test_hook/test_optimizer_hook.py @@ -46,8 +46,8 @@ class TestOptimizerHook: x = torch.rand(1, 1, 3, 3) dummy_runner = MagicMock() - dummy_runner.optimizer.zero_grad = Mock(return_value=None) - dummy_runner.optimizer.step = Mock(return_value=None) + dummy_runner.optim_wrapper.zero_grad = Mock(return_value=None) + dummy_runner.optim_wrapper.step = Mock(return_value=None) dummy_runner.model = model dummy_runner.outputs = dict() @@ -82,7 +82,7 @@ class TestOptimizerHook: assert 'conv3.bias' in dummy_runner.logger.msg assert 'conv1.weight' not in dummy_runner.logger.msg assert 'conv1.bias' not in dummy_runner.logger.msg - dummy_runner.optimizer.step.assert_called() + dummy_runner.optim_wrapper.step.assert_called() dummy_runner.outputs['loss'].backward.assert_called() optimizer_hook.clip_grads.assert_called() optimizer_hook.detect_anomalous_parameters.assert_called() @@ -109,7 +109,7 @@ class TestOptimizerHook: optimizer_hook.after_train_iter(dummy_runner, 0) - dummy_runner.optimizer.step.assert_called() + dummy_runner.optim_wrapper.step.assert_called() dummy_runner.outputs['loss'].backward.assert_called() optimizer_hook.clip_grads.assert_not_called() optimizer_hook.detect_anomalous_parameters.assert_not_called() diff --git a/tests/test_hook/test_runtime_info_hook.py b/tests/test_hook/test_runtime_info_hook.py index 29adeab3..2eb651c5 100644 --- a/tests/test_hook/test_runtime_info_hook.py +++ b/tests/test_hook/test_runtime_info_hook.py @@ -2,8 +2,12 @@ from unittest import TestCase from unittest.mock import Mock +import torch.nn as nn +from torch.optim import SGD + from mmengine.hooks import RuntimeInfoHook from mmengine.logging import MessageHub +from mmengine.optim import OptimWrapper, OptimWrapperDict class TestRuntimeInfoHook(TestCase): @@ -47,18 +51,31 @@ class TestRuntimeInfoHook(TestCase): self.assertEqual(message_hub.get_info('epoch'), 9) def test_before_train_iter(self): + model = nn.Linear(1, 1) + optim1 = SGD(model.parameters(), lr=0.01) + optim2 = SGD(model.parameters(), lr=0.02) + optim_wrapper1 = OptimWrapper(optim1) + optim_wrapper2 = OptimWrapper(optim2) + optim_wrapper_dict = OptimWrapperDict( + key1=optim_wrapper1, key2=optim_wrapper2) # single optimizer message_hub = MessageHub.get_instance( 'runtime_info_hook_test_before_train_iter') runner = Mock() runner.iter = 9 - runner.optimizer.param_groups = [{'lr': 0.01}] + runner.optim_wrapper = optim_wrapper1 runner.message_hub = message_hub hook = RuntimeInfoHook() hook.before_train_iter(runner, batch_idx=2, data_batch=None) self.assertEqual(message_hub.get_info('iter'), 9) self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01) + with self.assertRaisesRegex(AssertionError, + 'runner.optim_wrapper.get_lr()'): + runner.optim_wrapper = Mock() + runner.optim_wrapper.get_lr = Mock(return_value='error type') + hook.before_train_iter(runner, batch_idx=2, data_batch=None) + # multiple optimizers message_hub = MessageHub.get_instance( 'runtime_info_hook_test_before_train_iter') @@ -68,8 +85,8 @@ class TestRuntimeInfoHook(TestCase): optimizer1.param_groups = [{'lr': 0.01}] optimizer2 = Mock() optimizer2.param_groups = [{'lr': 0.02}] - runner.optimizer = dict(key1=optimizer1, key2=optimizer2) runner.message_hub = message_hub + runner.optim_wrapper = optim_wrapper_dict hook = RuntimeInfoHook() hook.before_train_iter(runner, batch_idx=2, data_batch=None) self.assertEqual(message_hub.get_info('iter'), 9) diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py index 8bad4f6c..c1f96ac4 100644 --- a/tests/test_model/test_wrappers/test_data_parallel.py +++ b/tests/test_model/test_wrappers/test_data_parallel.py @@ -37,6 +37,12 @@ def test_is_model_wrapper(): if hasattr(torch.distributed, '_verify_model_across_ranks'): torch.distributed._verify_model_across_ranks = mock + # _verify_model_across_ranks is added in torch1.11.0 so we should check + # whether _verify_params_across_processes is the member of + # torch.distributed before mocking + if hasattr(torch.distributed, '_verify_params_across_processes'): + torch.distributed._verify_params_across_processes = mock + model = Model() assert not is_model_wrapper(model) diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 2115c20c..c76b10dc 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -6,8 +6,9 @@ from unittest.mock import MagicMock import torch import torch.nn as nn -from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, - DefaultOptimizerConstructor, build_optimizer) +from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, + DefaultOptimWrapperConstructor, OptimWrapper, + build_optim_wrapper) from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS from mmengine.registry import build_from_cfg from mmengine.utils import mmcv_full_available @@ -201,30 +202,52 @@ class TestBuilder(TestCase): def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) - optimizer = build_optimizer(self.model, optimizer_cfg) - self._check_default_optimizer(optimizer, self.model) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) + self._check_default_optimizer(optim_wrapper.optimizer, self.model) + + # test build optimizer without type in optim_wrapper_cfg + optim_wrapper_cfg = dict( + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) + self.assertIsInstance(optim_wrapper, OptimWrapper) + self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build function with invalid ``constructor`` with self.assertRaises(KeyError): - optimizer_cfg['constructor'] = 'INVALID_CONSTRUCTOR' - build_optimizer(self.model, optimizer_cfg) + optim_wrapper_cfg['constructor'] = 'INVALID_CONSTRUCTOR' + build_optim_wrapper(self.model, optim_wrapper_cfg) # test build function with invalid ``paramwise_cfg`` with self.assertRaises(KeyError): - optimizer_cfg['paramwise_cfg'] = dict(invalid_mult=1) - build_optimizer(self.model, optimizer_cfg) + optim_wrapper_cfg['paramwise_cfg'] = dict(invalid_mult=1) + build_optim_wrapper(self.model, optim_wrapper_cfg) + + optim_wrapper_cfg.pop('optimizer') + optim_wrapper_cfg.pop('constructor') + optim_wrapper_cfg.pop('paramwise_cfg') + self.assertRaisesRegex( + AssertionError, '`optim_wrapper_cfg` must contain', + lambda: build_optim_wrapper(self.model, optim_wrapper_cfg)) def test_build_default_optimizer_constructor(self): - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, @@ -232,22 +255,26 @@ class TestBuilder(TestCase): dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) optim_constructor_cfg = dict( - type='DefaultOptimizerConstructor', - optimizer_cfg=optimizer_cfg, + type='DefaultOptimWrapperConstructor', + optim_wrapper_cfg=optim_wrapper, paramwise_cfg=paramwise_cfg) - optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg) - optimizer = optim_constructor(self.model) - self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg) + optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( + optim_constructor_cfg) + optim_wrapper = optim_constructor(self.model) + self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, + **paramwise_cfg) def test_build_custom_optimizer_constructor(self): - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) - @OPTIMIZER_CONSTRUCTORS.register_module() - class MyOptimizerConstructor(DefaultOptimizerConstructor): + @OPTIM_WRAPPER_CONSTRUCTORS.register_module() + class MyOptimizerConstructor(DefaultOptimWrapperConstructor): def __call__(self, model): if hasattr(model, 'module'): @@ -268,9 +295,10 @@ class TestBuilder(TestCase): paramwise_cfg = dict(conv1_lr_mult=5) optim_constructor_cfg = dict( type='MyOptimizerConstructor', - optimizer_cfg=optimizer_cfg, + optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg) - optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg) + optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( + optim_constructor_cfg) optimizer = optim_constructor(self.model) param_groups = optimizer.param_groups @@ -291,153 +319,182 @@ class TestBuilder(TestCase): with self.assertRaises(TypeError): # optimizer_cfg must be a dict optimizer_cfg = [] - optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) + optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_constructor(self.model) with self.assertRaises(TypeError): # paramwise_cfg must be a dict or None - optimizer_cfg = dict(lr=0.0001) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict(lr=0.0001, weight_decay=None)) paramwise_cfg = ['error'] - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) with self.assertRaises(ValueError): # bias_decay_mult/norm_decay_mult is specified but weight_decay # is None - optimizer_cfg = dict(lr=0.0001, weight_decay=None) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict(lr=0.0001, weight_decay=None)) paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) # basic config with ExampleModel optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) - optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) - optimizer = optim_constructor(self.model) - self._check_default_optimizer(optimizer, self.model) + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) + optim_wrapper = optim_constructor(self.model) + self._check_default_optimizer(optim_wrapper.optimizer, self.model) def test_default_optimizer_constructor_with_model_wrapper(self): # basic config with pseudo data parallel model = PseudoDataParallel() - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = None - optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) - optimizer = optim_constructor(model) - self._check_default_optimizer(optimizer, model, prefix='module.') + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) + optim_wrapper = optim_constructor(model) + self._check_default_optimizer( + optim_wrapper.optimizer, model, prefix='module.') # paramwise_cfg with pseudo data parallel model = PseudoDataParallel() - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optim_wrapper = optim_constructor(model) self._check_sgd_optimizer( - optimizer, model, prefix='module.', **paramwise_cfg) + optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg) # basic config with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(ExampleModel()) - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = None - optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) - optimizer = optim_constructor(model) - self._check_default_optimizer(optimizer, model, prefix='module.') + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg) + optim_wrapper = optim_constructor(model) + self._check_default_optimizer( + optim_wrapper.optimizer, model, prefix='module.') # paramwise_cfg with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(self.model) - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optim_wrapper = optim_constructor(model) self._check_sgd_optimizer( - optimizer, model, prefix='module.', **paramwise_cfg) + optim_wrapper.optimizer, + model, + prefix='module.', + **paramwise_cfg) def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): # Empty paramwise_cfg with ExampleModel - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict() - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(self.model) - self._check_default_optimizer(optimizer, self.model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optim_wrapper = optim_constructor(self.model) + self._check_default_optimizer(optim_wrapper.optimizer, self.model) # Empty paramwise_cfg with ExampleModel and no grad model = ExampleModel() for param in model.parameters(): param.requires_grad = False - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict() - optim_constructor = DefaultOptimizerConstructor(optimizer_cfg) - optimizer = optim_constructor(model) - self._check_default_optimizer(optimizer, model) + optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) + optim_wrapper = optim_constructor(model) + self._check_default_optimizer(optim_wrapper.optimizer, model) def test_default_optimizer_constructor_with_paramwise_cfg(self): # paramwise_cfg with ExampleModel - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, norm_decay_mult=0, dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(self.model) - self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optim_wrapper = optim_constructor(self.model) + self._check_sgd_optimizer(optim_wrapper.optimizer, self.model, + **paramwise_cfg) def test_default_optimizer_constructor_no_grad(self): # paramwise_cfg with ExampleModel and no grad - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, @@ -447,11 +504,12 @@ class TestBuilder(TestCase): for param in self.model.parameters(): param.requires_grad = False - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(self.model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optim_wrapper = optim_constructor(self.model) + optimizer = optim_wrapper.optimizer param_groups = optimizer.param_groups - assert isinstance(optimizer, torch.optim.SGD) + assert isinstance(optim_wrapper.optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr assert optimizer.defaults['momentum'] == self.momentum assert optimizer.defaults['weight_decay'] == self.base_wd @@ -465,11 +523,13 @@ class TestBuilder(TestCase): def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option model = ExampleDuplicateModel() - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( bias_lr_mult=2, bias_decay_mult=0.5, @@ -479,8 +539,8 @@ class TestBuilder(TestCase): with self.assertRaisesRegex( ValueError, 'some parameters appear in more than one parameter group'): - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) optim_constructor(model) paramwise_cfg = dict( @@ -490,27 +550,31 @@ class TestBuilder(TestCase): dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1, bypass_duplicate=True) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) self.assertWarnsRegex( Warning, 'conv3.0 is duplicate. It is skipped since bypass_duplicate=True', lambda: optim_constructor(model)) - optimizer = optim_constructor(model) + optim_wrapper = optim_constructor(model) model_parameters = list(model.parameters()) num_params = 14 if MMCV_FULL_AVAILABLE else 11 - assert len( - optimizer.param_groups) == len(model_parameters) == num_params - self._check_sgd_optimizer(optimizer, model, **paramwise_cfg) + assert len(optim_wrapper.optimizer.param_groups) == len( + model_parameters) == num_params + self._check_sgd_optimizer(optim_wrapper.optimizer, model, + **paramwise_cfg) def test_default_optimizer_constructor_custom_key(self): - # test DefaultOptimizerConstructor with custom_keys and ExampleModel - optimizer_cfg = dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum) + # test DefaultOptimWrapperConstructor with custom_keys and + # ExampleModel + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict( custom_keys={ 'param1': dict(lr_mult=10), @@ -523,23 +587,24 @@ class TestBuilder(TestCase): with self.assertRaises(TypeError): # custom_keys should be a dict paramwise_cfg_ = dict(custom_keys=[0.1, 0.0001]) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg_) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg_) optimizer = optim_constructor(self.model) with self.assertRaises(ValueError): # if 'decay_mult' is specified in custom_keys, weight_decay # should be specified - optimizer_cfg_ = dict(type='SGD', lr=0.01) + optim_wrapper_cfg_ = dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)) paramwise_cfg_ = dict( custom_keys={'.backbone': dict(decay_mult=0.5)}) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg_, paramwise_cfg_) - optimizer = optim_constructor(self.model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg_, paramwise_cfg_) + optim_constructor(self.model) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(self.model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model).optimizer # check optimizer type and default config assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr @@ -598,14 +663,17 @@ class TestBuilder(TestCase): assert param_groups[i][setting] == settings[ setting], f'{name} {setting}' - # test DefaultOptimizerConstructor with custom_keys and ExampleModel 2 - optimizer_cfg = dict( - type='SGD', lr=self.base_lr, momentum=self.momentum) + # test DefaultOptimWrapperConstructor with custom_keys and + # ExampleModel 2 + optim_wrapper_cfg = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', lr=self.base_lr, momentum=self.momentum)) paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) - optim_constructor = DefaultOptimizerConstructor( - optimizer_cfg, paramwise_cfg) - optimizer = optim_constructor(self.model) + optim_constructor = DefaultOptimWrapperConstructor( + optim_wrapper_cfg, paramwise_cfg) + optimizer = optim_constructor(self.model).optimizer # check optimizer type and default config assert isinstance(optimizer, torch.optim.SGD) assert optimizer.defaults['lr'] == self.base_lr diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py new file mode 100644 index 00000000..22a60a20 --- /dev/null +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -0,0 +1,384 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +import torch +import torch.distributed as torch_dist +import torch.nn as nn +from torch.cuda.amp import GradScaler +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import SGD, Adam, Optimizer + +from mmengine import MessageHub, MMLogger +from mmengine.dist import all_gather +from mmengine.optim import AmpOptimWrapper, OptimWrapper +from mmengine.testing import assert_allclose +from mmengine.testing._internal import MultiProcessTestCase +from mmengine.utils import TORCH_VERSION, digit_version + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 1, 1) + self.conv2 = nn.Conv2d(1, 1, 1) + self.conv3 = nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ToyModel2(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv(x) + return x + + +class TestOptimWrapper(MultiProcessTestCase): + # Test `OptimWrapper.accumulate_grad` will block the gradient + # synchronization when using gradient accumulation strategy in distributed + # data parallel training. + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def run_test(self, test_name: str, parent_pipe) -> None: + self.model = ToyModel() + self.optimizer = SGD(self.model.parameters(), lr=0.1) + self.logger = MMLogger.get_instance('test_optim_wrapper') + self.message_hub = MessageHub.get_instance('test_optim_wrapper_init') + super().run_test(test_name, parent_pipe) + + def test_init(self): + optim_wrapper = OptimWrapper(self.optimizer) + self.assertEqual(optim_wrapper.optimizer, self.optimizer) + self.assertIsNone(optim_wrapper.clip_grad_kwargs) + self.assertEqual(optim_wrapper.accumulative_iters, 1) + self.assertIs(optim_wrapper.logger, self.logger) + self.assertIs(optim_wrapper.message_hub, self.message_hub) + + with self.assertRaisesRegex(AssertionError, + 'If `clip_grad_kwargs` is not None'): + OptimWrapper(self.optimizer, clip_grad=[]) + + def test_update_params(self): + # Test update params every iteration. + optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1) + self._mock_method(optim_wrapper) + loss = torch.tensor(1) + optim_wrapper.update_params(loss) + optim_wrapper.backward.assert_called_with(torch.tensor(1)) + optim_wrapper.step.assert_called_with() + optim_wrapper.zero_grad.assert_called_with() + + with optim_wrapper.accumulate_grad(self.model, 2, 100): + optim_wrapper.update_params(torch.tensor(1)) + optim_wrapper.backward.assert_called_with(torch.tensor(1)) + optim_wrapper.step.assert_called_with() + optim_wrapper.zero_grad.assert_called_with() + + # It will raise an error if `accumulative_iters > 1` and + # `accumulate_grad` is not enabled. + optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) + self._mock_method(optim_wrapper) + with self.assertRaisesRegex(AssertionError, + 'gradient accumulation must be'): + optim_wrapper.update_params(loss) + + # `iter=0`, Call `optimizer_step` first time. + with optim_wrapper.accumulate_grad( + self.model, cur_iter=0, max_iters=100): + loss = torch.tensor(1) + optim_wrapper.update_params(loss) + optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3) + optim_wrapper.step.assert_not_called() + optim_wrapper.zero_grad.assert_not_called() + + # `iter=2`, Call `optimizer_step` first time. + with optim_wrapper.accumulate_grad( + self.model, cur_iter=2, max_iters=100): + optim_wrapper.update_params(loss) + optim_wrapper.step.assert_called() + optim_wrapper.zero_grad.assert_called() + self._mock_method(optim_wrapper) + # Test end of training. + with optim_wrapper.accumulate_grad( + self.model, cur_iter=99, max_iters=100): + optim_wrapper.update_params(loss) + optim_wrapper.step.assert_called() + optim_wrapper.zero_grad.assert_called() + optim_wrapper.backward.assert_called_with(1) + + # If ``accumulative_iters > 1``, call ``update_params`` with + # non-accumulate_grad context will raise an Assertion error + optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1) + optim_wrapper.accumulative_iters = 2 + with self.assertRaisesRegex(AssertionError, + 'gradient accumulation must be performed'): + optim_wrapper.update_params(loss) + + def test_initilize_iter_status(self): + optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) + optim_wrapper._initilize_iter_status(self.model) + self.assertEqual(optim_wrapper.divisible_iters, 0) + self.assertEqual(optim_wrapper.remainder_iters, 0) + + # Indivisible cur_iter will output warning. + optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) + optim_wrapper.cur_iter = 0 + optim_wrapper.max_iters = 100 + with self.assertLogs(self.logger) as cm: + optim_wrapper._initilize_iter_status(self.model) + self.assertEqual(len(cm.output), 1) + self.assertRegex(cm.records[0].msg, 'Resume iter number is not') + + # Model with batch norm will output warning. + optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) + optim_wrapper.cur_iter = 0 + optim_wrapper.max_iters = 99 + model = nn.BatchNorm2d(1) + with self.assertLogs(self.logger) as cm: + optim_wrapper._initilize_iter_status(model) + self.assertEqual(len(cm.output), 1) + self.assertRegex(cm.records[0].msg, 'Gradient accumulative') + + def test_ger_lr(self): + model = ToyModel() + optim = SGD(model.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optim) + self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1])) + + def test_get_momentum(self): + # Get momentum from SGD + model = ToyModel() + optim = SGD(model.parameters(), lr=0., momentum=0.8) + optim_wrapper = OptimWrapper(optim) + self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.8])) + # Get momentum from Adam + optim = Adam(model.parameters(), lr=0., betas=(0.9, 0.9)) + optim_wrapper = OptimWrapper(optim) + self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.9])) + + def test_backward(self): + loss = MagicMock() + optim_wrapper = OptimWrapper(self.optimizer) + optim_wrapper.backward(loss) + loss.backward.assert_called() + + def test_zero_grad(self): + optimizer = MagicMock(spec=Optimizer) + optim_wrapper = OptimWrapper(optimizer) + optim_wrapper.zero_grad() + optimizer.zero_grad.assert_called() + + def test_step(self): + optimizer = MagicMock(spec=Optimizer) + optim_wrapper = OptimWrapper(optimizer) + optim_wrapper.step() + optimizer.step.assert_called() + + def test_clip_grads(self): + optim_wrapper = OptimWrapper( + self.optimizer, clip_grad=dict(max_norm=35)) + loss = self.model(torch.Tensor(1, 1, 1, 1)) + loss.backward() + optim_wrapper._clip_grad() + log_scalars = self.message_hub.log_scalars + self.assertIn('train/grad_norm', log_scalars) + + def test_state_dict(self): + optim_wrapper = OptimWrapper(self.optimizer) + self.assertEqual(optim_wrapper.state_dict(), + self.optimizer.state_dict()) + + def test_load_state_dict(self): + optim_wrapper = OptimWrapper(self.optimizer) + model = ToyModel() + optimizer = SGD(model.parameters(), lr=0.1) + optim_wrapper.load_state_dict(optimizer.state_dict()) + + self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict()) + + def test_param_groups(self): + optim_wrapper = OptimWrapper(self.optimizer) + self.assertEqual(optim_wrapper.param_groups, + self.optimizer.param_groups) + + def test_accumulate_grad(self): + self._init_dist_env(self.rank, self.world_size) + model = ToyModel2() + ddp_model = DistributedDataParallel(model) + optimizer = SGD(ddp_model.parameters(), lr=0.01) + optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) + optim_wrapper.zero_grad() + with optim_wrapper.accumulate_grad(ddp_model, 0, 100): + # Automatically sync grads if `accumulative_iters` = 1 + inputs = torch.randn(1, 1, 1, 1) * self.rank + ddp_model(inputs).sum().backward() + grad = model.conv.weight.grad + all_grads = all_gather(grad) + assert_allclose(all_grads[0], all_grads[1]) + + # Do not sync grads when `optim_wrapper.cur_iter` cannot be + # divided by `optim_wrapper.accumulative_iters` + optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3) + with optim_wrapper.accumulate_grad(ddp_model, 0, 100): + ddp_model(inputs).sum().backward() + all_grads = all_gather(model.conv.weight.grad) + with self.assertRaises(AssertionError): + assert_allclose(all_grads[0], all_grads[1]) + + # sync grads if `cur_iter == 2` + with optim_wrapper.accumulate_grad(ddp_model, 2, 100): + ddp_model(inputs).sum().backward() + all_grads = all_gather(model.conv.weight.grad) + assert_allclose(all_grads[0], all_grads[1]) + + def test_precision_context(self): + optim_wrapper = OptimWrapper(self.optimizer) + with optim_wrapper.precision_context(): + pass + + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29515' + os.environ['RANK'] = str(rank) + torch_dist.init_process_group( + backend='gloo', rank=rank, world_size=world_size) + + # TODO Test the real interface after add testing tool function which can + # test the function or method is read called. + def _mock_method(self, optim_wrapper): + optim_wrapper.backward = MagicMock() + optim_wrapper.step = MagicMock() + optim_wrapper.zero_grad = MagicMock() + + +class TestAmpOptimWrapper(TestCase): + + def setUp(self) -> None: + self.model = ToyModel() + self.optimizer = SGD(self.model.parameters(), lr=0.1) + + @unittest.skipIf( + not torch.cuda.is_available() + and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), + reason='`torch.cuda.amp` is only available when pytorch-gpu version ' + '>= 1.6') + def test_init(self): + # Test with default arguments. + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + + # Test with dynamic. + amp_optim_wrapper = AmpOptimWrapper( + 'dynamic', optimizer=self.optimizer) + self.assertIsNone(amp_optim_wrapper._scale_update_param) + self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + + # Test with dict loss_scale. + amp_optim_wrapper = AmpOptimWrapper( + dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) + self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + self.assertIsNone(amp_optim_wrapper._scale_update_param) + with self.assertRaisesRegex(TypeError, + 'loss_scale must be of type float'): + AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown') + + @unittest.skipIf( + not torch.cuda.is_available() + and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), + reason='`torch.cuda.amp` is only available when pytorch-gpu version ' + '>= 1.6') + def test_step(self): + optimizer = MagicMock(spec=Optimizer) + amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer) + amp_optim_wrapper.loss_scaler = MagicMock() + amp_optim_wrapper.step() + amp_optim_wrapper.loss_scaler.step.assert_called_with( + amp_optim_wrapper.optimizer) + amp_optim_wrapper.loss_scaler.update.assert_called_with( + amp_optim_wrapper._scale_update_param) + + @unittest.skipIf( + not torch.cuda.is_available() + and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), + reason='`torch.cuda.amp` is only available when pytorch-gpu version ' + '>= 1.6') + def test_backward(self): + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + loss_scaler = MagicMock() + scale_return = MagicMock() + scale_fn = MagicMock(return_value=scale_return) + loss_scaler.scale = scale_fn + amp_optim_wrapper.loss_scaler = loss_scaler + + amp_optim_wrapper.backward(1) + loss_scaler.scale.assert_called_with(1) + scale_return.backward.assert_called_with() + + @unittest.skipIf( + not torch.cuda.is_available() + and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), + reason='`torch.cuda.amp` is only available when pytorch-gpu version ' + '>= 1.6') + def test_state_dict(self): + self.model = self.model.cuda() + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + amp_optim_wrapper.update_params(loss) + state_dict = amp_optim_wrapper.state_dict() + scalar_state_dict = state_dict.pop('loss_scaler') + optim_state_dict = state_dict + + self.assertDictEqual(optim_state_dict, + amp_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(scalar_state_dict, + amp_optim_wrapper.loss_scaler.state_dict()) + + @unittest.skipIf( + not torch.cuda.is_available() + and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), + reason='`torch.cuda.amp` is only available when pytorch-gpu version ' + '>= 1.6') + def test_load_state_dict(self): + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + self.model = self.model.cuda() + # Test load from optimizer + optimizer = SGD(self.model.parameters(), lr=0.1) + amp_optim_wrapper.load_state_dict(optimizer.state_dict()) + + self.assertDictEqual(optimizer.state_dict(), + amp_optim_wrapper.optimizer.state_dict()) + # Test load from optim_wrapper + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + amp_optim_wrapper_ = AmpOptimWrapper( + optimizer=SGD(self.model.parameters(), lr=0.1)) + amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict()) + self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(), + amp_optim_wrapper_.optimizer.state_dict()) + self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(), + amp_optim_wrapper_.loss_scaler.state_dict()) + + @unittest.skipIf( + not torch.cuda.is_available() + and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), + reason='`torch.cuda.amp` is only available when pytorch-gpu version ' + '>= 1.6') + def test_precision_context(self): + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + with amp_optim_wrapper.precision_context(): + x = torch.randn(1, 1, 1, 1).cuda() + y = nn.Conv2d(1, 1, 1).cuda()(x) + self.assertEqual(y.dtype, torch.float16) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py new file mode 100644 index 00000000..f3259e06 --- /dev/null +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager +from unittest import TestCase +from unittest.mock import patch + +import torch.nn as nn +from torch.optim import SGD + +from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict + + +class TestOptimWrapperDict(TestCase): + + def setUp(self) -> None: + model1 = nn.Linear(1, 1) + model2 = nn.Linear(1, 1) + self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8) + self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9) + self.optim_wrapper1 = OptimWrapper(self.optim1) + self.optim_wrapper2 = OptimWrapper(self.optim2) + self.optimizers_wrappers = dict( + optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) + + @patch('torch.cuda.is_available', lambda: True) + def test_init(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertEqual(optim_wrapper_dict.optim_wrappers, + self.optimizers_wrappers) + # Different types of OptimWrapper will raise an error + + with self.assertRaisesRegex( + AssertionError, 'All optimizer wrappers should have the same'): + optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2) + OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2) + + with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'): + optim_wrapper2 = OptimWrapper( + optimizer=self.optim2, accumulative_iters=2) + OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2) + + def test_accumulate_grad(self): + + @contextmanager + def context_a(a, b, *args, **kwargs): + a[0] = 100 + yield + a[0] = 1 + + @contextmanager + def context_b(a, b, *args, **kwargs): + b[0] = 200 + yield + b[0] = 2 + + a = [0] + b = [0] + # Test enter the context both of `optim_wrapper1` and `optim_wrapper1`. + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.optim_wrapper1.accumulate_grad = context_a + self.optim_wrapper2.accumulate_grad = context_b + with optim_wrapper_dict.accumulate_grad(a, b, 0): + self.assertEqual(a[0], 100) + self.assertEqual(b[0], 200) + + self.assertEqual(a[0], 1) + self.assertEqual(b[0], 2) + + def test_state_dict(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + state_dict = optim_wrapper_dict.state_dict() + self.assertEqual(state_dict['optim1'], + self.optim_wrapper1.state_dict()) + self.assertEqual(state_dict['optim2'], + self.optim_wrapper2.state_dict()) + + def test_load_state_dict(self): + # Test OptimWrapperDict can load from saved state dict. + model1 = nn.Linear(1, 1) + model2 = nn.Linear(1, 1) + optim1 = SGD(model1.parameters(), lr=0.1) + optim2 = SGD(model2.parameters(), lr=0.1) + optim_wrapper_load1 = OptimWrapper(optim1) + optim_wrapper_load2 = OptimWrapper(optim2) + + optim_wrapper_dict_save = OptimWrapperDict(**self.optimizers_wrappers) + optim_wrapper_dict_load = OptimWrapperDict( + optim1=optim_wrapper_load1, optim2=optim_wrapper_load2) + state_dict = optim_wrapper_dict_save.state_dict() + optim_wrapper_dict_load.load_state_dict(state_dict) + + self.assertDictEqual(optim_wrapper_dict_load.state_dict(), + optim_wrapper_dict_save.state_dict()) + + def test_items(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertListEqual( + list(optim_wrapper_dict.items()), + list(self.optimizers_wrappers.items())) + + def test_values(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertListEqual( + list(optim_wrapper_dict.values()), + list(self.optimizers_wrappers.values())) + + def test_keys(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertListEqual( + list(optim_wrapper_dict.keys()), + list(self.optimizers_wrappers.keys())) + + def test_get_lr(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + lr = optim_wrapper_dict.get_lr() + self.assertEqual(lr['optim1.lr'], [0.1]) + self.assertEqual(lr['optim2.lr'], [0.2]) + + def test_get_momentum(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + momentum = optim_wrapper_dict.get_momentum() + self.assertEqual(momentum['optim1.momentum'], [0.8]) + self.assertEqual(momentum['optim2.momentum'], [0.9]) + + def test_getitem(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertIs(self.optimizers_wrappers['optim1'], + optim_wrapper_dict['optim1']) + self.assertIs(self.optimizers_wrappers['optim2'], + optim_wrapper_dict['optim2']) + + def test_len(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertEqual(len(optim_wrapper_dict), 2) + + def test_contain(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertIn('optim1', optim_wrapper_dict) + + def test_repr(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + desc = repr(optim_wrapper_dict) + self.assertRegex(desc, 'name: optim1') diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index f1b4fea9..c88cf472 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -16,13 +16,15 @@ from mmengine.config import Config from mmengine.data import DefaultSampler from mmengine.evaluator import BaseMetric, Evaluator from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, - IterTimerHook, LoggerHook, OptimizerHook, - ParamSchedulerHook) + IterTimerHook, LoggerHook, ParamSchedulerHook, + RuntimeInfoHook) from mmengine.logging import LogProcessor, MessageHub, MMLogger -from mmengine.optim import DefaultOptimizerConstructor, MultiStepLR, StepLR +from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, + OptimWrapper, OptimWrapperDict, StepLR) from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, - MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, - PARAM_SCHEDULERS, Registry) + MODEL_WRAPPERS, MODELS, + OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS, + Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.priority import Priority, get_priority @@ -73,24 +75,24 @@ class CustomModelWrapper(nn.Module): self.model = model -@OPTIMIZER_CONSTRUCTORS.register_module() +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() class ToyMultipleOptimizerConstructor: - def __init__(self, optimizer_cfg, paramwise_cfg=None): - if not isinstance(optimizer_cfg, dict): + def __init__(self, optim_wrapper_cfg, paramwise_cfg=None): + if not isinstance(optim_wrapper_cfg, dict): raise TypeError('optimizer_cfg should be a dict', - f'but got {type(optimizer_cfg)}') + f'but got {type(optim_wrapper_cfg)}') assert paramwise_cfg is None, ( 'parawise_cfg should be set in each optimizer separately') - self.optimizer_cfg = optimizer_cfg + self.optim_wrapper_cfg = optim_wrapper_cfg self.constructors = {} - for key, cfg in self.optimizer_cfg.items(): + for key, cfg in self.optim_wrapper_cfg.items(): _cfg = cfg.copy() paramwise_cfg_ = _cfg.pop('paramwise_cfg', None) - self.constructors[key] = DefaultOptimizerConstructor( + self.constructors[key] = DefaultOptimWrapperConstructor( _cfg, paramwise_cfg_) - def __call__(self, model: nn.Module) -> torch.optim.Optimizer: + def __call__(self, model: nn.Module) -> OptimWrapperDict: optimizers = {} while hasattr(model, 'module'): model = model.module @@ -98,7 +100,7 @@ class ToyMultipleOptimizerConstructor: for key, constructor in self.constructors.items(): module = getattr(model, key) optimizers[key] = constructor(module) - return optimizers + return OptimWrapperDict(**optimizers) @DATASETS.register_module() @@ -160,13 +162,6 @@ class ToyHook2(Hook): pass -@HOOKS.register_module() -class ToyHook3(Hook): - - def before_train_iter(self, runner, data_batch): - pass - - @LOOPS.register_module() class CustomTrainLoop(BaseLoop): @@ -246,7 +241,8 @@ class TestRunner(TestCase): sampler=dict(type='DefaultSampler', shuffle=False), batch_size=3, num_workers=0), - optimizer=dict(type='SGD', lr=0.01), + optim_wrapper=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), val_evaluator=dict(type='ToyMetric1'), test_evaluator=dict(type='ToyMetric1'), @@ -299,7 +295,7 @@ class TestRunner(TestCase): # also be None cfg.experiment_name = 'test_init2' cfg.pop('train_dataloader') - cfg.pop('optimizer') + cfg.pop('optim_wrapper') cfg.pop('param_scheduler') runner = Runner(**cfg) self.assertIsInstance(runner, Runner) @@ -324,7 +320,7 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_init5' cfg.pop('train_cfg') cfg.pop('train_dataloader') - cfg.pop('optimizer') + cfg.pop('optim_wrapper') with self.assertRaisesRegex(ValueError, 'should be None'): runner = Runner(**cfg) @@ -398,7 +394,7 @@ class TestRunner(TestCase): self.assertIsInstance(runner._train_dataloader, dict) self.assertIsInstance(runner._val_dataloader, dict) self.assertIsInstance(runner._test_dataloader, dict) - self.assertIsInstance(runner.optimizer, dict) + self.assertIsInstance(runner.optim_wrapper, dict) self.assertIsInstance(runner.param_schedulers[0], dict) # After calling runner.train(), @@ -408,7 +404,7 @@ class TestRunner(TestCase): self.assertIsInstance(runner._train_loop, BaseLoop) self.assertIsInstance(runner.train_dataloader, DataLoader) - self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper, OptimWrapper) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertIsInstance(runner._val_loop, BaseLoop) self.assertIsInstance(runner._val_loop.dataloader, DataLoader) @@ -423,10 +419,10 @@ class TestRunner(TestCase): # 4. initialize runner with objects rather than config model = ToyModel() - optimizer = SGD( + optim_wrapper = OptimWrapper(SGD( model.parameters(), lr=0.01, - ) + )) toy_hook = ToyHook() toy_hook2 = ToyHook2() @@ -438,8 +434,8 @@ class TestRunner(TestCase): work_dir=self.temp_dir, train_cfg=dict(by_epoch=True, max_epochs=3), train_dataloader=train_dataloader, - optimizer=optimizer, - param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]), + optim_wrapper=optim_wrapper, + param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), val_cfg=dict(interval=1, begin=1), val_dataloader=val_dataloader, val_evaluator=ToyMetric1(), @@ -618,79 +614,92 @@ class TestRunner(TestCase): runner = Runner.from_cfg(cfg) self.assertIsInstance(runner.model, CustomModelWrapper) - def test_build_optimizer(self): + def test_build_optim_wrapper(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_build_optimizer' + cfg.experiment_name = 'test_build_optim_wrapper' runner = Runner.from_cfg(cfg) # input should be an Optimizer object or dict - with self.assertRaisesRegex(TypeError, 'optimizer should be'): - runner.build_optimizer('invalid-type') + with self.assertRaisesRegex(TypeError, 'optimizer wrapper should be'): + runner.build_optim_wrapper('invalid-type') # 1. test one optimizer # 1.1 input is an Optimizer object - _optimizer = SGD(runner.model.parameters(), lr=0.01) - optimizer = runner.build_optimizer(_optimizer) - self.assertEqual(id(_optimizer), id(optimizer)) + optimizer = SGD(runner.model.parameters(), lr=0.01) + optim_wrapper = OptimWrapper(optimizer) + optim_wrapper = runner.build_optim_wrapper(optim_wrapper) + self.assertEqual(id(optimizer), id(optim_wrapper.optimizer)) # 1.2 input is a dict - optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01)) - self.assertIsInstance(optimizer, SGD) + optim_wrapper = runner.build_optim_wrapper( + dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))) + self.assertIsInstance(optim_wrapper, OptimWrapper) # 2. test multiple optmizers # 2.1 input is a dict which contains multiple optimizer objects optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01) + optim_wrapper1 = OptimWrapper(optimizer1) optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02) - optim_cfg = dict(key1=optimizer1, key2=optimizer2) - optimizer = runner.build_optimizer(optim_cfg) - self.assertIsInstance(optimizer, dict) - self.assertIsInstance(optimizer['key1'], SGD) - self.assertIsInstance(optimizer['key2'], Adam) + optim_wrapper2 = OptimWrapper(optimizer2) + optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) + optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) + self.assertIsInstance(optim_wrapper, OptimWrapperDict) + self.assertIsInstance(optim_wrapper['key1'].optimizer, SGD) + self.assertIsInstance(optim_wrapper['key2'].optimizer, Adam) # 2.2 each item mush be an optimizer object when "type" and # "constructor" are not in optimizer optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01) - optimizer2 = dict(type='Adam', lr=0.02) - optim_cfg = dict(key1=optimizer1, key2=optimizer2) + optim_wrapper1 = OptimWrapper(optimizer1) + optim_wrapper2 = dict( + type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01)) + optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) with self.assertRaisesRegex(ValueError, 'each item mush be an optimizer object'): - runner.build_optimizer(optim_cfg) + runner.build_optim_wrapper(optim_cfg) # 2.3 input is a dict which contains multiple configs - optim_cfg = dict( - linear1=dict(type='SGD', lr=0.01), - linear2=dict(type='Adam', lr=0.02), + optim_wrapper_cfg = dict( + linear1=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + linear2=dict( + type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), constructor='ToyMultipleOptimizerConstructor') - optimizer = runner.build_optimizer(optim_cfg) - self.assertIsInstance(optimizer, dict) - self.assertIsInstance(optimizer['linear1'], SGD) - self.assertIsInstance(optimizer['linear2'], Adam) + optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) + self.assertIsInstance(optim_wrapper, OptimWrapperDict) + self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD) + self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam) def test_build_param_scheduler(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_param_scheduler' runner = Runner.from_cfg(cfg) - # `build_optimizer` should be called before `build_param_scheduler` + # `build_optim_wrapper` should be called before + # `build_param_scheduler` cfg = dict(type='MultiStepLR', milestones=[1, 2]) - runner.optimizer = dict( - key1=dict(type='SGD', lr=0.01), - key2=dict(type='Adam', lr=0.02), + runner.optim_wrapper = dict( + key1=dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)), + key2=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.02)), ) - with self.assertRaisesRegex(RuntimeError, 'should be called before'): + with self.assertRaisesRegex(AssertionError, 'should be called before'): runner.build_param_scheduler(cfg) - runner.optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01)) + runner.optim_wrapper = runner.build_optim_wrapper( + dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01))) + param_schedulers = runner.build_param_scheduler(cfg) + self.assertIsInstance(param_schedulers, list) + self.assertEqual(len(param_schedulers), 1) + self.assertIsInstance(param_schedulers[0], MultiStepLR) # 1. test one optimizer and one parameter scheduler # 1.1 input is a ParamScheduler object - param_scheduler = MultiStepLR(runner.optimizer, milestones=[1, 2]) + param_scheduler = MultiStepLR(runner.optim_wrapper, milestones=[1, 2]) param_schedulers = runner.build_param_scheduler(param_scheduler) self.assertEqual(len(param_schedulers), 1) self.assertEqual(id(param_schedulers[0]), id(param_scheduler)) # 1.2 input is a dict - cfg = dict(type='MultiStepLR', milestones=[1, 2]) param_schedulers = runner.build_param_scheduler(param_scheduler) self.assertEqual(len(param_schedulers), 1) self.assertIsInstance(param_schedulers[0], MultiStepLR) @@ -715,9 +724,11 @@ class TestRunner(TestCase): # 3. test multiple optimizers and list of parameter schedulers optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01) + optim_wrapper1 = OptimWrapper(optimizer1) optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02) - optim_cfg = dict(key1=optimizer1, key2=optimizer2) - runner.optimizer = runner.build_optimizer(optim_cfg) + optim_wrapper2 = OptimWrapper(optimizer2) + optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) + runner.optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) cfg = [ dict(type='MultiStepLR', milestones=[1, 2]), dict(type='StepLR', step_size=1) @@ -743,7 +754,8 @@ class TestRunner(TestCase): self.assertEqual(len(param_schedulers['key2']), 2) # 5. test converting epoch-based scheduler to iter-based - runner.optimizer = runner.build_optimizer(dict(type='SGD', lr=0.01)) + runner.optim_wrapper = runner.build_optim_wrapper( + dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01))) # 5.1 train loop should be built before converting scheduler cfg = dict( @@ -752,7 +764,7 @@ class TestRunner(TestCase): AssertionError, 'Scheduler can only be converted to iter-based when ' 'train loop is built.'): - param_schedulers = runner.build_param_scheduler(cfg) + runner.build_param_scheduler(cfg) # 5.2 convert epoch-based to iter-based scheduler cfg = dict( @@ -947,7 +959,7 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_train1' cfg.pop('train_dataloader') cfg.pop('train_cfg') - cfg.pop('optimizer') + cfg.pop('optim_wrapper') cfg.pop('param_scheduler') runner = Runner.from_cfg(cfg) with self.assertRaisesRegex(RuntimeError, 'should not be None'): @@ -1063,7 +1075,7 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_individually_val' cfg.pop('train_dataloader') cfg.pop('train_cfg') - cfg.pop('optimizer') + cfg.pop('optim_wrapper') cfg.pop('param_scheduler') cfg.pop('test_dataloader') cfg.pop('test_cfg') @@ -1091,7 +1103,7 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_individually_test' cfg.pop('train_dataloader') cfg.pop('train_cfg') - cfg.pop('optimizer') + cfg.pop('optim_wrapper') cfg.pop('param_scheduler') cfg.pop('val_dataloader') cfg.pop('val_cfg') @@ -1132,15 +1144,15 @@ class TestRunner(TestCase): get_priority('BELOW_NORMAL')) # 1.3 `hook` is a hook object - optimizer_hook = OptimizerHook() - runner.register_hook(optimizer_hook) + runtime_info_hook = RuntimeInfoHook() + runner.register_hook(runtime_info_hook) self.assertEqual(len(runner._hooks), 2) - # The priority of `OptimizerHook` is `HIGH` which is greater than + # The priority of `runtime_info_hook` is `HIGH` which is greater than # `IterTimerHook`, so the first item of `_hooks` should be - # `OptimizerHook` - self.assertTrue(isinstance(runner._hooks[0], OptimizerHook)) + # `runtime_info_hook` + self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook)) self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('HIGH')) + get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH')) # 2. test `priority` parameter # `priority` argument is not None and it will be set as priority of @@ -1198,29 +1210,6 @@ class TestRunner(TestCase): self.assertEqual(len(runner._hooks), 8) self.assertTrue(isinstance(runner._hooks[7], ToyHook)) - def test_call_hook(self): - # test unexpected argument in `call_hook` - cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_call_hook1' - runner = Runner.from_cfg(cfg) - runner._hooks = [] - custom_hooks = [dict(type='ToyHook3')] - runner.register_custom_hooks(custom_hooks) - with self.assertRaisesRegex( - TypeError, - r"got an unexpected keyword argument 'batch_idx' in " - r''): - runner.call_hook('before_train_iter', batch_idx=0, data_batch=None) - - # test call hook with expected arguments - cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.experiment_name = 'test_call_hook2' - runner = Runner.from_cfg(cfg) - runner._hooks = [] - custom_hooks = [dict(type='ToyHook3')] - runner.register_custom_hooks(custom_hooks) - runner.call_hook('before_train_iter', data_batch=None) - def test_register_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_register_hooks' @@ -1229,7 +1218,7 @@ class TestRunner(TestCase): runner._hooks = [] custom_hooks = [dict(type='ToyHook')] runner.register_hooks(custom_hooks=custom_hooks) - # 7 default hooks + custom hook (ToyHook) + # six default hooks + custom hook (ToyHook) self.assertEqual(len(runner._hooks), 8) self.assertTrue(isinstance(runner._hooks[7], ToyHook)) @@ -1336,7 +1325,7 @@ class TestRunner(TestCase): # 1.2 test `load_checkpoint` cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint2' - cfg.optimizer = dict(type='SGD', lr=0.2) + cfg.optim_wrapper = dict(type='SGD', lr=0.2) cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.load_checkpoint(path) @@ -1345,22 +1334,24 @@ class TestRunner(TestCase): self.assertTrue(runner._has_loaded) # load checkpoint will not initialize optimizer and param_schedulers # objects - self.assertIsInstance(runner.optimizer, dict) + self.assertIsInstance(runner.optim_wrapper, dict) self.assertIsInstance(runner.param_schedulers, list) self.assertIsInstance(runner.param_schedulers[0], dict) # 1.3 test `resume` cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint3' - cfg.optimizer = dict(type='SGD', lr=0.2) + cfg.optim_wrapper = dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)) cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.resume(path) self.assertEqual(runner.epoch, 3) self.assertEqual(runner.iter, 12) self.assertTrue(runner._has_loaded) - self.assertIsInstance(runner.optimizer, SGD) - self.assertEqual(runner.optimizer.param_groups[0]['lr'], 0.0001) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) + self.assertEqual(runner.optim_wrapper.param_groups[0]['lr'], 0.0001) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1}) @@ -1373,7 +1364,7 @@ class TestRunner(TestCase): self.assertEqual(runner.epoch, 3) self.assertEqual(runner.iter, 12) self.assertTrue(runner._has_loaded) - self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) # 1.5 test resume from a specified checkpoint @@ -1386,46 +1377,48 @@ class TestRunner(TestCase): self.assertEqual(runner.epoch, 1) self.assertEqual(runner.iter, 4) self.assertTrue(runner._has_loaded) - self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) # 1.6 multiple optimizers cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint6' - cfg.optimizer = dict( - linear1=dict(type='SGD', lr=0.01), - linear2=dict(type='Adam', lr=0.02), - constructor='ToyMultipleOptimizerConstructor', - ) + cfg.optim_wrapper = dict( + linear1=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + linear2=dict( + type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') # disable OptimizerHook because it only works with one optimizer cfg.default_hooks = dict(optimizer=None) runner = Runner.from_cfg(cfg) runner.train() path = osp.join(self.temp_dir, 'epoch_3.pth') self.assertTrue(osp.exists(path)) - self.assertEqual(runner.optimizer['linear1'].param_groups[0]['lr'], + self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'], 0.0001) - self.assertIsInstance(runner.optimizer['linear2'], Adam) - self.assertEqual(runner.optimizer['linear2'].param_groups[0]['lr'], + self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam) + self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'], 0.0002) cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint7' - cfg.optimizer = dict( - linear1=dict(type='SGD', lr=0.2), - linear2=dict(type='Adam', lr=0.03), - constructor='ToyMultipleOptimizerConstructor', - ) + cfg.optim_wrapper = dict( + linear1=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)), + linear2=dict( + type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)), + constructor='ToyMultipleOptimizerConstructor') cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) cfg.default_hooks = dict(optimizer=None) runner = Runner.from_cfg(cfg) runner.resume(path) - self.assertIsInstance(runner.optimizer, dict) - self.assertIsInstance(runner.optimizer['linear1'], SGD) - self.assertEqual(runner.optimizer['linear1'].param_groups[0]['lr'], + self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict) + self.assertIsInstance(runner.optim_wrapper['linear1'].optimizer, SGD) + self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'], 0.0001) - self.assertIsInstance(runner.optimizer['linear2'], Adam) - self.assertEqual(runner.optimizer['linear2'].param_groups[0]['lr'], + self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam) + self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'], 0.0002) self.assertIsInstance(runner.param_schedulers, dict) self.assertEqual(len(runner.param_schedulers['linear1']), 1) @@ -1479,7 +1472,7 @@ class TestRunner(TestCase): self.assertEqual(runner.epoch, 0) self.assertEqual(runner.iter, 12) self.assertTrue(runner._has_loaded) - self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) # 2.4 test auto resume @@ -1491,7 +1484,7 @@ class TestRunner(TestCase): self.assertEqual(runner.epoch, 0) self.assertEqual(runner.iter, 12) self.assertTrue(runner._has_loaded) - self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) # 2.5 test resume from a specified checkpoint @@ -1504,5 +1497,5 @@ class TestRunner(TestCase): self.assertEqual(runner.epoch, 0) self.assertEqual(runner.iter, 3) self.assertTrue(runner._has_loaded) - self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)