diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index fe970acb..e9e8b379 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -111,8 +111,8 @@ class BaseModel(BaseModule): Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ - # enable automatic mixed precision training context. - with optim_wrapper.precision_context(): + # Enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): batch_inputs, data_samples = self.data_preprocessor(data, True) losses = self(batch_inputs, data_samples, mode='loss') parsed_losses, log_vars = self.parse_losses(losses) diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4084dde7..813f42c6 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -89,8 +89,8 @@ class MMDistributedDataParallel(DistributedDataParallel): Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ - # enable automatic mixed precision training context. - with optim_wrapper.precision_context(): + # Enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): batch_inputs, data_samples = self.module.data_preprocessor( data, training=True) losses = self(batch_inputs, data_samples, mode='loss') diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 16823e61..4bbd76a9 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -2,6 +2,7 @@ from contextlib import contextmanager import torch +import torch.nn as nn from torch.cuda.amp import GradScaler from mmengine.registry import OPTIM_WRAPPERS @@ -25,15 +26,21 @@ class AmpOptimWrapper(OptimWrapper): loss_scale (float or str or dict): The initial configuration of `torch.cuda.amp.GradScaler`. See more specific arguments introduction at `PyTorch AMP `_ # noqa: E501 + Defaults to ``dynamic``. - "dynamic": Initialize GradScale without any arguments. - float: Initialize GradScaler with ``init_scale``. - dict: Initialize GradScaler with more detail configuration. **kwargs: Keyword arguments passed to OptimWrapper. + + Note: + If you use ``IterBasedRunner`` and enable gradient accumulation, + the original `max_iters` should be multiplied by + ``accumulative_counts``. """ - def __init__(self, loss_scale=512., **kwargs): + def __init__(self, loss_scale='dynamic', **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') assert torch.cuda.is_available(), ( @@ -62,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper): loss (torch.Tensor): The loss of current iteration. """ self.loss_scaler.scale(loss).backward() + self._inner_count += 1 def step(self): """Update parameters with :attr:`loss_scaler`.""" @@ -104,7 +112,13 @@ class AmpOptimWrapper(OptimWrapper): self.optimizer.load_state_dict(state_dict) @contextmanager - def precision_context(self): - """A wrapper of ``torch.cuda.amp.autocast``""" - with torch.cuda.amp.autocast(): + def optim_context(self, model: nn.Module): + """Enables the context for mixed precision training, and enables the + context for disabling gradient synchronization during gradient + accumulation context. + + Args: + model (nn.Module): The training model. + """ + with super().optim_context(model), torch.cuda.amp.autocast(): yield diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 0bc03463..073b3480 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -73,7 +73,7 @@ class DefaultOptimWrapperConstructor: Optional fields are - any arguments of the corresponding optimizer wrapper type, - e.g., accumulative_iters, clip_grad, etc. + e.g., accumulative_counts, clip_grad, etc. The positional fields of ``optimizer`` are diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index d71249bc..db1794fd 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -27,21 +27,31 @@ class OptimWrapper: Args: optimizer (Optimizer): Optimizer used to update model parameters. - accumulative_iters (int): The number of iterations to accumulate + accumulative_counts (int): The number of iterations to accumulate gradients. The parameters will be updated per - ``accumulative_iters``. + ``accumulative_counts``. clip_grad (dict, optional): If ``clip_grad`` is not None, it will be the arguments of ``torch.nn.utils.clip_grad``. - Warnings: - If ``accumulative_iters`` is larger than 1, :meth:`update_params` must - be called in the context of ``accumulate_grad``. + Note: + If ``accumulative_counts`` is larger than 1, perform + :meth:`update_params` under the context of ``optim_context`` + could avoid unnecessary gradient synchronization. + + Note: + If you use ``IterBasedRunner`` and enable gradient accumulation, + the original `max_iters` should be multiplied by + ``accumulative_counts``. + + Note: + The subclass should ensure that once :meth:`update_params` is called, + ``_inner_count += 1`` is automatically performed. Examples: >>> # Config sample of OptimWrapper. >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', - >>> accumulative_iters=3, + >>> _accumulative_counts=1, >>> clip_grad=dict(max_norm=0.2)) >>> # Use OptimWrapper to update model. >>> import torch.nn as nn @@ -59,29 +69,32 @@ class OptimWrapper: >>> for data in dataloader: >>> loss = model(data) >>> optim_wrapper.update_params(loss) - >>> # Enable gradient accumulation. If model is a subclass instance of - >>> # DistributedDataParallel, ``accumulate_grad`` context manager can - >>> # avoid unnecessary gradient synchronize. + >>> # Enable gradient accumulation + >>> optim_wrapper_cfg = dict( + >>> type='OptimWrapper', + >>> _accumulative_counts=3, + >>> clip_grad=dict(max_norm=0.2)) + >>> ddp_model = DistributedDataParallel(model) + >>> optimizer = SGD(ddp_model.parameters(), lr=0.1) + >>> optim_wrapper = OptimWrapper(optimizer) + >>> optim_wrapper.initialize_count_status(0, len(dataloader)) + >>> # If model is a subclass instance of DistributedDataParallel, + >>> # `optim_context` context manager can avoid unnecessary gradient + >>> # synchronize. >>> for iter, data in enumerate(dataloader): - >>> with optim_wrapper.accumulate_grad( - >>> model, iter, len(dataloader)): + >>> with optim_wrapper.optim_context(ddp_model): >>> loss = model(data) - >>> optim_wrapper.update_params(loss) + >>> optim_wrapper.update_params(loss) """ def __init__(self, optimizer: Optimizer, - accumulative_iters: int = 1, + accumulative_counts: int = 1, clip_grad: Optional[dict] = None): - assert accumulative_iters > 0, ( - 'accumulative_iters at least greater than or equal to 1') - self.accumulative_iters = accumulative_iters - # `max_iters` and `cur_iter` is only valid in gradient accumulative - # mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be - # updated in the ``accumulate_grad`` context that is enabled in - # `runner.train_loop`. - self.cur_iter = 0 - self.max_iters = 0 + assert accumulative_counts > 0, ( + '_accumulative_counts at least greater than or equal to 1') + self._accumulative_counts = accumulative_counts + assert isinstance(optimizer, Optimizer), ( 'optimizer must be a `torch.optim.Optimizer` instance, but got ' f'{type(optimizer)}') @@ -90,13 +103,27 @@ class OptimWrapper: if clip_grad is not None: # clip_grad_kwargs should not be non-empty dict. assert isinstance(clip_grad, dict) and clip_grad, ( - 'If `clip_grad_kwargs` is not None, it should be a `dict` ' + 'If `clip_grad` is not None, it should be a `dict` ' 'which is the arguments of `torch.nn.utils.clip_grad`') self.clip_grad_kwargs = clip_grad self.logger = MMLogger.get_current_instance() # Used to update `grad_norm` log message. self.message_hub = MessageHub.get_current_instance() - self.iter_status_initialized = False + self._inner_count = 0 + # `_max_counts` means the total number of parameter updates. It + # ensures that the gradient of the last few iterations will not be + # lost when the `_max_counts` is not divisible by + # `accumulative_counts`. + self._max_counts = -1 + # If `_inner_count` is smaller than `_divisible_counts`, the loss + # factor used for gradient accumulation should be the same as + # `_accumulative_counts`. If `_max_counts` has not been initialized, + # the loss factor will always be the same as `_accumulative_counts`. + self._divisible_counts = -1 + # The `_remainder_iter` is used for calculating loss factor at the + # last few iterations. If `_max_counts` has not been initialized, + # the loss factor will always be the same as `_accumulative_counts`. + self._remainder_counts = -1 def update_params(self, loss: torch.Tensor) -> None: """Update parameters in :attr:`optimizer`. @@ -104,36 +131,12 @@ class OptimWrapper: Args: loss (torch.Tensor): A tensor for back propagation. """ - if self.accumulative_iters == 1: - # update parameters without gradient accumulation. The gradient - # should not be rescaled and `loss_factor=1`. - loss_factor = 1 - else: - # gradient accumulation must be called in the context of - # ``accumulate_grad``. - assert hasattr(self, 'divisible_iters'), ( - 'gradient accumulation must be performed in the context of' - '`OptimWrapper.accumulate_grad`') - # if `self.accumulative_iters > 1`, the gradient needs to be - # rescaled and accumulated. In most cases, `loss_factor` equals to - # `self.accumulative_iters`. However `self.max_iters` may not be - # divisible `self.by accumulative_iters`, so the `loss_scale` for - # the last few iterations needs to be recalculated. - if self.cur_iter < self.divisible_iters: - loss_factor = self.accumulative_iters - else: - loss_factor = self.remainder_iters - assert loss_factor > 0, ( - 'loss_factor should be larger than zero! This error could ' - 'happened when gradient accumulation context enabled with an ' - 'error `cur_iter` or `max_iters` please check your loop') - - loss = loss / loss_factor + loss = self.scale_loss(loss) self.backward(loss) - # Update parameters only if `self.cur_iter` is divisible by - # `self.accumulative_iters` or `self.cur_iter` equals to - # `self.max_iters` - if self._should_update(self.cur_iter, self.max_iters): + # Update parameters only if `self._inner_count` is divisible by + # `self._accumulative_counts` or `self._inner_count` equals to + # `self._max_counts` + if self.should_update(): self.step() self.zero_grad() @@ -145,10 +148,15 @@ class OptimWrapper: required logic. For example, ``torch.cuda.amp`` require some extra operation on GradScaler during backward process. + Note: + If subclasses inherit from ``OptimWrapper`` override + ``backward``, ``_inner_count +=1`` must be implemented. + Args: loss (torch.Tensor): The loss of current iteration. """ loss.backward() + self._inner_count += 1 def zero_grad(self) -> None: """A wrapper of ``Optimizer.zero_grad``. @@ -218,7 +226,7 @@ class OptimWrapper: Provide unified interface to get learning rate of optimizer. Returns: - List[float]: Learning rate of the optimizer. + Dict[str, List[float]]: Learning rate of the optimizer. """ lr = [group['lr'] for group in self.param_groups] return dict(lr=lr) @@ -229,7 +237,7 @@ class OptimWrapper: Provide unified interface to get momentum of optimizer. Returns: - List[float]: Momentum of the optimizer. + Dict[str, List[float]]: Momentum of the optimizer. """ momentum = [] for group in self.param_groups: @@ -244,49 +252,35 @@ class OptimWrapper: return dict(momentum=momentum) @contextmanager - def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int): - """A Context manager for gradient accumulation and avoiding unnecessary - gradient synchronization during gradient accumulation. + def optim_context(self, model: nn.Module): + """A Context for gradient accumulation and automatic mix precision + training. + + If subclasses need to enable the context for mix precision training, + e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be + enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 + training, ``optim_context`` will only enable the context for + blocking the unnecessary gradient synchronization during gradient + accumulation If model is an instance with ``no_sync`` method (which means blocking the gradient synchronization) and - ``self.accumulative_iters != 1``. The model will not automatically + ``self._accumulative_counts != 1``. The model will not automatically synchronize gradients if ``cur_iter`` is divisible by - ``self.accumulative_iters``. Otherwise, this method will enable an + ``self._accumulative_counts``. Otherwise, this method will enable an empty context. - Warnings: - This context manager must be enabled if you want to use - gradient accumulation. - Args: model (nn.Module): The training model. - cur_iter (int): Current iteration during training process. - max_iters (int): Maximum training iteration. """ - assert max_iters > 0, '`max_iters` must be larger than zero' - self.cur_iter = cur_iter - self.max_iters = max_iters - if not self.iter_status_initialized: - self._initilize_iter_status(model) # During gradient accumulation process, the gradient synchronize # should only happen before updating parameters. - if (not self._should_update(cur_iter, max_iters) - and hasattr(model, 'no_sync')): + if not self.should_sync() and hasattr(model, 'no_sync'): with model.no_sync(): yield else: yield - @contextmanager - def precision_context(self): - """precision context which enables an empty context by default. - - The subclass used for mixed or low precision training needs to override - this method. - """ - yield - def _clip_grad(self) -> None: """Clip the gradients of parameters.""" params: List[torch.Tensor] = [] @@ -300,50 +294,117 @@ class OptimWrapper: **self.clip_grad_kwargs) self.message_hub.update_scalar('train/grad_norm', float(grad_norm)) - def _initilize_iter_status(self, model: nn.Module) -> None: + def initialize_count_status(self, model: nn.Module, init_counts: int, + max_counts: int) -> None: """Initialize gradient accumulation related attributes. + ``OptimWrapper`` can be used without calling + ``initialize_iter_status``. However, Consider the case of ``len( + dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is + not divisible by 3, the last iteration will not trigger + ``optimizer.step()``, resulting in one less parameter updating. + Args: model (nn.Module): Training model + init_counts (int): The initial value of the inner count. + max_counts (int): The maximum value of the inner count. """ - if self.max_iters % self.accumulative_iters != 0: + self._inner_count = init_counts + self._max_counts = max_counts + if self._inner_count % self._accumulative_counts != 0: self.logger.warning( - 'Resume iter number is not divisible by accumulative_iters in ' - 'GradientCumulativeOptimizerHook, which means the gradient of ' - 'some iters is lost and the result may be influenced slightly.' - ) + 'Resumed iteration number is not divisible by ' + '`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' + 'which means the gradient of some iterations is lost and the ' + 'result may be influenced slightly.') - if has_batch_norm(model) and self.accumulative_iters > 1: + if has_batch_norm(model) and self._accumulative_counts > 1: self.logger.warning( 'Gradient accumulative may slightly decrease ' 'performance because the model has BatchNorm layers.') - residual_iters = self.max_iters - self.cur_iter + residual_counts = max_counts - init_counts # The maximum number of training iteration that is divisible by - # accumulative_iters. - self.divisible_iters = ( - residual_iters // self.accumulative_iters * - self.accumulative_iters) - # Remainder of ``self.max_iters`` divided by ``self.max_iters`` - self.remainder_iters = residual_iters - self.divisible_iters - self.iter_status_initialized = True + # `_accumulative_counts`. + self._divisible_counts = ( + residual_counts // self._accumulative_counts * + self._accumulative_counts) + # Remainder of `_max_counts` divided by `_accumulative_counts` + self._remainder_counts = residual_counts - self._divisible_counts - def _should_update(self, cur_iter: int, max_iters: int) -> bool: - """Should optim_wrapper update parameters or synchronized gradient at - current iteration. + def should_update(self) -> bool: + """Decide whether the parameters should be updated at the current + iteration. - Args: - cur_iter (int): Current iteration of training process. - max_iters (int): Maximum iterations of training process. + Called by :meth:`update_params` and check whether the optimizer + wrapper should update parameters at current iteration. Returns: - bool: Whether to update parameters or synchronized gradient. + bool: Whether to update parameters. """ - return ((cur_iter + 1) % self.accumulative_iters == 0 - or cur_iter + 1 == max_iters) + return (self._inner_count % self._accumulative_counts == 0 + or self._inner_count == self._max_counts) + + def should_sync(self) -> bool: + """Decide whether the automatic gradient synchronization should be + allowed at the current iteration. + + It takes effect when gradient accumulation is used to skip + synchronization at the iterations where the parameter is not updated. + + Since ``should_sync`` is called by :meth:`optim_context`, and it is + called before :meth:`backward` which means ``self._inner_count += 1`` + has not happened yet. Therefore, ``self._inner_count += 1`` should be + performed manually here. + + Returns: + bool: Whether to block the automatic gradient synchronization. + """ + return ((self._inner_count + 1) % self._accumulative_counts == 0 + or (self._inner_count + 1) == self._max_counts) + + def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: + """Get scaled loss according to ``_accumulative_counts``, + ``_inner_count`` and max_counts. + + Args: + loss (torch.Tensor): Original loss calculated by model. + + Returns: + loss (torch.Tensor): Scaled loss. + """ + if self._accumulative_counts == 1: + # update parameters without gradient accumulation. The gradient + # should not be rescaled and `loss_factor=1`. + loss_factor = 1 + elif self._max_counts == -1: + loss_factor = self._accumulative_counts + else: + # if `self._accumulative_counts > 1`, the gradient needs to be + # rescaled and accumulated. In most cases, `loss_factor` equals to + # `self._accumulative_counts`. However, `self._max_counts` may not + # be divisible by `self._accumulative_counts`, so the + # `loss_scale` for the last few iterations needs to be + # recalculated. + if self._inner_count < self._divisible_counts: + loss_factor = self._accumulative_counts + else: + loss_factor = self._remainder_counts + assert loss_factor > 0, ( + 'loss_factor should be larger than zero! This error could ' + 'happened when initialize_iter_status called with an ' + 'error `init_counts` or `max_counts`') + + loss = loss / loss_factor + return loss + + @property + def inner_count(self): + """Get the number of updating parameters of optimizer wrapper.""" + return self._inner_count def __repr__(self): - wrapper_info = f'Type: {type(self).__name__}\n' \ - f'accumulative_iters: {self.accumulative_iters}\n' \ - f'optimizer: \n' + wrapper_info = (f'Type: {type(self).__name__}\n' + f'_accumulative_counts: {self._accumulative_counts}\n' + 'optimizer: \n') optimizer_str = repr(self.optimizer) + '\n' return wrapper_info + optimizer_str diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py index 98293e9d..7b3859cf 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py +++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings -from contextlib import ExitStack, contextmanager +from contextlib import contextmanager from typing import Dict, Iterator, List, Tuple import torch @@ -41,24 +40,10 @@ class OptimWrapperDict(OptimWrapper): """ def __init__(self, **optim_wrapper_dict: OptimWrapper): - first_key = next(iter(optim_wrapper_dict)) - first_optim_wrapper = optim_wrapper_dict[first_key] - assert isinstance(first_optim_wrapper, OptimWrapper), ( - 'Each argument of `OptimWrapperDict` must be an `OptimWrapper ' - 'instance`') - optim_wrapper_class = type(first_optim_wrapper) for key, value in optim_wrapper_dict.items(): - assert type(value) == optim_wrapper_class, ( - f'All optimizer wrappers should have the same type, but found' - f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}' - ) - if value.accumulative_iters != 1: - warnings.warn( - f'The `accumulative_iters` of {key} is ' - f'{value.accumulative_iters}. OptimWrapperDict ' - 'will not enable any `accumulate_grad` context of its ' - 'optimizer wrappers. You should access the corresponding ' - 'optimizer wrapper to enable the context.') + assert isinstance(value, OptimWrapper), ( + '`OptimWrapperDict` only accept OptimWrapper instance, ' + f'but got {key}: {type(value)}') self.optim_wrappers = optim_wrapper_dict def update_params(self, loss: torch.Tensor) -> None: @@ -69,9 +54,8 @@ class OptimWrapperDict(OptimWrapper): Therefore, this method is not implemented. The optimizer wrapper of OptimWrapperDict should be accessed and call its `update_params. """ - raise NotImplementedError( - 'You should access the OptimWrapper of the ' - 'OptimWrapperDict and call its `update_params`') + raise NotImplementedError('`update_params` should be called by each ' + 'optimizer separately`') def backward(self, loss: torch.Tensor) -> None: """Since OptimWrapperDict doesn't know which optimizer wrapper's @@ -81,14 +65,14 @@ class OptimWrapperDict(OptimWrapper): The optimizer wrapper of OptimWrapperDict should be accessed and call its `backward. """ - raise NotImplementedError('You should access the OptimWrapper of the ' - 'OptimWrapperDict and call its `backward`') + raise NotImplementedError('`backward` should be called by each ' + 'optimizer separately`') def step(self) -> None: """Since the backward method is not implemented, the step should not be implemented either.""" - raise NotImplementedError('You should access the OptimWrapper of the ' - 'OptimWrapperDict and call its `step`') + raise NotImplementedError('`step` should be called by each ' + 'optimizer separately`') def zero_grad(self) -> None: """Set the gradients of all optimizer wrappers to zero.""" @@ -96,49 +80,29 @@ class OptimWrapperDict(OptimWrapper): optim_wrapper.zero_grad() @contextmanager - def precision_context(self): - optim_wrapper = next(iter(self.optim_wrappers.values())) - with optim_wrapper.precision_context(): - yield + def optim_context(self, model: nn.Module): + """``optim_context`` should be called by each optimizer separately.""" + raise NotImplementedError( + '`optim_context` should be called by each optimizer separately') - @contextmanager - def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int): - """Enable ``accumulate_grad`` contexts of all optimizer wrappers. + def initialize_count_status(self, model: nn.Module, cur_iter, + max_iters) -> None: + """Do nothing but provide unified interface for :obj:`OptimWrapper` - Warning: - Consider there is only one ``model`` arguments for all - optimizer wrappers, all optimizer wrappers are working under the - same ``model.no_sync`` context. For example, there is a model - composed of model_a(optimizer_a) and model_b(optimizer_b). - ``OptimWrapperDict.accumulate_grad`` will further - call ``model.no_sync``, which will block the gradient - synchronization of both a and b. If optimizer_a and - optimizer_b have different ``accumulative_iters``, and want to - block the gradient synchronization of model_a and model_b - separately, the model should not implement the ``no_sync`` - method(or enable an empty context). The ``accumulate_grad`` context - should be enabled inside the model by accessing corresponding - optimizer wrapper. + Since ``OptimWrapperDict`` does not know the correspondence between + model and optimizer wrapper. ``initialize_iter_status`` will do nothing + and each optimizer wrapper should call ``initialize_iter_status`` + separately. """ - with ExitStack() as stack: - for optim_wrapper in self.optim_wrappers.values(): - stack.enter_context( - optim_wrapper.accumulate_grad(model, cur_iter, max_iters)) - yield + return - def load_state_dict(self, state_dict: dict) -> None: - """Load the state dictionary from the ``state_dict``. - - Args: - state_dict (dict): Each key-value pair in `state_dict` represents - the name and the state dictionary of corresponding - :obj:`OptimWrapper`. - """ - for name, _state_dict in state_dict.items(): - assert name in self.optim_wrappers, ( - f'Mismatched `state_dict`! cannot found {name} in ' - 'OptimWrapperDict') - self.optim_wrappers[name].load_state_dict(_state_dict) + @property + def param_groups(self): + """Returns the parameter groups of each OptimWrapper.""" + param_groups = dict() + for key, value in self.optim_wrappers.items(): + param_groups[key] = value.param_groups + return param_groups def get_lr(self) -> Dict[str, List[float]]: """Get the learning rate of all optimizers. @@ -175,6 +139,20 @@ class OptimWrapperDict(OptimWrapper): state_dict[name] = optim_wrapper.state_dict() return state_dict + def load_state_dict(self, state_dict: dict) -> None: + """Load the state dictionary from the ``state_dict``. + + Args: + state_dict (dict): Each key-value pair in `state_dict` represents + the name and the state dictionary of corresponding + :obj:`OptimWrapper`. + """ + for name, _state_dict in state_dict.items(): + assert name in self.optim_wrappers, ( + f'Mismatched `state_dict`! cannot found {name} in ' + 'OptimWrapperDict') + self.optim_wrappers[name].load_state_dict(_state_dict) + def items(self) -> Iterator[Tuple[str, OptimWrapper]]: """A generator to get the name and corresponding :obj:`OptimWrapper`""" diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 6e186855..34546cc2 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -101,12 +101,9 @@ class EpochBasedTrainLoop(BaseLoop): 'before_train_iter', batch_idx=idx, data_batch=data_batch) # Enable gradient accumulation mode and avoid unnecessary gradient # synchronization during gradient accumulation process. - with self.runner.optim_wrapper.accumulate_grad(self.runner.model, - self._iter, - self._max_iters): - # outputs should be a dict of loss. - outputs = self.runner.model.train_step( - data_batch, optim_wrapper=self.runner.optim_wrapper) + # outputs should be a dict of loss. + outputs = self.runner.model.train_step( + data_batch, optim_wrapper=self.runner.optim_wrapper) self.runner.call_hook( 'after_train_iter', @@ -204,19 +201,16 @@ class IterBasedTrainLoop(BaseLoop): 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) # Enable gradient accumulation mode and avoid unnecessary gradient # synchronization during gradient accumulation process. - with self.runner.optim_wrapper.accumulate_grad(self.runner.model, - self._iter, - self._max_iters): - # train_logs should be a dict of loss. - train_logs = self.runner.model.train_step( - data_batch, optim_wrapper=self.runner.optim_wrapper) - self.runner.message_hub.update_info('train_logs', train_logs) + # outputs should be a dict of loss. + outputs = self.runner.model.train_step( + data_batch, optim_wrapper=self.runner.optim_wrapper) + self.runner.message_hub.update_info('train_logs', outputs) self.runner.call_hook( 'after_train_iter', batch_idx=self._iter, data_batch=data_batch, - outputs=train_logs) + outputs=outputs) self._iter += 1 diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 000b1c7e..0d829e13 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -926,7 +926,7 @@ class Runner: >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper Type: OptimWrapper - accumulative_iters: 1 + accumulative_counts: 1 optimizer: SGD ( Parameter Group 0 @@ -941,7 +941,7 @@ class Runner: >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper Type: OptimWrapper - accumulative_iters: 1 + accumulative_counts: 1 optimizer: SGD ( Parameter Group 0 @@ -965,7 +965,7 @@ class Runner: >>> optim_wrapper name: generator Type: OptimWrapper - accumulative_iters: 1 + accumulative_counts: 1 optimizer: SGD ( Parameter Group 0 @@ -977,7 +977,7 @@ class Runner: ) name: discriminator Type: OptimWrapper - accumulative_iters: 1 + accumulative_counts: 1 optimizer: 'discriminator': Adam ( Parameter Group 0 @@ -1528,7 +1528,6 @@ class Runner: # `build_optimizer` should be called before `build_param_scheduler` # because the latter depends on the former self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) - # Automatically scaling lr by linear scaling rule self.scale_lr(self.optim_wrapper, self.auto_scale_lr) @@ -1540,9 +1539,14 @@ class Runner: self._val_loop = self.build_val_loop( self._val_loop) # type: ignore - # TODO: add a contextmanager to avoid calling `before_run` many times self.call_hook('before_run') + # Initiate inner count of `optim_wrapper`. + self.optim_wrapper.initialize_count_status( + self.model, + self._train_loop.iter, # type: ignore + self._train_loop.max_iters) # type: ignore + # TODO: add a contextmanager to avoid calling `before_run` many times # make sure checkpoint-related hooks are triggered after `before_run` self.load_or_resume() diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 2a7e74be..0efe338b 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -8,9 +8,10 @@ import torch.distributed as torch_dist import torch.nn as nn from torch.optim import SGD +from mmengine.dist import all_gather from mmengine.model import (BaseModel, MMDistributedDataParallel, MMSeparateDistributedDataParallel) -from mmengine.optim import OptimWrapper, OptimWrapperDict +from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase @@ -23,9 +24,9 @@ class ToyModel(BaseModel): self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x, data_samples=None, mode='tensor'): + x = self.conv1(x) + x = self.conv2(x) if mode == 'loss': - x = self.conv1(x) - x = self.conv2(x) return dict(loss=x) elif mode == 'predict': return x @@ -58,24 +59,41 @@ class ComplexModel(BaseModel): pass -class TestModelWrapper(MultiProcessTestCase): +class TestDistributedDataParallel(MultiProcessTestCase): def setUp(self) -> None: super().setUp() self._spawn_processes() + @unittest.skipIf( + not torch.cuda.is_available(), reason='cuda should be available') def test_train_step(self): self._init_dist_env(self.rank, self.world_size) - # Test `optim_wrapper` is a instance of `OptimWrapper` - model = ToyModel() + # Mixed precision training and gradient asynchronous should be valid at + # the same time + model = ToyModel().cuda() ddp_model = MMDistributedDataParallel(module=model) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) - inputs = torch.randn(3, 1, 1) * self.rank * 255 + optim_wrapper = AmpOptimWrapper( + optimizer=optimizer, accumulative_counts=3) + inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=MagicMock()) + res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss'] + self.assertIs(res.dtype, torch.float16) + grad = ddp_model.module.conv1.weight.grad + all_grads = all_gather(grad) + with self.assertRaises(AssertionError): + assert_allclose(all_grads[0], all_grads[1]) + + # Gradient accumulation + ddp_model.train_step([data], optim_wrapper=optim_wrapper) + + # Test update params and clean grads. ddp_model.train_step([data], optim_wrapper=optim_wrapper) grad = ddp_model.module.conv1.weight.grad - assert_allclose(grad, torch.zeros_like(grad)) + all_grads = all_gather(grad) + assert_allclose(all_grads[0], torch.zeros_like(all_grads[0])) + assert_allclose(all_grads[1], torch.zeros_like(all_grads[0])) def test_val_step(self): self._init_dist_env(self.rank, self.world_size) @@ -107,7 +125,7 @@ class TestModelWrapper(MultiProcessTestCase): @unittest.skipIf( not torch.cuda.is_available(), reason='cuda should be available') -class TestMMSeparateDistributedDataParallel(TestModelWrapper): +class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): def test_train_step(self): self._init_dist_env(self.rank, self.world_size) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 22a60a20..e65ac448 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -45,7 +45,7 @@ class ToyModel2(nn.Module): class TestOptimWrapper(MultiProcessTestCase): - # Test `OptimWrapper.accumulate_grad` will block the gradient + # Test `OptimWrapper.optim_context` will block the gradient # synchronization when using gradient accumulation strategy in distributed # data parallel training. def setUp(self) -> None: @@ -61,94 +61,84 @@ class TestOptimWrapper(MultiProcessTestCase): def test_init(self): optim_wrapper = OptimWrapper(self.optimizer) - self.assertEqual(optim_wrapper.optimizer, self.optimizer) + self.assertIs(optim_wrapper.optimizer, self.optimizer) self.assertIsNone(optim_wrapper.clip_grad_kwargs) - self.assertEqual(optim_wrapper.accumulative_iters, 1) + self.assertEqual(optim_wrapper._accumulative_counts, 1) self.assertIs(optim_wrapper.logger, self.logger) self.assertIs(optim_wrapper.message_hub, self.message_hub) + self.assertEqual(optim_wrapper._inner_count, 0) + self.assertEqual(optim_wrapper._max_counts, -1) + self.assertEqual(optim_wrapper._divisible_counts, -1) + self.assertEqual(optim_wrapper._remainder_counts, -1) with self.assertRaisesRegex(AssertionError, - 'If `clip_grad_kwargs` is not None'): + 'If `clip_grad` is not None'): OptimWrapper(self.optimizer, clip_grad=[]) def test_update_params(self): # Test update params every iteration. - optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1) + optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1) self._mock_method(optim_wrapper) loss = torch.tensor(1) optim_wrapper.update_params(loss) - optim_wrapper.backward.assert_called_with(torch.tensor(1)) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1)) optim_wrapper.step.assert_called_with() optim_wrapper.zero_grad.assert_called_with() - with optim_wrapper.accumulate_grad(self.model, 2, 100): - optim_wrapper.update_params(torch.tensor(1)) - optim_wrapper.backward.assert_called_with(torch.tensor(1)) - optim_wrapper.step.assert_called_with() - optim_wrapper.zero_grad.assert_called_with() - - # It will raise an error if `accumulative_iters > 1` and - # `accumulate_grad` is not enabled. - optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) + # Test gradient accumulation. + optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) self._mock_method(optim_wrapper) - with self.assertRaisesRegex(AssertionError, - 'gradient accumulation must be'): - optim_wrapper.update_params(loss) + # `iter=0`, accumulate gradient and do not update params. + loss = torch.tensor(1) + optim_wrapper.update_params(loss) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3) + optim_wrapper.step.assert_not_called() + optim_wrapper.zero_grad.assert_not_called() - # `iter=0`, Call `optimizer_step` first time. - with optim_wrapper.accumulate_grad( - self.model, cur_iter=0, max_iters=100): - loss = torch.tensor(1) - optim_wrapper.update_params(loss) - optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3) - optim_wrapper.step.assert_not_called() - optim_wrapper.zero_grad.assert_not_called() + # gradient accumulate + optim_wrapper.update_params(loss) + self.assertEqual(optim_wrapper._inner_count, 2) - # `iter=2`, Call `optimizer_step` first time. - with optim_wrapper.accumulate_grad( - self.model, cur_iter=2, max_iters=100): - optim_wrapper.update_params(loss) - optim_wrapper.step.assert_called() - optim_wrapper.zero_grad.assert_called() + # `iter=2`, update params. + optim_wrapper.update_params(loss) + optim_wrapper.step.assert_called() + optim_wrapper.zero_grad.assert_called() self._mock_method(optim_wrapper) - # Test end of training. - with optim_wrapper.accumulate_grad( - self.model, cur_iter=99, max_iters=100): - optim_wrapper.update_params(loss) - optim_wrapper.step.assert_called() - optim_wrapper.zero_grad.assert_called() - optim_wrapper.backward.assert_called_with(1) - # If ``accumulative_iters > 1``, call ``update_params`` with - # non-accumulate_grad context will raise an Assertion error - optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1) - optim_wrapper.accumulative_iters = 2 - with self.assertRaisesRegex(AssertionError, - 'gradient accumulation must be performed'): - optim_wrapper.update_params(loss) + # Test end of training without calling `initialize_iter_status` + optim_wrapper._inner_count = 99 + optim_wrapper.update_params(loss) + optim_wrapper.step.assert_not_called() + optim_wrapper.zero_grad.assert_not_called() + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3) + self._mock_method(optim_wrapper) - def test_initilize_iter_status(self): - optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) - optim_wrapper._initilize_iter_status(self.model) - self.assertEqual(optim_wrapper.divisible_iters, 0) - self.assertEqual(optim_wrapper.remainder_iters, 0) + # After calling `initialize_iter_status`, params will be updated at the + # last iteration, and the `loss_scaler` will be adjusted. + optim_wrapper.initialize_count_status(self.model, 99, 100) + optim_wrapper.update_params(loss) + optim_wrapper.step.assert_called() + optim_wrapper.zero_grad.assert_called() + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1)) + + def test_initialize_iter_status(self): + optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) + optim_wrapper.initialize_count_status(self.model, 0, 100) + self.assertEqual(optim_wrapper._divisible_counts, 99) + self.assertEqual(optim_wrapper._remainder_counts, 1) # Indivisible cur_iter will output warning. - optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) - optim_wrapper.cur_iter = 0 - optim_wrapper.max_iters = 100 + optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) with self.assertLogs(self.logger) as cm: - optim_wrapper._initilize_iter_status(self.model) + optim_wrapper.initialize_count_status(self.model, 2, 100) self.assertEqual(len(cm.output), 1) - self.assertRegex(cm.records[0].msg, 'Resume iter number is not') + self.assertRegex(cm.records[0].msg, 'Resumed iteration number') # Model with batch norm will output warning. - optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) - optim_wrapper.cur_iter = 0 - optim_wrapper.max_iters = 99 + optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) model = nn.BatchNorm2d(1) with self.assertLogs(self.logger) as cm: - optim_wrapper._initilize_iter_status(model) + optim_wrapper.initialize_count_status(model, 0, 99) self.assertEqual(len(cm.output), 1) self.assertRegex(cm.records[0].msg, 'Gradient accumulative') @@ -214,40 +204,40 @@ class TestOptimWrapper(MultiProcessTestCase): self.assertEqual(optim_wrapper.param_groups, self.optimizer.param_groups) - def test_accumulate_grad(self): + def test_optim_context(self): self._init_dist_env(self.rank, self.world_size) model = ToyModel2() ddp_model = DistributedDataParallel(model) optimizer = SGD(ddp_model.parameters(), lr=0.01) - optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) + optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1) optim_wrapper.zero_grad() - with optim_wrapper.accumulate_grad(ddp_model, 0, 100): - # Automatically sync grads if `accumulative_iters` = 1 - inputs = torch.randn(1, 1, 1, 1) * self.rank - ddp_model(inputs).sum().backward() - grad = model.conv.weight.grad - all_grads = all_gather(grad) - assert_allclose(all_grads[0], all_grads[1]) + + # Automatically sync grads if `accumulative_counts` = 1 + optim_wrapper.initialize_count_status(model, 0, 100) + inputs = torch.randn(1, 1, 1, 1) * self.rank + ddp_model(inputs).sum().backward() + grad = model.conv.weight.grad + all_grads = all_gather(grad) + assert_allclose(all_grads[0], all_grads[1]) # Do not sync grads when `optim_wrapper.cur_iter` cannot be - # divided by `optim_wrapper.accumulative_iters` - optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3) - with optim_wrapper.accumulate_grad(ddp_model, 0, 100): - ddp_model(inputs).sum().backward() - all_grads = all_gather(model.conv.weight.grad) - with self.assertRaises(AssertionError): - assert_allclose(all_grads[0], all_grads[1]) - - # sync grads if `cur_iter == 2` - with optim_wrapper.accumulate_grad(ddp_model, 2, 100): - ddp_model(inputs).sum().backward() - all_grads = all_gather(model.conv.weight.grad) + # divided by `optim_wrapper._accumulative_counts` + optim_wrapper = OptimWrapper(optimizer, accumulative_counts=3) + optim_wrapper.initialize_count_status(model, 0, 100) + with optim_wrapper.optim_context(ddp_model): + loss = ddp_model(inputs).sum() + loss.backward() + all_grads = all_gather(model.conv.weight.grad) + with self.assertRaises(AssertionError): assert_allclose(all_grads[0], all_grads[1]) - def test_precision_context(self): - optim_wrapper = OptimWrapper(self.optimizer) - with optim_wrapper.precision_context(): - pass + # sync grads if `cur_iter == 2` + optim_wrapper.initialize_count_status(model, 2, 100) + with optim_wrapper.optim_context(ddp_model): + loss = ddp_model(inputs).sum() + loss.backward() + all_grads = all_gather(model.conv.weight.grad) + assert_allclose(all_grads[0], all_grads[1]) def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" @@ -260,7 +250,12 @@ class TestOptimWrapper(MultiProcessTestCase): # TODO Test the real interface after add testing tool function which can # test the function or method is read called. def _mock_method(self, optim_wrapper): - optim_wrapper.backward = MagicMock() + + def mock_methd(loss): + optim_wrapper._inner_count += 1 + optim_wrapper.scaled_loss = loss + + optim_wrapper.backward = mock_methd optim_wrapper.step = MagicMock() optim_wrapper.zero_grad = MagicMock() @@ -376,9 +371,9 @@ class TestAmpOptimWrapper(TestCase): and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu version ' '>= 1.6') - def test_precision_context(self): + def test_optim_context(self): amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) - with amp_optim_wrapper.precision_context(): + with amp_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) self.assertEqual(y.dtype, torch.float16) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py index f3259e06..b5dd2c42 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py @@ -1,69 +1,92 @@ # Copyright (c) OpenMMLab. All rights reserved. -from contextlib import contextmanager from unittest import TestCase -from unittest.mock import patch +import torch import torch.nn as nn from torch.optim import SGD -from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict +from mmengine.optim import OptimWrapper, OptimWrapperDict class TestOptimWrapperDict(TestCase): def setUp(self) -> None: - model1 = nn.Linear(1, 1) - model2 = nn.Linear(1, 1) - self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8) - self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9) + self.model1 = nn.Linear(1, 1) + self.model2 = nn.Linear(1, 1) + self.optim1 = SGD(self.model1.parameters(), lr=0.1, momentum=0.8) + self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9) self.optim_wrapper1 = OptimWrapper(self.optim1) self.optim_wrapper2 = OptimWrapper(self.optim2) self.optimizers_wrappers = dict( optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) - @patch('torch.cuda.is_available', lambda: True) def test_init(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) self.assertEqual(optim_wrapper_dict.optim_wrappers, self.optimizers_wrappers) - # Different types of OptimWrapper will raise an error + with self.assertRaisesRegex(AssertionError, + '`OptimWrapperDict` only accept'): + OptimWrapperDict(**dict(optim1=self.optim1, optim2=self.optim2)) - with self.assertRaisesRegex( - AssertionError, 'All optimizer wrappers should have the same'): - optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2) - OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2) - - with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'): - optim_wrapper2 = OptimWrapper( - optimizer=self.optim2, accumulative_iters=2) - OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2) - - def test_accumulate_grad(self): - - @contextmanager - def context_a(a, b, *args, **kwargs): - a[0] = 100 - yield - a[0] = 1 - - @contextmanager - def context_b(a, b, *args, **kwargs): - b[0] = 200 - yield - b[0] = 2 - - a = [0] - b = [0] - # Test enter the context both of `optim_wrapper1` and `optim_wrapper1`. + def test_update_params(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.optim_wrapper1.accumulate_grad = context_a - self.optim_wrapper2.accumulate_grad = context_b - with optim_wrapper_dict.accumulate_grad(a, b, 0): - self.assertEqual(a[0], 100) - self.assertEqual(b[0], 200) + with self.assertRaisesRegex(NotImplementedError, + '`update_params` should be called'): + optim_wrapper_dict.update_params(1) - self.assertEqual(a[0], 1) - self.assertEqual(b[0], 2) + def test_backward(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + with self.assertRaisesRegex(NotImplementedError, + '`backward` should be called'): + optim_wrapper_dict.backward(1) + + def test_step(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + with self.assertRaisesRegex(NotImplementedError, + '`step` should be called'): + optim_wrapper_dict.step() + + def test_zero_grad(self): + # Test clear all grad + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.model1(torch.randn(1, 1)).sum().backward() + self.model2(torch.randn(1, 1)).sum().backward() + self.assertTrue((self.model1.weight.grad != 0).any()) + self.assertTrue((self.model2.weight.grad != 0).any()) + optim_wrapper_dict.zero_grad() + self.assertTrue((self.model1.weight.grad == 0).all()) + self.assertTrue((self.model2.weight.grad == 0).all()) + + def test_optim_context(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + with self.assertRaisesRegex(NotImplementedError, + '`optim_context` should be called'): + with optim_wrapper_dict.optim_context(self.model1): + yield + + def test_initialize_count_status(self): + # Test `initialize_count_status` can be called. + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + optim_wrapper_dict.initialize_count_status(self.model1, 1, 1) + + def test_param_groups(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + self.assertEqual(optim_wrapper_dict.param_groups['optim1'], + self.optim1.param_groups) + self.assertEqual(optim_wrapper_dict.param_groups['optim2'], + self.optim2.param_groups) + + def test_get_lr(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + lr = optim_wrapper_dict.get_lr() + self.assertEqual(lr['optim1.lr'], [0.1]) + self.assertEqual(lr['optim2.lr'], [0.2]) + + def test_get_momentum(self): + optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) + momentum = optim_wrapper_dict.get_momentum() + self.assertEqual(momentum['optim1.momentum'], [0.8]) + self.assertEqual(momentum['optim2.momentum'], [0.9]) def test_state_dict(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) @@ -109,18 +132,6 @@ class TestOptimWrapperDict(TestCase): list(optim_wrapper_dict.keys()), list(self.optimizers_wrappers.keys())) - def test_get_lr(self): - optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - lr = optim_wrapper_dict.get_lr() - self.assertEqual(lr['optim1.lr'], [0.1]) - self.assertEqual(lr['optim2.lr'], [0.2]) - - def test_get_momentum(self): - optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - momentum = optim_wrapper_dict.get_momentum() - self.assertEqual(momentum['optim1.momentum'], [0.8]) - self.assertEqual(momentum['optim2.momentum'], [0.9]) - def test_getitem(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) self.assertIs(self.optimizers_wrappers['optim1'], diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index a6650460..1097cffa 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1135,8 +1135,9 @@ class TestRunner(TestCase): cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)] cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2) runner = Runner.from_cfg(cfg) - runner.train() + self.assertEqual(runner.optim_wrapper._inner_count, 12) + self.assertEqual(runner.optim_wrapper._max_counts, 12) assert isinstance(runner.train_loop, EpochBasedTrainLoop) @@ -1183,6 +1184,8 @@ class TestRunner(TestCase): runner = Runner.from_cfg(cfg) runner.train() + self.assertEqual(runner.optim_wrapper._inner_count, 12) + self.assertEqual(runner.optim_wrapper._max_counts, 12) assert isinstance(runner.train_loop, IterBasedTrainLoop) self.assertEqual(len(epoch_results), 1)