[Refactor] Refactor the accumulate gradient implemention of OptimWrapper (#284)

* merge context

* update unit test

* add docstring

* fix bug in AmpOptimWrapper

* add docstring for backward

* add warning and docstring for accumuate gradient

* fix docstring

* fix docstring

* add params_group method

* fix as comment

* fix as comment

* make default_value of loss_scale to dynamic

* Fix docstring

* decouple should update and should no sync

* rename attribute in OptimWrapper

* fix docstring

* fix comment

* fix comment

* fix as comment

* fix as comment and add unit test
This commit is contained in:
Mashiro 2022-06-13 23:20:53 +08:00 committed by GitHub
parent fd295741ca
commit b7866021c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 435 additions and 357 deletions

View File

@ -111,8 +111,8 @@ class BaseModel(BaseModule):
Returns: Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
""" """
# enable automatic mixed precision training context. # Enable automatic mixed precision training context.
with optim_wrapper.precision_context(): with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.data_preprocessor(data, True) batch_inputs, data_samples = self.data_preprocessor(data, True)
losses = self(batch_inputs, data_samples, mode='loss') losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.parse_losses(losses) parsed_losses, log_vars = self.parse_losses(losses)

View File

@ -89,8 +89,8 @@ class MMDistributedDataParallel(DistributedDataParallel):
Returns: Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
""" """
# enable automatic mixed precision training context. # Enable automatic mixed precision training context.
with optim_wrapper.precision_context(): with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.module.data_preprocessor( batch_inputs, data_samples = self.module.data_preprocessor(
data, training=True) data, training=True)
losses = self(batch_inputs, data_samples, mode='loss') losses = self(batch_inputs, data_samples, mode='loss')

View File

@ -2,6 +2,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from mmengine.registry import OPTIM_WRAPPERS from mmengine.registry import OPTIM_WRAPPERS
@ -25,15 +26,21 @@ class AmpOptimWrapper(OptimWrapper):
loss_scale (float or str or dict): The initial configuration of loss_scale (float or str or dict): The initial configuration of
`torch.cuda.amp.GradScaler`. See more specific arguments `torch.cuda.amp.GradScaler`. See more specific arguments
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501 introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501
Defaults to ``dynamic``.
- "dynamic": Initialize GradScale without any arguments. - "dynamic": Initialize GradScale without any arguments.
- float: Initialize GradScaler with ``init_scale``. - float: Initialize GradScaler with ``init_scale``.
- dict: Initialize GradScaler with more detail configuration. - dict: Initialize GradScaler with more detail configuration.
**kwargs: Keyword arguments passed to OptimWrapper. **kwargs: Keyword arguments passed to OptimWrapper.
Note:
If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by
``accumulative_counts``.
""" """
def __init__(self, loss_scale=512., **kwargs): def __init__(self, loss_scale='dynamic', **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6') '`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert torch.cuda.is_available(), ( assert torch.cuda.is_available(), (
@ -62,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper):
loss (torch.Tensor): The loss of current iteration. loss (torch.Tensor): The loss of current iteration.
""" """
self.loss_scaler.scale(loss).backward() self.loss_scaler.scale(loss).backward()
self._inner_count += 1
def step(self): def step(self):
"""Update parameters with :attr:`loss_scaler`.""" """Update parameters with :attr:`loss_scaler`."""
@ -104,7 +112,13 @@ class AmpOptimWrapper(OptimWrapper):
self.optimizer.load_state_dict(state_dict) self.optimizer.load_state_dict(state_dict)
@contextmanager @contextmanager
def precision_context(self): def optim_context(self, model: nn.Module):
"""A wrapper of ``torch.cuda.amp.autocast``""" """Enables the context for mixed precision training, and enables the
with torch.cuda.amp.autocast(): context for disabling gradient synchronization during gradient
accumulation context.
Args:
model (nn.Module): The training model.
"""
with super().optim_context(model), torch.cuda.amp.autocast():
yield yield

View File

@ -73,7 +73,7 @@ class DefaultOptimWrapperConstructor:
Optional fields are Optional fields are
- any arguments of the corresponding optimizer wrapper type, - any arguments of the corresponding optimizer wrapper type,
e.g., accumulative_iters, clip_grad, etc. e.g., accumulative_counts, clip_grad, etc.
The positional fields of ``optimizer`` are The positional fields of ``optimizer`` are

View File

@ -27,21 +27,31 @@ class OptimWrapper:
Args: Args:
optimizer (Optimizer): Optimizer used to update model parameters. optimizer (Optimizer): Optimizer used to update model parameters.
accumulative_iters (int): The number of iterations to accumulate accumulative_counts (int): The number of iterations to accumulate
gradients. The parameters will be updated per gradients. The parameters will be updated per
``accumulative_iters``. ``accumulative_counts``.
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
the arguments of ``torch.nn.utils.clip_grad``. the arguments of ``torch.nn.utils.clip_grad``.
Warnings: Note:
If ``accumulative_iters`` is larger than 1, :meth:`update_params` must If ``accumulative_counts`` is larger than 1, perform
be called in the context of ``accumulate_grad``. :meth:`update_params` under the context of ``optim_context``
could avoid unnecessary gradient synchronization.
Note:
If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by
``accumulative_counts``.
Note:
The subclass should ensure that once :meth:`update_params` is called,
``_inner_count += 1`` is automatically performed.
Examples: Examples:
>>> # Config sample of OptimWrapper. >>> # Config sample of OptimWrapper.
>>> optim_wrapper_cfg = dict( >>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper', >>> type='OptimWrapper',
>>> accumulative_iters=3, >>> _accumulative_counts=1,
>>> clip_grad=dict(max_norm=0.2)) >>> clip_grad=dict(max_norm=0.2))
>>> # Use OptimWrapper to update model. >>> # Use OptimWrapper to update model.
>>> import torch.nn as nn >>> import torch.nn as nn
@ -59,29 +69,32 @@ class OptimWrapper:
>>> for data in dataloader: >>> for data in dataloader:
>>> loss = model(data) >>> loss = model(data)
>>> optim_wrapper.update_params(loss) >>> optim_wrapper.update_params(loss)
>>> # Enable gradient accumulation. If model is a subclass instance of >>> # Enable gradient accumulation
>>> # DistributedDataParallel, ``accumulate_grad`` context manager can >>> optim_wrapper_cfg = dict(
>>> # avoid unnecessary gradient synchronize. >>> type='OptimWrapper',
>>> _accumulative_counts=3,
>>> clip_grad=dict(max_norm=0.2))
>>> ddp_model = DistributedDataParallel(model)
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>> optim_wrapper.initialize_count_status(0, len(dataloader))
>>> # If model is a subclass instance of DistributedDataParallel,
>>> # `optim_context` context manager can avoid unnecessary gradient
>>> # synchronize.
>>> for iter, data in enumerate(dataloader): >>> for iter, data in enumerate(dataloader):
>>> with optim_wrapper.accumulate_grad( >>> with optim_wrapper.optim_context(ddp_model):
>>> model, iter, len(dataloader)):
>>> loss = model(data) >>> loss = model(data)
>>> optim_wrapper.update_params(loss) >>> optim_wrapper.update_params(loss)
""" """
def __init__(self, def __init__(self,
optimizer: Optimizer, optimizer: Optimizer,
accumulative_iters: int = 1, accumulative_counts: int = 1,
clip_grad: Optional[dict] = None): clip_grad: Optional[dict] = None):
assert accumulative_iters > 0, ( assert accumulative_counts > 0, (
'accumulative_iters at least greater than or equal to 1') '_accumulative_counts at least greater than or equal to 1')
self.accumulative_iters = accumulative_iters self._accumulative_counts = accumulative_counts
# `max_iters` and `cur_iter` is only valid in gradient accumulative
# mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be
# updated in the ``accumulate_grad`` context that is enabled in
# `runner.train_loop`.
self.cur_iter = 0
self.max_iters = 0
assert isinstance(optimizer, Optimizer), ( assert isinstance(optimizer, Optimizer), (
'optimizer must be a `torch.optim.Optimizer` instance, but got ' 'optimizer must be a `torch.optim.Optimizer` instance, but got '
f'{type(optimizer)}') f'{type(optimizer)}')
@ -90,13 +103,27 @@ class OptimWrapper:
if clip_grad is not None: if clip_grad is not None:
# clip_grad_kwargs should not be non-empty dict. # clip_grad_kwargs should not be non-empty dict.
assert isinstance(clip_grad, dict) and clip_grad, ( assert isinstance(clip_grad, dict) and clip_grad, (
'If `clip_grad_kwargs` is not None, it should be a `dict` ' 'If `clip_grad` is not None, it should be a `dict` '
'which is the arguments of `torch.nn.utils.clip_grad`') 'which is the arguments of `torch.nn.utils.clip_grad`')
self.clip_grad_kwargs = clip_grad self.clip_grad_kwargs = clip_grad
self.logger = MMLogger.get_current_instance() self.logger = MMLogger.get_current_instance()
# Used to update `grad_norm` log message. # Used to update `grad_norm` log message.
self.message_hub = MessageHub.get_current_instance() self.message_hub = MessageHub.get_current_instance()
self.iter_status_initialized = False self._inner_count = 0
# `_max_counts` means the total number of parameter updates. It
# ensures that the gradient of the last few iterations will not be
# lost when the `_max_counts` is not divisible by
# `accumulative_counts`.
self._max_counts = -1
# If `_inner_count` is smaller than `_divisible_counts`, the loss
# factor used for gradient accumulation should be the same as
# `_accumulative_counts`. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
self._divisible_counts = -1
# The `_remainder_iter` is used for calculating loss factor at the
# last few iterations. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
self._remainder_counts = -1
def update_params(self, loss: torch.Tensor) -> None: def update_params(self, loss: torch.Tensor) -> None:
"""Update parameters in :attr:`optimizer`. """Update parameters in :attr:`optimizer`.
@ -104,36 +131,12 @@ class OptimWrapper:
Args: Args:
loss (torch.Tensor): A tensor for back propagation. loss (torch.Tensor): A tensor for back propagation.
""" """
if self.accumulative_iters == 1: loss = self.scale_loss(loss)
# update parameters without gradient accumulation. The gradient
# should not be rescaled and `loss_factor=1`.
loss_factor = 1
else:
# gradient accumulation must be called in the context of
# ``accumulate_grad``.
assert hasattr(self, 'divisible_iters'), (
'gradient accumulation must be performed in the context of'
'`OptimWrapper.accumulate_grad`')
# if `self.accumulative_iters > 1`, the gradient needs to be
# rescaled and accumulated. In most cases, `loss_factor` equals to
# `self.accumulative_iters`. However `self.max_iters` may not be
# divisible `self.by accumulative_iters`, so the `loss_scale` for
# the last few iterations needs to be recalculated.
if self.cur_iter < self.divisible_iters:
loss_factor = self.accumulative_iters
else:
loss_factor = self.remainder_iters
assert loss_factor > 0, (
'loss_factor should be larger than zero! This error could '
'happened when gradient accumulation context enabled with an '
'error `cur_iter` or `max_iters` please check your loop')
loss = loss / loss_factor
self.backward(loss) self.backward(loss)
# Update parameters only if `self.cur_iter` is divisible by # Update parameters only if `self._inner_count` is divisible by
# `self.accumulative_iters` or `self.cur_iter` equals to # `self._accumulative_counts` or `self._inner_count` equals to
# `self.max_iters` # `self._max_counts`
if self._should_update(self.cur_iter, self.max_iters): if self.should_update():
self.step() self.step()
self.zero_grad() self.zero_grad()
@ -145,10 +148,15 @@ class OptimWrapper:
required logic. For example, ``torch.cuda.amp`` require some extra required logic. For example, ``torch.cuda.amp`` require some extra
operation on GradScaler during backward process. operation on GradScaler during backward process.
Note:
If subclasses inherit from ``OptimWrapper`` override
``backward``, ``_inner_count +=1`` must be implemented.
Args: Args:
loss (torch.Tensor): The loss of current iteration. loss (torch.Tensor): The loss of current iteration.
""" """
loss.backward() loss.backward()
self._inner_count += 1
def zero_grad(self) -> None: def zero_grad(self) -> None:
"""A wrapper of ``Optimizer.zero_grad``. """A wrapper of ``Optimizer.zero_grad``.
@ -218,7 +226,7 @@ class OptimWrapper:
Provide unified interface to get learning rate of optimizer. Provide unified interface to get learning rate of optimizer.
Returns: Returns:
List[float]: Learning rate of the optimizer. Dict[str, List[float]]: Learning rate of the optimizer.
""" """
lr = [group['lr'] for group in self.param_groups] lr = [group['lr'] for group in self.param_groups]
return dict(lr=lr) return dict(lr=lr)
@ -229,7 +237,7 @@ class OptimWrapper:
Provide unified interface to get momentum of optimizer. Provide unified interface to get momentum of optimizer.
Returns: Returns:
List[float]: Momentum of the optimizer. Dict[str, List[float]]: Momentum of the optimizer.
""" """
momentum = [] momentum = []
for group in self.param_groups: for group in self.param_groups:
@ -244,49 +252,35 @@ class OptimWrapper:
return dict(momentum=momentum) return dict(momentum=momentum)
@contextmanager @contextmanager
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int): def optim_context(self, model: nn.Module):
"""A Context manager for gradient accumulation and avoiding unnecessary """A Context for gradient accumulation and automatic mix precision
gradient synchronization during gradient accumulation. training.
If subclasses need to enable the context for mix precision training,
e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be
enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32
training, ``optim_context`` will only enable the context for
blocking the unnecessary gradient synchronization during gradient
accumulation
If model is an instance with ``no_sync`` method (which means If model is an instance with ``no_sync`` method (which means
blocking the gradient synchronization) and blocking the gradient synchronization) and
``self.accumulative_iters != 1``. The model will not automatically ``self._accumulative_counts != 1``. The model will not automatically
synchronize gradients if ``cur_iter`` is divisible by synchronize gradients if ``cur_iter`` is divisible by
``self.accumulative_iters``. Otherwise, this method will enable an ``self._accumulative_counts``. Otherwise, this method will enable an
empty context. empty context.
Warnings:
This context manager must be enabled if you want to use
gradient accumulation.
Args: Args:
model (nn.Module): The training model. model (nn.Module): The training model.
cur_iter (int): Current iteration during training process.
max_iters (int): Maximum training iteration.
""" """
assert max_iters > 0, '`max_iters` must be larger than zero'
self.cur_iter = cur_iter
self.max_iters = max_iters
if not self.iter_status_initialized:
self._initilize_iter_status(model)
# During gradient accumulation process, the gradient synchronize # During gradient accumulation process, the gradient synchronize
# should only happen before updating parameters. # should only happen before updating parameters.
if (not self._should_update(cur_iter, max_iters) if not self.should_sync() and hasattr(model, 'no_sync'):
and hasattr(model, 'no_sync')):
with model.no_sync(): with model.no_sync():
yield yield
else: else:
yield yield
@contextmanager
def precision_context(self):
"""precision context which enables an empty context by default.
The subclass used for mixed or low precision training needs to override
this method.
"""
yield
def _clip_grad(self) -> None: def _clip_grad(self) -> None:
"""Clip the gradients of parameters.""" """Clip the gradients of parameters."""
params: List[torch.Tensor] = [] params: List[torch.Tensor] = []
@ -300,50 +294,117 @@ class OptimWrapper:
**self.clip_grad_kwargs) **self.clip_grad_kwargs)
self.message_hub.update_scalar('train/grad_norm', float(grad_norm)) self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
def _initilize_iter_status(self, model: nn.Module) -> None: def initialize_count_status(self, model: nn.Module, init_counts: int,
max_counts: int) -> None:
"""Initialize gradient accumulation related attributes. """Initialize gradient accumulation related attributes.
``OptimWrapper`` can be used without calling
``initialize_iter_status``. However, Consider the case of ``len(
dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is
not divisible by 3, the last iteration will not trigger
``optimizer.step()``, resulting in one less parameter updating.
Args: Args:
model (nn.Module): Training model model (nn.Module): Training model
init_counts (int): The initial value of the inner count.
max_counts (int): The maximum value of the inner count.
""" """
if self.max_iters % self.accumulative_iters != 0: self._inner_count = init_counts
self._max_counts = max_counts
if self._inner_count % self._accumulative_counts != 0:
self.logger.warning( self.logger.warning(
'Resume iter number is not divisible by accumulative_iters in ' 'Resumed iteration number is not divisible by '
'GradientCumulativeOptimizerHook, which means the gradient of ' '`_accumulative_counts` in `GradientCumulativeOptimizerHook`, '
'some iters is lost and the result may be influenced slightly.' 'which means the gradient of some iterations is lost and the '
) 'result may be influenced slightly.')
if has_batch_norm(model) and self.accumulative_iters > 1: if has_batch_norm(model) and self._accumulative_counts > 1:
self.logger.warning( self.logger.warning(
'Gradient accumulative may slightly decrease ' 'Gradient accumulative may slightly decrease '
'performance because the model has BatchNorm layers.') 'performance because the model has BatchNorm layers.')
residual_iters = self.max_iters - self.cur_iter residual_counts = max_counts - init_counts
# The maximum number of training iteration that is divisible by # The maximum number of training iteration that is divisible by
# accumulative_iters. # `_accumulative_counts`.
self.divisible_iters = ( self._divisible_counts = (
residual_iters // self.accumulative_iters * residual_counts // self._accumulative_counts *
self.accumulative_iters) self._accumulative_counts)
# Remainder of ``self.max_iters`` divided by ``self.max_iters`` # Remainder of `_max_counts` divided by `_accumulative_counts`
self.remainder_iters = residual_iters - self.divisible_iters self._remainder_counts = residual_counts - self._divisible_counts
self.iter_status_initialized = True
def _should_update(self, cur_iter: int, max_iters: int) -> bool: def should_update(self) -> bool:
"""Should optim_wrapper update parameters or synchronized gradient at """Decide whether the parameters should be updated at the current
current iteration. iteration.
Args: Called by :meth:`update_params` and check whether the optimizer
cur_iter (int): Current iteration of training process. wrapper should update parameters at current iteration.
max_iters (int): Maximum iterations of training process.
Returns: Returns:
bool: Whether to update parameters or synchronized gradient. bool: Whether to update parameters.
""" """
return ((cur_iter + 1) % self.accumulative_iters == 0 return (self._inner_count % self._accumulative_counts == 0
or cur_iter + 1 == max_iters) or self._inner_count == self._max_counts)
def should_sync(self) -> bool:
"""Decide whether the automatic gradient synchronization should be
allowed at the current iteration.
It takes effect when gradient accumulation is used to skip
synchronization at the iterations where the parameter is not updated.
Since ``should_sync`` is called by :meth:`optim_context`, and it is
called before :meth:`backward` which means ``self._inner_count += 1``
has not happened yet. Therefore, ``self._inner_count += 1`` should be
performed manually here.
Returns:
bool: Whether to block the automatic gradient synchronization.
"""
return ((self._inner_count + 1) % self._accumulative_counts == 0
or (self._inner_count + 1) == self._max_counts)
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Get scaled loss according to ``_accumulative_counts``,
``_inner_count`` and max_counts.
Args:
loss (torch.Tensor): Original loss calculated by model.
Returns:
loss (torch.Tensor): Scaled loss.
"""
if self._accumulative_counts == 1:
# update parameters without gradient accumulation. The gradient
# should not be rescaled and `loss_factor=1`.
loss_factor = 1
elif self._max_counts == -1:
loss_factor = self._accumulative_counts
else:
# if `self._accumulative_counts > 1`, the gradient needs to be
# rescaled and accumulated. In most cases, `loss_factor` equals to
# `self._accumulative_counts`. However, `self._max_counts` may not
# be divisible by `self._accumulative_counts`, so the
# `loss_scale` for the last few iterations needs to be
# recalculated.
if self._inner_count < self._divisible_counts:
loss_factor = self._accumulative_counts
else:
loss_factor = self._remainder_counts
assert loss_factor > 0, (
'loss_factor should be larger than zero! This error could '
'happened when initialize_iter_status called with an '
'error `init_counts` or `max_counts`')
loss = loss / loss_factor
return loss
@property
def inner_count(self):
"""Get the number of updating parameters of optimizer wrapper."""
return self._inner_count
def __repr__(self): def __repr__(self):
wrapper_info = f'Type: {type(self).__name__}\n' \ wrapper_info = (f'Type: {type(self).__name__}\n'
f'accumulative_iters: {self.accumulative_iters}\n' \ f'_accumulative_counts: {self._accumulative_counts}\n'
f'optimizer: \n' 'optimizer: \n')
optimizer_str = repr(self.optimizer) + '\n' optimizer_str = repr(self.optimizer) + '\n'
return wrapper_info + optimizer_str return wrapper_info + optimizer_str

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from typing import Dict, Iterator, List, Tuple from typing import Dict, Iterator, List, Tuple
import torch import torch
@ -41,24 +40,10 @@ class OptimWrapperDict(OptimWrapper):
""" """
def __init__(self, **optim_wrapper_dict: OptimWrapper): def __init__(self, **optim_wrapper_dict: OptimWrapper):
first_key = next(iter(optim_wrapper_dict))
first_optim_wrapper = optim_wrapper_dict[first_key]
assert isinstance(first_optim_wrapper, OptimWrapper), (
'Each argument of `OptimWrapperDict` must be an `OptimWrapper '
'instance`')
optim_wrapper_class = type(first_optim_wrapper)
for key, value in optim_wrapper_dict.items(): for key, value in optim_wrapper_dict.items():
assert type(value) == optim_wrapper_class, ( assert isinstance(value, OptimWrapper), (
f'All optimizer wrappers should have the same type, but found' '`OptimWrapperDict` only accept OptimWrapper instance, '
f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}' f'but got {key}: {type(value)}')
)
if value.accumulative_iters != 1:
warnings.warn(
f'The `accumulative_iters` of {key} is '
f'{value.accumulative_iters}. OptimWrapperDict '
'will not enable any `accumulate_grad` context of its '
'optimizer wrappers. You should access the corresponding '
'optimizer wrapper to enable the context.')
self.optim_wrappers = optim_wrapper_dict self.optim_wrappers = optim_wrapper_dict
def update_params(self, loss: torch.Tensor) -> None: def update_params(self, loss: torch.Tensor) -> None:
@ -69,9 +54,8 @@ class OptimWrapperDict(OptimWrapper):
Therefore, this method is not implemented. The optimizer wrapper of Therefore, this method is not implemented. The optimizer wrapper of
OptimWrapperDict should be accessed and call its `update_params. OptimWrapperDict should be accessed and call its `update_params.
""" """
raise NotImplementedError( raise NotImplementedError('`update_params` should be called by each '
'You should access the OptimWrapper of the ' 'optimizer separately`')
'OptimWrapperDict and call its `update_params`')
def backward(self, loss: torch.Tensor) -> None: def backward(self, loss: torch.Tensor) -> None:
"""Since OptimWrapperDict doesn't know which optimizer wrapper's """Since OptimWrapperDict doesn't know which optimizer wrapper's
@ -81,14 +65,14 @@ class OptimWrapperDict(OptimWrapper):
The optimizer wrapper of OptimWrapperDict should be accessed and call The optimizer wrapper of OptimWrapperDict should be accessed and call
its `backward. its `backward.
""" """
raise NotImplementedError('You should access the OptimWrapper of the ' raise NotImplementedError('`backward` should be called by each '
'OptimWrapperDict and call its `backward`') 'optimizer separately`')
def step(self) -> None: def step(self) -> None:
"""Since the backward method is not implemented, the step should not be """Since the backward method is not implemented, the step should not be
implemented either.""" implemented either."""
raise NotImplementedError('You should access the OptimWrapper of the ' raise NotImplementedError('`step` should be called by each '
'OptimWrapperDict and call its `step`') 'optimizer separately`')
def zero_grad(self) -> None: def zero_grad(self) -> None:
"""Set the gradients of all optimizer wrappers to zero.""" """Set the gradients of all optimizer wrappers to zero."""
@ -96,49 +80,29 @@ class OptimWrapperDict(OptimWrapper):
optim_wrapper.zero_grad() optim_wrapper.zero_grad()
@contextmanager @contextmanager
def precision_context(self): def optim_context(self, model: nn.Module):
optim_wrapper = next(iter(self.optim_wrappers.values())) """``optim_context`` should be called by each optimizer separately."""
with optim_wrapper.precision_context(): raise NotImplementedError(
yield '`optim_context` should be called by each optimizer separately')
@contextmanager def initialize_count_status(self, model: nn.Module, cur_iter,
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int): max_iters) -> None:
"""Enable ``accumulate_grad`` contexts of all optimizer wrappers. """Do nothing but provide unified interface for :obj:`OptimWrapper`
Warning: Since ``OptimWrapperDict`` does not know the correspondence between
Consider there is only one ``model`` arguments for all model and optimizer wrapper. ``initialize_iter_status`` will do nothing
optimizer wrappers, all optimizer wrappers are working under the and each optimizer wrapper should call ``initialize_iter_status``
same ``model.no_sync`` context. For example, there is a model separately.
composed of model_a(optimizer_a) and model_b(optimizer_b).
``OptimWrapperDict.accumulate_grad`` will further
call ``model.no_sync``, which will block the gradient
synchronization of both a and b. If optimizer_a and
optimizer_b have different ``accumulative_iters``, and want to
block the gradient synchronization of model_a and model_b
separately, the model should not implement the ``no_sync``
method(or enable an empty context). The ``accumulate_grad`` context
should be enabled inside the model by accessing corresponding
optimizer wrapper.
""" """
with ExitStack() as stack: return
for optim_wrapper in self.optim_wrappers.values():
stack.enter_context(
optim_wrapper.accumulate_grad(model, cur_iter, max_iters))
yield
def load_state_dict(self, state_dict: dict) -> None: @property
"""Load the state dictionary from the ``state_dict``. def param_groups(self):
"""Returns the parameter groups of each OptimWrapper."""
Args: param_groups = dict()
state_dict (dict): Each key-value pair in `state_dict` represents for key, value in self.optim_wrappers.items():
the name and the state dictionary of corresponding param_groups[key] = value.param_groups
:obj:`OptimWrapper`. return param_groups
"""
for name, _state_dict in state_dict.items():
assert name in self.optim_wrappers, (
f'Mismatched `state_dict`! cannot found {name} in '
'OptimWrapperDict')
self.optim_wrappers[name].load_state_dict(_state_dict)
def get_lr(self) -> Dict[str, List[float]]: def get_lr(self) -> Dict[str, List[float]]:
"""Get the learning rate of all optimizers. """Get the learning rate of all optimizers.
@ -175,6 +139,20 @@ class OptimWrapperDict(OptimWrapper):
state_dict[name] = optim_wrapper.state_dict() state_dict[name] = optim_wrapper.state_dict()
return state_dict return state_dict
def load_state_dict(self, state_dict: dict) -> None:
"""Load the state dictionary from the ``state_dict``.
Args:
state_dict (dict): Each key-value pair in `state_dict` represents
the name and the state dictionary of corresponding
:obj:`OptimWrapper`.
"""
for name, _state_dict in state_dict.items():
assert name in self.optim_wrappers, (
f'Mismatched `state_dict`! cannot found {name} in '
'OptimWrapperDict')
self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]: def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding """A generator to get the name and corresponding
:obj:`OptimWrapper`""" :obj:`OptimWrapper`"""

View File

@ -101,9 +101,6 @@ class EpochBasedTrainLoop(BaseLoop):
'before_train_iter', batch_idx=idx, data_batch=data_batch) 'before_train_iter', batch_idx=idx, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient # Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process. # synchronization during gradient accumulation process.
with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
self._iter,
self._max_iters):
# outputs should be a dict of loss. # outputs should be a dict of loss.
outputs = self.runner.model.train_step( outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper) data_batch, optim_wrapper=self.runner.optim_wrapper)
@ -204,19 +201,16 @@ class IterBasedTrainLoop(BaseLoop):
'before_train_iter', batch_idx=self._iter, data_batch=data_batch) 'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient # Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process. # synchronization during gradient accumulation process.
with self.runner.optim_wrapper.accumulate_grad(self.runner.model, # outputs should be a dict of loss.
self._iter, outputs = self.runner.model.train_step(
self._max_iters):
# train_logs should be a dict of loss.
train_logs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper) data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.message_hub.update_info('train_logs', train_logs) self.runner.message_hub.update_info('train_logs', outputs)
self.runner.call_hook( self.runner.call_hook(
'after_train_iter', 'after_train_iter',
batch_idx=self._iter, batch_idx=self._iter,
data_batch=data_batch, data_batch=data_batch,
outputs=train_logs) outputs=outputs)
self._iter += 1 self._iter += 1

View File

@ -926,7 +926,7 @@ class Runner:
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper >>> optim_wrapper
Type: OptimWrapper Type: OptimWrapper
accumulative_iters: 1 accumulative_counts: 1
optimizer: optimizer:
SGD ( SGD (
Parameter Group 0 Parameter Group 0
@ -941,7 +941,7 @@ class Runner:
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper >>> optim_wrapper
Type: OptimWrapper Type: OptimWrapper
accumulative_iters: 1 accumulative_counts: 1
optimizer: optimizer:
SGD ( SGD (
Parameter Group 0 Parameter Group 0
@ -965,7 +965,7 @@ class Runner:
>>> optim_wrapper >>> optim_wrapper
name: generator name: generator
Type: OptimWrapper Type: OptimWrapper
accumulative_iters: 1 accumulative_counts: 1
optimizer: optimizer:
SGD ( SGD (
Parameter Group 0 Parameter Group 0
@ -977,7 +977,7 @@ class Runner:
) )
name: discriminator name: discriminator
Type: OptimWrapper Type: OptimWrapper
accumulative_iters: 1 accumulative_counts: 1
optimizer: optimizer:
'discriminator': Adam ( 'discriminator': Adam (
Parameter Group 0 Parameter Group 0
@ -1528,7 +1528,6 @@ class Runner:
# `build_optimizer` should be called before `build_param_scheduler` # `build_optimizer` should be called before `build_param_scheduler`
# because the latter depends on the former # because the latter depends on the former
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
# Automatically scaling lr by linear scaling rule # Automatically scaling lr by linear scaling rule
self.scale_lr(self.optim_wrapper, self.auto_scale_lr) self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
@ -1540,9 +1539,14 @@ class Runner:
self._val_loop = self.build_val_loop( self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore self._val_loop) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
self.call_hook('before_run') self.call_hook('before_run')
# Initiate inner count of `optim_wrapper`.
self.optim_wrapper.initialize_count_status(
self.model,
self._train_loop.iter, # type: ignore
self._train_loop.max_iters) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
# make sure checkpoint-related hooks are triggered after `before_run` # make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume() self.load_or_resume()

View File

@ -8,9 +8,10 @@ import torch.distributed as torch_dist
import torch.nn as nn import torch.nn as nn
from torch.optim import SGD from torch.optim import SGD
from mmengine.dist import all_gather
from mmengine.model import (BaseModel, MMDistributedDataParallel, from mmengine.model import (BaseModel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel) MMSeparateDistributedDataParallel)
from mmengine.optim import OptimWrapper, OptimWrapperDict from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing._internal import MultiProcessTestCase
@ -23,9 +24,9 @@ class ToyModel(BaseModel):
self.conv2 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x, data_samples=None, mode='tensor'): def forward(self, x, data_samples=None, mode='tensor'):
if mode == 'loss':
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
if mode == 'loss':
return dict(loss=x) return dict(loss=x)
elif mode == 'predict': elif mode == 'predict':
return x return x
@ -58,24 +59,41 @@ class ComplexModel(BaseModel):
pass pass
class TestModelWrapper(MultiProcessTestCase): class TestDistributedDataParallel(MultiProcessTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self._spawn_processes() self._spawn_processes()
@unittest.skipIf(
not torch.cuda.is_available(), reason='cuda should be available')
def test_train_step(self): def test_train_step(self):
self._init_dist_env(self.rank, self.world_size) self._init_dist_env(self.rank, self.world_size)
# Test `optim_wrapper` is a instance of `OptimWrapper` # Mixed precision training and gradient asynchronous should be valid at
model = ToyModel() # the same time
model = ToyModel().cuda()
ddp_model = MMDistributedDataParallel(module=model) ddp_model = MMDistributedDataParallel(module=model)
optimizer = SGD(ddp_model.parameters(), lr=0) optimizer = SGD(ddp_model.parameters(), lr=0)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) optim_wrapper = AmpOptimWrapper(
inputs = torch.randn(3, 1, 1) * self.rank * 255 optimizer=optimizer, accumulative_counts=3)
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock()) data = dict(inputs=inputs, data_sample=MagicMock())
res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss']
self.assertIs(res.dtype, torch.float16)
grad = ddp_model.module.conv1.weight.grad
all_grads = all_gather(grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
# Gradient accumulation
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
# Test update params and clean grads.
ddp_model.train_step([data], optim_wrapper=optim_wrapper) ddp_model.train_step([data], optim_wrapper=optim_wrapper)
grad = ddp_model.module.conv1.weight.grad grad = ddp_model.module.conv1.weight.grad
assert_allclose(grad, torch.zeros_like(grad)) all_grads = all_gather(grad)
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
def test_val_step(self): def test_val_step(self):
self._init_dist_env(self.rank, self.world_size) self._init_dist_env(self.rank, self.world_size)
@ -107,7 +125,7 @@ class TestModelWrapper(MultiProcessTestCase):
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available(), reason='cuda should be available') not torch.cuda.is_available(), reason='cuda should be available')
class TestMMSeparateDistributedDataParallel(TestModelWrapper): class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
def test_train_step(self): def test_train_step(self):
self._init_dist_env(self.rank, self.world_size) self._init_dist_env(self.rank, self.world_size)

View File

@ -45,7 +45,7 @@ class ToyModel2(nn.Module):
class TestOptimWrapper(MultiProcessTestCase): class TestOptimWrapper(MultiProcessTestCase):
# Test `OptimWrapper.accumulate_grad` will block the gradient # Test `OptimWrapper.optim_context` will block the gradient
# synchronization when using gradient accumulation strategy in distributed # synchronization when using gradient accumulation strategy in distributed
# data parallel training. # data parallel training.
def setUp(self) -> None: def setUp(self) -> None:
@ -61,94 +61,84 @@ class TestOptimWrapper(MultiProcessTestCase):
def test_init(self): def test_init(self):
optim_wrapper = OptimWrapper(self.optimizer) optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.optimizer, self.optimizer) self.assertIs(optim_wrapper.optimizer, self.optimizer)
self.assertIsNone(optim_wrapper.clip_grad_kwargs) self.assertIsNone(optim_wrapper.clip_grad_kwargs)
self.assertEqual(optim_wrapper.accumulative_iters, 1) self.assertEqual(optim_wrapper._accumulative_counts, 1)
self.assertIs(optim_wrapper.logger, self.logger) self.assertIs(optim_wrapper.logger, self.logger)
self.assertIs(optim_wrapper.message_hub, self.message_hub) self.assertIs(optim_wrapper.message_hub, self.message_hub)
self.assertEqual(optim_wrapper._inner_count, 0)
self.assertEqual(optim_wrapper._max_counts, -1)
self.assertEqual(optim_wrapper._divisible_counts, -1)
self.assertEqual(optim_wrapper._remainder_counts, -1)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(AssertionError,
'If `clip_grad_kwargs` is not None'): 'If `clip_grad` is not None'):
OptimWrapper(self.optimizer, clip_grad=[]) OptimWrapper(self.optimizer, clip_grad=[])
def test_update_params(self): def test_update_params(self):
# Test update params every iteration. # Test update params every iteration.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1) optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1)
self._mock_method(optim_wrapper) self._mock_method(optim_wrapper)
loss = torch.tensor(1) loss = torch.tensor(1)
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
optim_wrapper.backward.assert_called_with(torch.tensor(1)) self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
optim_wrapper.step.assert_called_with() optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with() optim_wrapper.zero_grad.assert_called_with()
with optim_wrapper.accumulate_grad(self.model, 2, 100): # Test gradient accumulation.
optim_wrapper.update_params(torch.tensor(1)) optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper.backward.assert_called_with(torch.tensor(1))
optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with()
# It will raise an error if `accumulative_iters > 1` and
# `accumulate_grad` is not enabled.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
self._mock_method(optim_wrapper) self._mock_method(optim_wrapper)
with self.assertRaisesRegex(AssertionError, # `iter=0`, accumulate gradient and do not update params.
'gradient accumulation must be'):
optim_wrapper.update_params(loss)
# `iter=0`, Call `optimizer_step` first time.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=0, max_iters=100):
loss = torch.tensor(1) loss = torch.tensor(1)
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3) self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3)
optim_wrapper.step.assert_not_called() optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called() optim_wrapper.zero_grad.assert_not_called()
# `iter=2`, Call `optimizer_step` first time. # gradient accumulate
with optim_wrapper.accumulate_grad( optim_wrapper.update_params(loss)
self.model, cur_iter=2, max_iters=100): self.assertEqual(optim_wrapper._inner_count, 2)
# `iter=2`, update params.
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called() optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called() optim_wrapper.zero_grad.assert_called()
self._mock_method(optim_wrapper) self._mock_method(optim_wrapper)
# Test end of training.
with optim_wrapper.accumulate_grad( # Test end of training without calling `initialize_iter_status`
self.model, cur_iter=99, max_iters=100): optim_wrapper._inner_count = 99
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3)
self._mock_method(optim_wrapper)
# After calling `initialize_iter_status`, params will be updated at the
# last iteration, and the `loss_scaler` will be adjusted.
optim_wrapper.initialize_count_status(self.model, 99, 100)
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called() optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called() optim_wrapper.zero_grad.assert_called()
optim_wrapper.backward.assert_called_with(1) self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
# If ``accumulative_iters > 1``, call ``update_params`` with def test_initialize_iter_status(self):
# non-accumulate_grad context will raise an Assertion error optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1) optim_wrapper.initialize_count_status(self.model, 0, 100)
optim_wrapper.accumulative_iters = 2 self.assertEqual(optim_wrapper._divisible_counts, 99)
with self.assertRaisesRegex(AssertionError, self.assertEqual(optim_wrapper._remainder_counts, 1)
'gradient accumulation must be performed'):
optim_wrapper.update_params(loss)
def test_initilize_iter_status(self):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper._initilize_iter_status(self.model)
self.assertEqual(optim_wrapper.divisible_iters, 0)
self.assertEqual(optim_wrapper.remainder_iters, 0)
# Indivisible cur_iter will output warning. # Indivisible cur_iter will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper.cur_iter = 0
optim_wrapper.max_iters = 100
with self.assertLogs(self.logger) as cm: with self.assertLogs(self.logger) as cm:
optim_wrapper._initilize_iter_status(self.model) optim_wrapper.initialize_count_status(self.model, 2, 100)
self.assertEqual(len(cm.output), 1) self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Resume iter number is not') self.assertRegex(cm.records[0].msg, 'Resumed iteration number')
# Model with batch norm will output warning. # Model with batch norm will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3) optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper.cur_iter = 0
optim_wrapper.max_iters = 99
model = nn.BatchNorm2d(1) model = nn.BatchNorm2d(1)
with self.assertLogs(self.logger) as cm: with self.assertLogs(self.logger) as cm:
optim_wrapper._initilize_iter_status(model) optim_wrapper.initialize_count_status(model, 0, 99)
self.assertEqual(len(cm.output), 1) self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Gradient accumulative') self.assertRegex(cm.records[0].msg, 'Gradient accumulative')
@ -214,15 +204,16 @@ class TestOptimWrapper(MultiProcessTestCase):
self.assertEqual(optim_wrapper.param_groups, self.assertEqual(optim_wrapper.param_groups,
self.optimizer.param_groups) self.optimizer.param_groups)
def test_accumulate_grad(self): def test_optim_context(self):
self._init_dist_env(self.rank, self.world_size) self._init_dist_env(self.rank, self.world_size)
model = ToyModel2() model = ToyModel2()
ddp_model = DistributedDataParallel(model) ddp_model = DistributedDataParallel(model)
optimizer = SGD(ddp_model.parameters(), lr=0.01) optimizer = SGD(ddp_model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
optim_wrapper.zero_grad() optim_wrapper.zero_grad()
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
# Automatically sync grads if `accumulative_iters` = 1 # Automatically sync grads if `accumulative_counts` = 1
optim_wrapper.initialize_count_status(model, 0, 100)
inputs = torch.randn(1, 1, 1, 1) * self.rank inputs = torch.randn(1, 1, 1, 1) * self.rank
ddp_model(inputs).sum().backward() ddp_model(inputs).sum().backward()
grad = model.conv.weight.grad grad = model.conv.weight.grad
@ -230,25 +221,24 @@ class TestOptimWrapper(MultiProcessTestCase):
assert_allclose(all_grads[0], all_grads[1]) assert_allclose(all_grads[0], all_grads[1])
# Do not sync grads when `optim_wrapper.cur_iter` cannot be # Do not sync grads when `optim_wrapper.cur_iter` cannot be
# divided by `optim_wrapper.accumulative_iters` # divided by `optim_wrapper._accumulative_counts`
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3) optim_wrapper = OptimWrapper(optimizer, accumulative_counts=3)
with optim_wrapper.accumulate_grad(ddp_model, 0, 100): optim_wrapper.initialize_count_status(model, 0, 100)
ddp_model(inputs).sum().backward() with optim_wrapper.optim_context(ddp_model):
loss = ddp_model(inputs).sum()
loss.backward()
all_grads = all_gather(model.conv.weight.grad) all_grads = all_gather(model.conv.weight.grad)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1]) assert_allclose(all_grads[0], all_grads[1])
# sync grads if `cur_iter == 2` # sync grads if `cur_iter == 2`
with optim_wrapper.accumulate_grad(ddp_model, 2, 100): optim_wrapper.initialize_count_status(model, 2, 100)
ddp_model(inputs).sum().backward() with optim_wrapper.optim_context(ddp_model):
loss = ddp_model(inputs).sum()
loss.backward()
all_grads = all_gather(model.conv.weight.grad) all_grads = all_gather(model.conv.weight.grad)
assert_allclose(all_grads[0], all_grads[1]) assert_allclose(all_grads[0], all_grads[1])
def test_precision_context(self):
optim_wrapper = OptimWrapper(self.optimizer)
with optim_wrapper.precision_context():
pass
def _init_dist_env(self, rank, world_size): def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
@ -260,7 +250,12 @@ class TestOptimWrapper(MultiProcessTestCase):
# TODO Test the real interface after add testing tool function which can # TODO Test the real interface after add testing tool function which can
# test the function or method is read called. # test the function or method is read called.
def _mock_method(self, optim_wrapper): def _mock_method(self, optim_wrapper):
optim_wrapper.backward = MagicMock()
def mock_methd(loss):
optim_wrapper._inner_count += 1
optim_wrapper.scaled_loss = loss
optim_wrapper.backward = mock_methd
optim_wrapper.step = MagicMock() optim_wrapper.step = MagicMock()
optim_wrapper.zero_grad = MagicMock() optim_wrapper.zero_grad = MagicMock()
@ -376,9 +371,9 @@ class TestAmpOptimWrapper(TestCase):
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version ' reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6') '>= 1.6')
def test_precision_context(self): def test_optim_context(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
with amp_optim_wrapper.precision_context(): with amp_optim_wrapper.optim_context(self.model):
x = torch.randn(1, 1, 1, 1).cuda() x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x) y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.float16) self.assertEqual(y.dtype, torch.float16)

View File

@ -1,69 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch
import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import SGD from torch.optim import SGD
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.optim import OptimWrapper, OptimWrapperDict
class TestOptimWrapperDict(TestCase): class TestOptimWrapperDict(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
model1 = nn.Linear(1, 1) self.model1 = nn.Linear(1, 1)
model2 = nn.Linear(1, 1) self.model2 = nn.Linear(1, 1)
self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8) self.optim1 = SGD(self.model1.parameters(), lr=0.1, momentum=0.8)
self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9) self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9)
self.optim_wrapper1 = OptimWrapper(self.optim1) self.optim_wrapper1 = OptimWrapper(self.optim1)
self.optim_wrapper2 = OptimWrapper(self.optim2) self.optim_wrapper2 = OptimWrapper(self.optim2)
self.optimizers_wrappers = dict( self.optimizers_wrappers = dict(
optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) optim1=self.optim_wrapper1, optim2=self.optim_wrapper2)
@patch('torch.cuda.is_available', lambda: True)
def test_init(self): def test_init(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertEqual(optim_wrapper_dict.optim_wrappers, self.assertEqual(optim_wrapper_dict.optim_wrappers,
self.optimizers_wrappers) self.optimizers_wrappers)
# Different types of OptimWrapper will raise an error with self.assertRaisesRegex(AssertionError,
'`OptimWrapperDict` only accept'):
OptimWrapperDict(**dict(optim1=self.optim1, optim2=self.optim2))
with self.assertRaisesRegex( def test_update_params(self):
AssertionError, 'All optimizer wrappers should have the same'):
optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2)
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'):
optim_wrapper2 = OptimWrapper(
optimizer=self.optim2, accumulative_iters=2)
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
def test_accumulate_grad(self):
@contextmanager
def context_a(a, b, *args, **kwargs):
a[0] = 100
yield
a[0] = 1
@contextmanager
def context_b(a, b, *args, **kwargs):
b[0] = 200
yield
b[0] = 2
a = [0]
b = [0]
# Test enter the context both of `optim_wrapper1` and `optim_wrapper1`.
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.optim_wrapper1.accumulate_grad = context_a with self.assertRaisesRegex(NotImplementedError,
self.optim_wrapper2.accumulate_grad = context_b '`update_params` should be called'):
with optim_wrapper_dict.accumulate_grad(a, b, 0): optim_wrapper_dict.update_params(1)
self.assertEqual(a[0], 100)
self.assertEqual(b[0], 200)
self.assertEqual(a[0], 1) def test_backward(self):
self.assertEqual(b[0], 2) optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
with self.assertRaisesRegex(NotImplementedError,
'`backward` should be called'):
optim_wrapper_dict.backward(1)
def test_step(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
with self.assertRaisesRegex(NotImplementedError,
'`step` should be called'):
optim_wrapper_dict.step()
def test_zero_grad(self):
# Test clear all grad
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.model1(torch.randn(1, 1)).sum().backward()
self.model2(torch.randn(1, 1)).sum().backward()
self.assertTrue((self.model1.weight.grad != 0).any())
self.assertTrue((self.model2.weight.grad != 0).any())
optim_wrapper_dict.zero_grad()
self.assertTrue((self.model1.weight.grad == 0).all())
self.assertTrue((self.model2.weight.grad == 0).all())
def test_optim_context(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
with self.assertRaisesRegex(NotImplementedError,
'`optim_context` should be called'):
with optim_wrapper_dict.optim_context(self.model1):
yield
def test_initialize_count_status(self):
# Test `initialize_count_status` can be called.
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
optim_wrapper_dict.initialize_count_status(self.model1, 1, 1)
def test_param_groups(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertEqual(optim_wrapper_dict.param_groups['optim1'],
self.optim1.param_groups)
self.assertEqual(optim_wrapper_dict.param_groups['optim2'],
self.optim2.param_groups)
def test_get_lr(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
lr = optim_wrapper_dict.get_lr()
self.assertEqual(lr['optim1.lr'], [0.1])
self.assertEqual(lr['optim2.lr'], [0.2])
def test_get_momentum(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
momentum = optim_wrapper_dict.get_momentum()
self.assertEqual(momentum['optim1.momentum'], [0.8])
self.assertEqual(momentum['optim2.momentum'], [0.9])
def test_state_dict(self): def test_state_dict(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
@ -109,18 +132,6 @@ class TestOptimWrapperDict(TestCase):
list(optim_wrapper_dict.keys()), list(optim_wrapper_dict.keys()),
list(self.optimizers_wrappers.keys())) list(self.optimizers_wrappers.keys()))
def test_get_lr(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
lr = optim_wrapper_dict.get_lr()
self.assertEqual(lr['optim1.lr'], [0.1])
self.assertEqual(lr['optim2.lr'], [0.2])
def test_get_momentum(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
momentum = optim_wrapper_dict.get_momentum()
self.assertEqual(momentum['optim1.momentum'], [0.8])
self.assertEqual(momentum['optim2.momentum'], [0.9])
def test_getitem(self): def test_getitem(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertIs(self.optimizers_wrappers['optim1'], self.assertIs(self.optimizers_wrappers['optim1'],

View File

@ -1135,8 +1135,9 @@ class TestRunner(TestCase):
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)] cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2) cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
self.assertEqual(runner.optim_wrapper._inner_count, 12)
self.assertEqual(runner.optim_wrapper._max_counts, 12)
assert isinstance(runner.train_loop, EpochBasedTrainLoop) assert isinstance(runner.train_loop, EpochBasedTrainLoop)
@ -1183,6 +1184,8 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
self.assertEqual(runner.optim_wrapper._inner_count, 12)
self.assertEqual(runner.optim_wrapper._max_counts, 12)
assert isinstance(runner.train_loop, IterBasedTrainLoop) assert isinstance(runner.train_loop, IterBasedTrainLoop)
self.assertEqual(len(epoch_results), 1) self.assertEqual(len(epoch_results), 1)