mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the accumulate gradient implemention of OptimWrapper (#284)
* merge context * update unit test * add docstring * fix bug in AmpOptimWrapper * add docstring for backward * add warning and docstring for accumuate gradient * fix docstring * fix docstring * add params_group method * fix as comment * fix as comment * make default_value of loss_scale to dynamic * Fix docstring * decouple should update and should no sync * rename attribute in OptimWrapper * fix docstring * fix comment * fix comment * fix as comment * fix as comment and add unit test
This commit is contained in:
parent
fd295741ca
commit
b7866021c4
@ -111,8 +111,8 @@ class BaseModel(BaseModule):
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
|
||||
"""
|
||||
# enable automatic mixed precision training context.
|
||||
with optim_wrapper.precision_context():
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.data_preprocessor(data, True)
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_losses, log_vars = self.parse_losses(losses)
|
||||
|
@ -89,8 +89,8 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
|
||||
"""
|
||||
# enable automatic mixed precision training context.
|
||||
with optim_wrapper.precision_context():
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.module.data_preprocessor(
|
||||
data, training=True)
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
|
@ -2,6 +2,7 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
@ -25,15 +26,21 @@ class AmpOptimWrapper(OptimWrapper):
|
||||
loss_scale (float or str or dict): The initial configuration of
|
||||
`torch.cuda.amp.GradScaler`. See more specific arguments
|
||||
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501
|
||||
Defaults to ``dynamic``.
|
||||
|
||||
- "dynamic": Initialize GradScale without any arguments.
|
||||
- float: Initialize GradScaler with ``init_scale``.
|
||||
- dict: Initialize GradScaler with more detail configuration.
|
||||
|
||||
**kwargs: Keyword arguments passed to OptimWrapper.
|
||||
|
||||
Note:
|
||||
If you use ``IterBasedRunner`` and enable gradient accumulation,
|
||||
the original `max_iters` should be multiplied by
|
||||
``accumulative_counts``.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_scale=512., **kwargs):
|
||||
def __init__(self, loss_scale='dynamic', **kwargs):
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
|
||||
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
|
||||
assert torch.cuda.is_available(), (
|
||||
@ -62,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper):
|
||||
loss (torch.Tensor): The loss of current iteration.
|
||||
"""
|
||||
self.loss_scaler.scale(loss).backward()
|
||||
self._inner_count += 1
|
||||
|
||||
def step(self):
|
||||
"""Update parameters with :attr:`loss_scaler`."""
|
||||
@ -104,7 +112,13 @@ class AmpOptimWrapper(OptimWrapper):
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
@contextmanager
|
||||
def precision_context(self):
|
||||
"""A wrapper of ``torch.cuda.amp.autocast``"""
|
||||
with torch.cuda.amp.autocast():
|
||||
def optim_context(self, model: nn.Module):
|
||||
"""Enables the context for mixed precision training, and enables the
|
||||
context for disabling gradient synchronization during gradient
|
||||
accumulation context.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The training model.
|
||||
"""
|
||||
with super().optim_context(model), torch.cuda.amp.autocast():
|
||||
yield
|
||||
|
@ -73,7 +73,7 @@ class DefaultOptimWrapperConstructor:
|
||||
Optional fields are
|
||||
|
||||
- any arguments of the corresponding optimizer wrapper type,
|
||||
e.g., accumulative_iters, clip_grad, etc.
|
||||
e.g., accumulative_counts, clip_grad, etc.
|
||||
|
||||
The positional fields of ``optimizer`` are
|
||||
|
||||
|
@ -27,21 +27,31 @@ class OptimWrapper:
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Optimizer used to update model parameters.
|
||||
accumulative_iters (int): The number of iterations to accumulate
|
||||
accumulative_counts (int): The number of iterations to accumulate
|
||||
gradients. The parameters will be updated per
|
||||
``accumulative_iters``.
|
||||
``accumulative_counts``.
|
||||
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
|
||||
the arguments of ``torch.nn.utils.clip_grad``.
|
||||
|
||||
Warnings:
|
||||
If ``accumulative_iters`` is larger than 1, :meth:`update_params` must
|
||||
be called in the context of ``accumulate_grad``.
|
||||
Note:
|
||||
If ``accumulative_counts`` is larger than 1, perform
|
||||
:meth:`update_params` under the context of ``optim_context``
|
||||
could avoid unnecessary gradient synchronization.
|
||||
|
||||
Note:
|
||||
If you use ``IterBasedRunner`` and enable gradient accumulation,
|
||||
the original `max_iters` should be multiplied by
|
||||
``accumulative_counts``.
|
||||
|
||||
Note:
|
||||
The subclass should ensure that once :meth:`update_params` is called,
|
||||
``_inner_count += 1`` is automatically performed.
|
||||
|
||||
Examples:
|
||||
>>> # Config sample of OptimWrapper.
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> type='OptimWrapper',
|
||||
>>> accumulative_iters=3,
|
||||
>>> _accumulative_counts=1,
|
||||
>>> clip_grad=dict(max_norm=0.2))
|
||||
>>> # Use OptimWrapper to update model.
|
||||
>>> import torch.nn as nn
|
||||
@ -59,29 +69,32 @@ class OptimWrapper:
|
||||
>>> for data in dataloader:
|
||||
>>> loss = model(data)
|
||||
>>> optim_wrapper.update_params(loss)
|
||||
>>> # Enable gradient accumulation. If model is a subclass instance of
|
||||
>>> # DistributedDataParallel, ``accumulate_grad`` context manager can
|
||||
>>> # avoid unnecessary gradient synchronize.
|
||||
>>> # Enable gradient accumulation
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> type='OptimWrapper',
|
||||
>>> _accumulative_counts=3,
|
||||
>>> clip_grad=dict(max_norm=0.2))
|
||||
>>> ddp_model = DistributedDataParallel(model)
|
||||
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1)
|
||||
>>> optim_wrapper = OptimWrapper(optimizer)
|
||||
>>> optim_wrapper.initialize_count_status(0, len(dataloader))
|
||||
>>> # If model is a subclass instance of DistributedDataParallel,
|
||||
>>> # `optim_context` context manager can avoid unnecessary gradient
|
||||
>>> # synchronize.
|
||||
>>> for iter, data in enumerate(dataloader):
|
||||
>>> with optim_wrapper.accumulate_grad(
|
||||
>>> model, iter, len(dataloader)):
|
||||
>>> with optim_wrapper.optim_context(ddp_model):
|
||||
>>> loss = model(data)
|
||||
>>> optim_wrapper.update_params(loss)
|
||||
>>> optim_wrapper.update_params(loss)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
accumulative_iters: int = 1,
|
||||
accumulative_counts: int = 1,
|
||||
clip_grad: Optional[dict] = None):
|
||||
assert accumulative_iters > 0, (
|
||||
'accumulative_iters at least greater than or equal to 1')
|
||||
self.accumulative_iters = accumulative_iters
|
||||
# `max_iters` and `cur_iter` is only valid in gradient accumulative
|
||||
# mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be
|
||||
# updated in the ``accumulate_grad`` context that is enabled in
|
||||
# `runner.train_loop`.
|
||||
self.cur_iter = 0
|
||||
self.max_iters = 0
|
||||
assert accumulative_counts > 0, (
|
||||
'_accumulative_counts at least greater than or equal to 1')
|
||||
self._accumulative_counts = accumulative_counts
|
||||
|
||||
assert isinstance(optimizer, Optimizer), (
|
||||
'optimizer must be a `torch.optim.Optimizer` instance, but got '
|
||||
f'{type(optimizer)}')
|
||||
@ -90,13 +103,27 @@ class OptimWrapper:
|
||||
if clip_grad is not None:
|
||||
# clip_grad_kwargs should not be non-empty dict.
|
||||
assert isinstance(clip_grad, dict) and clip_grad, (
|
||||
'If `clip_grad_kwargs` is not None, it should be a `dict` '
|
||||
'If `clip_grad` is not None, it should be a `dict` '
|
||||
'which is the arguments of `torch.nn.utils.clip_grad`')
|
||||
self.clip_grad_kwargs = clip_grad
|
||||
self.logger = MMLogger.get_current_instance()
|
||||
# Used to update `grad_norm` log message.
|
||||
self.message_hub = MessageHub.get_current_instance()
|
||||
self.iter_status_initialized = False
|
||||
self._inner_count = 0
|
||||
# `_max_counts` means the total number of parameter updates. It
|
||||
# ensures that the gradient of the last few iterations will not be
|
||||
# lost when the `_max_counts` is not divisible by
|
||||
# `accumulative_counts`.
|
||||
self._max_counts = -1
|
||||
# If `_inner_count` is smaller than `_divisible_counts`, the loss
|
||||
# factor used for gradient accumulation should be the same as
|
||||
# `_accumulative_counts`. If `_max_counts` has not been initialized,
|
||||
# the loss factor will always be the same as `_accumulative_counts`.
|
||||
self._divisible_counts = -1
|
||||
# The `_remainder_iter` is used for calculating loss factor at the
|
||||
# last few iterations. If `_max_counts` has not been initialized,
|
||||
# the loss factor will always be the same as `_accumulative_counts`.
|
||||
self._remainder_counts = -1
|
||||
|
||||
def update_params(self, loss: torch.Tensor) -> None:
|
||||
"""Update parameters in :attr:`optimizer`.
|
||||
@ -104,36 +131,12 @@ class OptimWrapper:
|
||||
Args:
|
||||
loss (torch.Tensor): A tensor for back propagation.
|
||||
"""
|
||||
if self.accumulative_iters == 1:
|
||||
# update parameters without gradient accumulation. The gradient
|
||||
# should not be rescaled and `loss_factor=1`.
|
||||
loss_factor = 1
|
||||
else:
|
||||
# gradient accumulation must be called in the context of
|
||||
# ``accumulate_grad``.
|
||||
assert hasattr(self, 'divisible_iters'), (
|
||||
'gradient accumulation must be performed in the context of'
|
||||
'`OptimWrapper.accumulate_grad`')
|
||||
# if `self.accumulative_iters > 1`, the gradient needs to be
|
||||
# rescaled and accumulated. In most cases, `loss_factor` equals to
|
||||
# `self.accumulative_iters`. However `self.max_iters` may not be
|
||||
# divisible `self.by accumulative_iters`, so the `loss_scale` for
|
||||
# the last few iterations needs to be recalculated.
|
||||
if self.cur_iter < self.divisible_iters:
|
||||
loss_factor = self.accumulative_iters
|
||||
else:
|
||||
loss_factor = self.remainder_iters
|
||||
assert loss_factor > 0, (
|
||||
'loss_factor should be larger than zero! This error could '
|
||||
'happened when gradient accumulation context enabled with an '
|
||||
'error `cur_iter` or `max_iters` please check your loop')
|
||||
|
||||
loss = loss / loss_factor
|
||||
loss = self.scale_loss(loss)
|
||||
self.backward(loss)
|
||||
# Update parameters only if `self.cur_iter` is divisible by
|
||||
# `self.accumulative_iters` or `self.cur_iter` equals to
|
||||
# `self.max_iters`
|
||||
if self._should_update(self.cur_iter, self.max_iters):
|
||||
# Update parameters only if `self._inner_count` is divisible by
|
||||
# `self._accumulative_counts` or `self._inner_count` equals to
|
||||
# `self._max_counts`
|
||||
if self.should_update():
|
||||
self.step()
|
||||
self.zero_grad()
|
||||
|
||||
@ -145,10 +148,15 @@ class OptimWrapper:
|
||||
required logic. For example, ``torch.cuda.amp`` require some extra
|
||||
operation on GradScaler during backward process.
|
||||
|
||||
Note:
|
||||
If subclasses inherit from ``OptimWrapper`` override
|
||||
``backward``, ``_inner_count +=1`` must be implemented.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss of current iteration.
|
||||
"""
|
||||
loss.backward()
|
||||
self._inner_count += 1
|
||||
|
||||
def zero_grad(self) -> None:
|
||||
"""A wrapper of ``Optimizer.zero_grad``.
|
||||
@ -218,7 +226,7 @@ class OptimWrapper:
|
||||
Provide unified interface to get learning rate of optimizer.
|
||||
|
||||
Returns:
|
||||
List[float]: Learning rate of the optimizer.
|
||||
Dict[str, List[float]]: Learning rate of the optimizer.
|
||||
"""
|
||||
lr = [group['lr'] for group in self.param_groups]
|
||||
return dict(lr=lr)
|
||||
@ -229,7 +237,7 @@ class OptimWrapper:
|
||||
Provide unified interface to get momentum of optimizer.
|
||||
|
||||
Returns:
|
||||
List[float]: Momentum of the optimizer.
|
||||
Dict[str, List[float]]: Momentum of the optimizer.
|
||||
"""
|
||||
momentum = []
|
||||
for group in self.param_groups:
|
||||
@ -244,49 +252,35 @@ class OptimWrapper:
|
||||
return dict(momentum=momentum)
|
||||
|
||||
@contextmanager
|
||||
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
|
||||
"""A Context manager for gradient accumulation and avoiding unnecessary
|
||||
gradient synchronization during gradient accumulation.
|
||||
def optim_context(self, model: nn.Module):
|
||||
"""A Context for gradient accumulation and automatic mix precision
|
||||
training.
|
||||
|
||||
If subclasses need to enable the context for mix precision training,
|
||||
e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be
|
||||
enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32
|
||||
training, ``optim_context`` will only enable the context for
|
||||
blocking the unnecessary gradient synchronization during gradient
|
||||
accumulation
|
||||
|
||||
If model is an instance with ``no_sync`` method (which means
|
||||
blocking the gradient synchronization) and
|
||||
``self.accumulative_iters != 1``. The model will not automatically
|
||||
``self._accumulative_counts != 1``. The model will not automatically
|
||||
synchronize gradients if ``cur_iter`` is divisible by
|
||||
``self.accumulative_iters``. Otherwise, this method will enable an
|
||||
``self._accumulative_counts``. Otherwise, this method will enable an
|
||||
empty context.
|
||||
|
||||
Warnings:
|
||||
This context manager must be enabled if you want to use
|
||||
gradient accumulation.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The training model.
|
||||
cur_iter (int): Current iteration during training process.
|
||||
max_iters (int): Maximum training iteration.
|
||||
"""
|
||||
assert max_iters > 0, '`max_iters` must be larger than zero'
|
||||
self.cur_iter = cur_iter
|
||||
self.max_iters = max_iters
|
||||
if not self.iter_status_initialized:
|
||||
self._initilize_iter_status(model)
|
||||
# During gradient accumulation process, the gradient synchronize
|
||||
# should only happen before updating parameters.
|
||||
if (not self._should_update(cur_iter, max_iters)
|
||||
and hasattr(model, 'no_sync')):
|
||||
if not self.should_sync() and hasattr(model, 'no_sync'):
|
||||
with model.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def precision_context(self):
|
||||
"""precision context which enables an empty context by default.
|
||||
|
||||
The subclass used for mixed or low precision training needs to override
|
||||
this method.
|
||||
"""
|
||||
yield
|
||||
|
||||
def _clip_grad(self) -> None:
|
||||
"""Clip the gradients of parameters."""
|
||||
params: List[torch.Tensor] = []
|
||||
@ -300,50 +294,117 @@ class OptimWrapper:
|
||||
**self.clip_grad_kwargs)
|
||||
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
|
||||
|
||||
def _initilize_iter_status(self, model: nn.Module) -> None:
|
||||
def initialize_count_status(self, model: nn.Module, init_counts: int,
|
||||
max_counts: int) -> None:
|
||||
"""Initialize gradient accumulation related attributes.
|
||||
|
||||
``OptimWrapper`` can be used without calling
|
||||
``initialize_iter_status``. However, Consider the case of ``len(
|
||||
dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is
|
||||
not divisible by 3, the last iteration will not trigger
|
||||
``optimizer.step()``, resulting in one less parameter updating.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Training model
|
||||
init_counts (int): The initial value of the inner count.
|
||||
max_counts (int): The maximum value of the inner count.
|
||||
"""
|
||||
if self.max_iters % self.accumulative_iters != 0:
|
||||
self._inner_count = init_counts
|
||||
self._max_counts = max_counts
|
||||
if self._inner_count % self._accumulative_counts != 0:
|
||||
self.logger.warning(
|
||||
'Resume iter number is not divisible by accumulative_iters in '
|
||||
'GradientCumulativeOptimizerHook, which means the gradient of '
|
||||
'some iters is lost and the result may be influenced slightly.'
|
||||
)
|
||||
'Resumed iteration number is not divisible by '
|
||||
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, '
|
||||
'which means the gradient of some iterations is lost and the '
|
||||
'result may be influenced slightly.')
|
||||
|
||||
if has_batch_norm(model) and self.accumulative_iters > 1:
|
||||
if has_batch_norm(model) and self._accumulative_counts > 1:
|
||||
self.logger.warning(
|
||||
'Gradient accumulative may slightly decrease '
|
||||
'performance because the model has BatchNorm layers.')
|
||||
residual_iters = self.max_iters - self.cur_iter
|
||||
residual_counts = max_counts - init_counts
|
||||
# The maximum number of training iteration that is divisible by
|
||||
# accumulative_iters.
|
||||
self.divisible_iters = (
|
||||
residual_iters // self.accumulative_iters *
|
||||
self.accumulative_iters)
|
||||
# Remainder of ``self.max_iters`` divided by ``self.max_iters``
|
||||
self.remainder_iters = residual_iters - self.divisible_iters
|
||||
self.iter_status_initialized = True
|
||||
# `_accumulative_counts`.
|
||||
self._divisible_counts = (
|
||||
residual_counts // self._accumulative_counts *
|
||||
self._accumulative_counts)
|
||||
# Remainder of `_max_counts` divided by `_accumulative_counts`
|
||||
self._remainder_counts = residual_counts - self._divisible_counts
|
||||
|
||||
def _should_update(self, cur_iter: int, max_iters: int) -> bool:
|
||||
"""Should optim_wrapper update parameters or synchronized gradient at
|
||||
current iteration.
|
||||
def should_update(self) -> bool:
|
||||
"""Decide whether the parameters should be updated at the current
|
||||
iteration.
|
||||
|
||||
Args:
|
||||
cur_iter (int): Current iteration of training process.
|
||||
max_iters (int): Maximum iterations of training process.
|
||||
Called by :meth:`update_params` and check whether the optimizer
|
||||
wrapper should update parameters at current iteration.
|
||||
|
||||
Returns:
|
||||
bool: Whether to update parameters or synchronized gradient.
|
||||
bool: Whether to update parameters.
|
||||
"""
|
||||
return ((cur_iter + 1) % self.accumulative_iters == 0
|
||||
or cur_iter + 1 == max_iters)
|
||||
return (self._inner_count % self._accumulative_counts == 0
|
||||
or self._inner_count == self._max_counts)
|
||||
|
||||
def should_sync(self) -> bool:
|
||||
"""Decide whether the automatic gradient synchronization should be
|
||||
allowed at the current iteration.
|
||||
|
||||
It takes effect when gradient accumulation is used to skip
|
||||
synchronization at the iterations where the parameter is not updated.
|
||||
|
||||
Since ``should_sync`` is called by :meth:`optim_context`, and it is
|
||||
called before :meth:`backward` which means ``self._inner_count += 1``
|
||||
has not happened yet. Therefore, ``self._inner_count += 1`` should be
|
||||
performed manually here.
|
||||
|
||||
Returns:
|
||||
bool: Whether to block the automatic gradient synchronization.
|
||||
"""
|
||||
return ((self._inner_count + 1) % self._accumulative_counts == 0
|
||||
or (self._inner_count + 1) == self._max_counts)
|
||||
|
||||
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||
"""Get scaled loss according to ``_accumulative_counts``,
|
||||
``_inner_count`` and max_counts.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): Original loss calculated by model.
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor): Scaled loss.
|
||||
"""
|
||||
if self._accumulative_counts == 1:
|
||||
# update parameters without gradient accumulation. The gradient
|
||||
# should not be rescaled and `loss_factor=1`.
|
||||
loss_factor = 1
|
||||
elif self._max_counts == -1:
|
||||
loss_factor = self._accumulative_counts
|
||||
else:
|
||||
# if `self._accumulative_counts > 1`, the gradient needs to be
|
||||
# rescaled and accumulated. In most cases, `loss_factor` equals to
|
||||
# `self._accumulative_counts`. However, `self._max_counts` may not
|
||||
# be divisible by `self._accumulative_counts`, so the
|
||||
# `loss_scale` for the last few iterations needs to be
|
||||
# recalculated.
|
||||
if self._inner_count < self._divisible_counts:
|
||||
loss_factor = self._accumulative_counts
|
||||
else:
|
||||
loss_factor = self._remainder_counts
|
||||
assert loss_factor > 0, (
|
||||
'loss_factor should be larger than zero! This error could '
|
||||
'happened when initialize_iter_status called with an '
|
||||
'error `init_counts` or `max_counts`')
|
||||
|
||||
loss = loss / loss_factor
|
||||
return loss
|
||||
|
||||
@property
|
||||
def inner_count(self):
|
||||
"""Get the number of updating parameters of optimizer wrapper."""
|
||||
return self._inner_count
|
||||
|
||||
def __repr__(self):
|
||||
wrapper_info = f'Type: {type(self).__name__}\n' \
|
||||
f'accumulative_iters: {self.accumulative_iters}\n' \
|
||||
f'optimizer: \n'
|
||||
wrapper_info = (f'Type: {type(self).__name__}\n'
|
||||
f'_accumulative_counts: {self._accumulative_counts}\n'
|
||||
'optimizer: \n')
|
||||
optimizer_str = repr(self.optimizer) + '\n'
|
||||
return wrapper_info + optimizer_str
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
@ -41,24 +40,10 @@ class OptimWrapperDict(OptimWrapper):
|
||||
"""
|
||||
|
||||
def __init__(self, **optim_wrapper_dict: OptimWrapper):
|
||||
first_key = next(iter(optim_wrapper_dict))
|
||||
first_optim_wrapper = optim_wrapper_dict[first_key]
|
||||
assert isinstance(first_optim_wrapper, OptimWrapper), (
|
||||
'Each argument of `OptimWrapperDict` must be an `OptimWrapper '
|
||||
'instance`')
|
||||
optim_wrapper_class = type(first_optim_wrapper)
|
||||
for key, value in optim_wrapper_dict.items():
|
||||
assert type(value) == optim_wrapper_class, (
|
||||
f'All optimizer wrappers should have the same type, but found'
|
||||
f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}'
|
||||
)
|
||||
if value.accumulative_iters != 1:
|
||||
warnings.warn(
|
||||
f'The `accumulative_iters` of {key} is '
|
||||
f'{value.accumulative_iters}. OptimWrapperDict '
|
||||
'will not enable any `accumulate_grad` context of its '
|
||||
'optimizer wrappers. You should access the corresponding '
|
||||
'optimizer wrapper to enable the context.')
|
||||
assert isinstance(value, OptimWrapper), (
|
||||
'`OptimWrapperDict` only accept OptimWrapper instance, '
|
||||
f'but got {key}: {type(value)}')
|
||||
self.optim_wrappers = optim_wrapper_dict
|
||||
|
||||
def update_params(self, loss: torch.Tensor) -> None:
|
||||
@ -69,9 +54,8 @@ class OptimWrapperDict(OptimWrapper):
|
||||
Therefore, this method is not implemented. The optimizer wrapper of
|
||||
OptimWrapperDict should be accessed and call its `update_params.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'You should access the OptimWrapper of the '
|
||||
'OptimWrapperDict and call its `update_params`')
|
||||
raise NotImplementedError('`update_params` should be called by each '
|
||||
'optimizer separately`')
|
||||
|
||||
def backward(self, loss: torch.Tensor) -> None:
|
||||
"""Since OptimWrapperDict doesn't know which optimizer wrapper's
|
||||
@ -81,14 +65,14 @@ class OptimWrapperDict(OptimWrapper):
|
||||
The optimizer wrapper of OptimWrapperDict should be accessed and call
|
||||
its `backward.
|
||||
"""
|
||||
raise NotImplementedError('You should access the OptimWrapper of the '
|
||||
'OptimWrapperDict and call its `backward`')
|
||||
raise NotImplementedError('`backward` should be called by each '
|
||||
'optimizer separately`')
|
||||
|
||||
def step(self) -> None:
|
||||
"""Since the backward method is not implemented, the step should not be
|
||||
implemented either."""
|
||||
raise NotImplementedError('You should access the OptimWrapper of the '
|
||||
'OptimWrapperDict and call its `step`')
|
||||
raise NotImplementedError('`step` should be called by each '
|
||||
'optimizer separately`')
|
||||
|
||||
def zero_grad(self) -> None:
|
||||
"""Set the gradients of all optimizer wrappers to zero."""
|
||||
@ -96,49 +80,29 @@ class OptimWrapperDict(OptimWrapper):
|
||||
optim_wrapper.zero_grad()
|
||||
|
||||
@contextmanager
|
||||
def precision_context(self):
|
||||
optim_wrapper = next(iter(self.optim_wrappers.values()))
|
||||
with optim_wrapper.precision_context():
|
||||
yield
|
||||
def optim_context(self, model: nn.Module):
|
||||
"""``optim_context`` should be called by each optimizer separately."""
|
||||
raise NotImplementedError(
|
||||
'`optim_context` should be called by each optimizer separately')
|
||||
|
||||
@contextmanager
|
||||
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
|
||||
"""Enable ``accumulate_grad`` contexts of all optimizer wrappers.
|
||||
def initialize_count_status(self, model: nn.Module, cur_iter,
|
||||
max_iters) -> None:
|
||||
"""Do nothing but provide unified interface for :obj:`OptimWrapper`
|
||||
|
||||
Warning:
|
||||
Consider there is only one ``model`` arguments for all
|
||||
optimizer wrappers, all optimizer wrappers are working under the
|
||||
same ``model.no_sync`` context. For example, there is a model
|
||||
composed of model_a(optimizer_a) and model_b(optimizer_b).
|
||||
``OptimWrapperDict.accumulate_grad`` will further
|
||||
call ``model.no_sync``, which will block the gradient
|
||||
synchronization of both a and b. If optimizer_a and
|
||||
optimizer_b have different ``accumulative_iters``, and want to
|
||||
block the gradient synchronization of model_a and model_b
|
||||
separately, the model should not implement the ``no_sync``
|
||||
method(or enable an empty context). The ``accumulate_grad`` context
|
||||
should be enabled inside the model by accessing corresponding
|
||||
optimizer wrapper.
|
||||
Since ``OptimWrapperDict`` does not know the correspondence between
|
||||
model and optimizer wrapper. ``initialize_iter_status`` will do nothing
|
||||
and each optimizer wrapper should call ``initialize_iter_status``
|
||||
separately.
|
||||
"""
|
||||
with ExitStack() as stack:
|
||||
for optim_wrapper in self.optim_wrappers.values():
|
||||
stack.enter_context(
|
||||
optim_wrapper.accumulate_grad(model, cur_iter, max_iters))
|
||||
yield
|
||||
return
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load the state dictionary from the ``state_dict``.
|
||||
|
||||
Args:
|
||||
state_dict (dict): Each key-value pair in `state_dict` represents
|
||||
the name and the state dictionary of corresponding
|
||||
:obj:`OptimWrapper`.
|
||||
"""
|
||||
for name, _state_dict in state_dict.items():
|
||||
assert name in self.optim_wrappers, (
|
||||
f'Mismatched `state_dict`! cannot found {name} in '
|
||||
'OptimWrapperDict')
|
||||
self.optim_wrappers[name].load_state_dict(_state_dict)
|
||||
@property
|
||||
def param_groups(self):
|
||||
"""Returns the parameter groups of each OptimWrapper."""
|
||||
param_groups = dict()
|
||||
for key, value in self.optim_wrappers.items():
|
||||
param_groups[key] = value.param_groups
|
||||
return param_groups
|
||||
|
||||
def get_lr(self) -> Dict[str, List[float]]:
|
||||
"""Get the learning rate of all optimizers.
|
||||
@ -175,6 +139,20 @@ class OptimWrapperDict(OptimWrapper):
|
||||
state_dict[name] = optim_wrapper.state_dict()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load the state dictionary from the ``state_dict``.
|
||||
|
||||
Args:
|
||||
state_dict (dict): Each key-value pair in `state_dict` represents
|
||||
the name and the state dictionary of corresponding
|
||||
:obj:`OptimWrapper`.
|
||||
"""
|
||||
for name, _state_dict in state_dict.items():
|
||||
assert name in self.optim_wrappers, (
|
||||
f'Mismatched `state_dict`! cannot found {name} in '
|
||||
'OptimWrapperDict')
|
||||
self.optim_wrappers[name].load_state_dict(_state_dict)
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
|
||||
"""A generator to get the name and corresponding
|
||||
:obj:`OptimWrapper`"""
|
||||
|
@ -101,12 +101,9 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
'before_train_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# Enable gradient accumulation mode and avoid unnecessary gradient
|
||||
# synchronization during gradient accumulation process.
|
||||
with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
|
||||
self._iter,
|
||||
self._max_iters):
|
||||
# outputs should be a dict of loss.
|
||||
outputs = self.runner.model.train_step(
|
||||
data_batch, optim_wrapper=self.runner.optim_wrapper)
|
||||
# outputs should be a dict of loss.
|
||||
outputs = self.runner.model.train_step(
|
||||
data_batch, optim_wrapper=self.runner.optim_wrapper)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
@ -204,19 +201,16 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
|
||||
# Enable gradient accumulation mode and avoid unnecessary gradient
|
||||
# synchronization during gradient accumulation process.
|
||||
with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
|
||||
self._iter,
|
||||
self._max_iters):
|
||||
# train_logs should be a dict of loss.
|
||||
train_logs = self.runner.model.train_step(
|
||||
data_batch, optim_wrapper=self.runner.optim_wrapper)
|
||||
self.runner.message_hub.update_info('train_logs', train_logs)
|
||||
# outputs should be a dict of loss.
|
||||
outputs = self.runner.model.train_step(
|
||||
data_batch, optim_wrapper=self.runner.optim_wrapper)
|
||||
self.runner.message_hub.update_info('train_logs', outputs)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=self._iter,
|
||||
data_batch=data_batch,
|
||||
outputs=train_logs)
|
||||
outputs=outputs)
|
||||
self._iter += 1
|
||||
|
||||
|
||||
|
@ -926,7 +926,7 @@ class Runner:
|
||||
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
>>> optim_wrapper
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
accumulative_counts: 1
|
||||
optimizer:
|
||||
SGD (
|
||||
Parameter Group 0
|
||||
@ -941,7 +941,7 @@ class Runner:
|
||||
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
>>> optim_wrapper
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
accumulative_counts: 1
|
||||
optimizer:
|
||||
SGD (
|
||||
Parameter Group 0
|
||||
@ -965,7 +965,7 @@ class Runner:
|
||||
>>> optim_wrapper
|
||||
name: generator
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
accumulative_counts: 1
|
||||
optimizer:
|
||||
SGD (
|
||||
Parameter Group 0
|
||||
@ -977,7 +977,7 @@ class Runner:
|
||||
)
|
||||
name: discriminator
|
||||
Type: OptimWrapper
|
||||
accumulative_iters: 1
|
||||
accumulative_counts: 1
|
||||
optimizer:
|
||||
'discriminator': Adam (
|
||||
Parameter Group 0
|
||||
@ -1528,7 +1528,6 @@ class Runner:
|
||||
# `build_optimizer` should be called before `build_param_scheduler`
|
||||
# because the latter depends on the former
|
||||
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
|
||||
|
||||
# Automatically scaling lr by linear scaling rule
|
||||
self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
|
||||
|
||||
@ -1540,9 +1539,14 @@ class Runner:
|
||||
self._val_loop = self.build_val_loop(
|
||||
self._val_loop) # type: ignore
|
||||
|
||||
# TODO: add a contextmanager to avoid calling `before_run` many times
|
||||
self.call_hook('before_run')
|
||||
# Initiate inner count of `optim_wrapper`.
|
||||
self.optim_wrapper.initialize_count_status(
|
||||
self.model,
|
||||
self._train_loop.iter, # type: ignore
|
||||
self._train_loop.max_iters) # type: ignore
|
||||
|
||||
# TODO: add a contextmanager to avoid calling `before_run` many times
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
|
@ -8,9 +8,10 @@ import torch.distributed as torch_dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel)
|
||||
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
|
||||
@ -23,9 +24,9 @@ class ToyModel(BaseModel):
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x, data_samples=None, mode='tensor'):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
if mode == 'loss':
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return dict(loss=x)
|
||||
elif mode == 'predict':
|
||||
return x
|
||||
@ -58,24 +59,41 @@ class ComplexModel(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class TestModelWrapper(MultiProcessTestCase):
|
||||
class TestDistributedDataParallel(MultiProcessTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available(), reason='cuda should be available')
|
||||
def test_train_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
# Test `optim_wrapper` is a instance of `OptimWrapper`
|
||||
model = ToyModel()
|
||||
# Mixed precision training and gradient asynchronous should be valid at
|
||||
# the same time
|
||||
model = ToyModel().cuda()
|
||||
ddp_model = MMDistributedDataParallel(module=model)
|
||||
optimizer = SGD(ddp_model.parameters(), lr=0)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
|
||||
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
||||
optim_wrapper = AmpOptimWrapper(
|
||||
optimizer=optimizer, accumulative_counts=3)
|
||||
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss']
|
||||
self.assertIs(res.dtype, torch.float16)
|
||||
grad = ddp_model.module.conv1.weight.grad
|
||||
all_grads = all_gather(grad)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# Gradient accumulation
|
||||
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
|
||||
|
||||
# Test update params and clean grads.
|
||||
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
|
||||
grad = ddp_model.module.conv1.weight.grad
|
||||
assert_allclose(grad, torch.zeros_like(grad))
|
||||
all_grads = all_gather(grad)
|
||||
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
|
||||
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
|
||||
|
||||
def test_val_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
@ -107,7 +125,7 @@ class TestModelWrapper(MultiProcessTestCase):
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available(), reason='cuda should be available')
|
||||
class TestMMSeparateDistributedDataParallel(TestModelWrapper):
|
||||
class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
|
||||
|
||||
def test_train_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
|
@ -45,7 +45,7 @@ class ToyModel2(nn.Module):
|
||||
|
||||
|
||||
class TestOptimWrapper(MultiProcessTestCase):
|
||||
# Test `OptimWrapper.accumulate_grad` will block the gradient
|
||||
# Test `OptimWrapper.optim_context` will block the gradient
|
||||
# synchronization when using gradient accumulation strategy in distributed
|
||||
# data parallel training.
|
||||
def setUp(self) -> None:
|
||||
@ -61,94 +61,84 @@ class TestOptimWrapper(MultiProcessTestCase):
|
||||
|
||||
def test_init(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
self.assertEqual(optim_wrapper.optimizer, self.optimizer)
|
||||
self.assertIs(optim_wrapper.optimizer, self.optimizer)
|
||||
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
|
||||
self.assertEqual(optim_wrapper.accumulative_iters, 1)
|
||||
self.assertEqual(optim_wrapper._accumulative_counts, 1)
|
||||
self.assertIs(optim_wrapper.logger, self.logger)
|
||||
self.assertIs(optim_wrapper.message_hub, self.message_hub)
|
||||
self.assertEqual(optim_wrapper._inner_count, 0)
|
||||
self.assertEqual(optim_wrapper._max_counts, -1)
|
||||
self.assertEqual(optim_wrapper._divisible_counts, -1)
|
||||
self.assertEqual(optim_wrapper._remainder_counts, -1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'If `clip_grad_kwargs` is not None'):
|
||||
'If `clip_grad` is not None'):
|
||||
OptimWrapper(self.optimizer, clip_grad=[])
|
||||
|
||||
def test_update_params(self):
|
||||
# Test update params every iteration.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1)
|
||||
self._mock_method(optim_wrapper)
|
||||
loss = torch.tensor(1)
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.backward.assert_called_with(torch.tensor(1))
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
|
||||
optim_wrapper.step.assert_called_with()
|
||||
optim_wrapper.zero_grad.assert_called_with()
|
||||
|
||||
with optim_wrapper.accumulate_grad(self.model, 2, 100):
|
||||
optim_wrapper.update_params(torch.tensor(1))
|
||||
optim_wrapper.backward.assert_called_with(torch.tensor(1))
|
||||
optim_wrapper.step.assert_called_with()
|
||||
optim_wrapper.zero_grad.assert_called_with()
|
||||
|
||||
# It will raise an error if `accumulative_iters > 1` and
|
||||
# `accumulate_grad` is not enabled.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
# Test gradient accumulation.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
||||
self._mock_method(optim_wrapper)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'gradient accumulation must be'):
|
||||
optim_wrapper.update_params(loss)
|
||||
# `iter=0`, accumulate gradient and do not update params.
|
||||
loss = torch.tensor(1)
|
||||
optim_wrapper.update_params(loss)
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3)
|
||||
optim_wrapper.step.assert_not_called()
|
||||
optim_wrapper.zero_grad.assert_not_called()
|
||||
|
||||
# `iter=0`, Call `optimizer_step` first time.
|
||||
with optim_wrapper.accumulate_grad(
|
||||
self.model, cur_iter=0, max_iters=100):
|
||||
loss = torch.tensor(1)
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3)
|
||||
optim_wrapper.step.assert_not_called()
|
||||
optim_wrapper.zero_grad.assert_not_called()
|
||||
# gradient accumulate
|
||||
optim_wrapper.update_params(loss)
|
||||
self.assertEqual(optim_wrapper._inner_count, 2)
|
||||
|
||||
# `iter=2`, Call `optimizer_step` first time.
|
||||
with optim_wrapper.accumulate_grad(
|
||||
self.model, cur_iter=2, max_iters=100):
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
# `iter=2`, update params.
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
self._mock_method(optim_wrapper)
|
||||
# Test end of training.
|
||||
with optim_wrapper.accumulate_grad(
|
||||
self.model, cur_iter=99, max_iters=100):
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
optim_wrapper.backward.assert_called_with(1)
|
||||
|
||||
# If ``accumulative_iters > 1``, call ``update_params`` with
|
||||
# non-accumulate_grad context will raise an Assertion error
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
|
||||
optim_wrapper.accumulative_iters = 2
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'gradient accumulation must be performed'):
|
||||
optim_wrapper.update_params(loss)
|
||||
# Test end of training without calling `initialize_iter_status`
|
||||
optim_wrapper._inner_count = 99
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_not_called()
|
||||
optim_wrapper.zero_grad.assert_not_called()
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3)
|
||||
self._mock_method(optim_wrapper)
|
||||
|
||||
def test_initilize_iter_status(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
optim_wrapper._initilize_iter_status(self.model)
|
||||
self.assertEqual(optim_wrapper.divisible_iters, 0)
|
||||
self.assertEqual(optim_wrapper.remainder_iters, 0)
|
||||
# After calling `initialize_iter_status`, params will be updated at the
|
||||
# last iteration, and the `loss_scaler` will be adjusted.
|
||||
optim_wrapper.initialize_count_status(self.model, 99, 100)
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
|
||||
|
||||
def test_initialize_iter_status(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
||||
optim_wrapper.initialize_count_status(self.model, 0, 100)
|
||||
self.assertEqual(optim_wrapper._divisible_counts, 99)
|
||||
self.assertEqual(optim_wrapper._remainder_counts, 1)
|
||||
|
||||
# Indivisible cur_iter will output warning.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
optim_wrapper.cur_iter = 0
|
||||
optim_wrapper.max_iters = 100
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
||||
with self.assertLogs(self.logger) as cm:
|
||||
optim_wrapper._initilize_iter_status(self.model)
|
||||
optim_wrapper.initialize_count_status(self.model, 2, 100)
|
||||
self.assertEqual(len(cm.output), 1)
|
||||
self.assertRegex(cm.records[0].msg, 'Resume iter number is not')
|
||||
self.assertRegex(cm.records[0].msg, 'Resumed iteration number')
|
||||
|
||||
# Model with batch norm will output warning.
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
|
||||
optim_wrapper.cur_iter = 0
|
||||
optim_wrapper.max_iters = 99
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
||||
model = nn.BatchNorm2d(1)
|
||||
with self.assertLogs(self.logger) as cm:
|
||||
optim_wrapper._initilize_iter_status(model)
|
||||
optim_wrapper.initialize_count_status(model, 0, 99)
|
||||
self.assertEqual(len(cm.output), 1)
|
||||
self.assertRegex(cm.records[0].msg, 'Gradient accumulative')
|
||||
|
||||
@ -214,40 +204,40 @@ class TestOptimWrapper(MultiProcessTestCase):
|
||||
self.assertEqual(optim_wrapper.param_groups,
|
||||
self.optimizer.param_groups)
|
||||
|
||||
def test_accumulate_grad(self):
|
||||
def test_optim_context(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel2()
|
||||
ddp_model = DistributedDataParallel(model)
|
||||
optimizer = SGD(ddp_model.parameters(), lr=0.01)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
|
||||
optim_wrapper.zero_grad()
|
||||
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
|
||||
# Automatically sync grads if `accumulative_iters` = 1
|
||||
inputs = torch.randn(1, 1, 1, 1) * self.rank
|
||||
ddp_model(inputs).sum().backward()
|
||||
grad = model.conv.weight.grad
|
||||
all_grads = all_gather(grad)
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# Automatically sync grads if `accumulative_counts` = 1
|
||||
optim_wrapper.initialize_count_status(model, 0, 100)
|
||||
inputs = torch.randn(1, 1, 1, 1) * self.rank
|
||||
ddp_model(inputs).sum().backward()
|
||||
grad = model.conv.weight.grad
|
||||
all_grads = all_gather(grad)
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# Do not sync grads when `optim_wrapper.cur_iter` cannot be
|
||||
# divided by `optim_wrapper.accumulative_iters`
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3)
|
||||
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
|
||||
ddp_model(inputs).sum().backward()
|
||||
all_grads = all_gather(model.conv.weight.grad)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# sync grads if `cur_iter == 2`
|
||||
with optim_wrapper.accumulate_grad(ddp_model, 2, 100):
|
||||
ddp_model(inputs).sum().backward()
|
||||
all_grads = all_gather(model.conv.weight.grad)
|
||||
# divided by `optim_wrapper._accumulative_counts`
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=3)
|
||||
optim_wrapper.initialize_count_status(model, 0, 100)
|
||||
with optim_wrapper.optim_context(ddp_model):
|
||||
loss = ddp_model(inputs).sum()
|
||||
loss.backward()
|
||||
all_grads = all_gather(model.conv.weight.grad)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
def test_precision_context(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer)
|
||||
with optim_wrapper.precision_context():
|
||||
pass
|
||||
# sync grads if `cur_iter == 2`
|
||||
optim_wrapper.initialize_count_status(model, 2, 100)
|
||||
with optim_wrapper.optim_context(ddp_model):
|
||||
loss = ddp_model(inputs).sum()
|
||||
loss.backward()
|
||||
all_grads = all_gather(model.conv.weight.grad)
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
def _init_dist_env(self, rank, world_size):
|
||||
"""Initialize the distributed environment."""
|
||||
@ -260,7 +250,12 @@ class TestOptimWrapper(MultiProcessTestCase):
|
||||
# TODO Test the real interface after add testing tool function which can
|
||||
# test the function or method is read called.
|
||||
def _mock_method(self, optim_wrapper):
|
||||
optim_wrapper.backward = MagicMock()
|
||||
|
||||
def mock_methd(loss):
|
||||
optim_wrapper._inner_count += 1
|
||||
optim_wrapper.scaled_loss = loss
|
||||
|
||||
optim_wrapper.backward = mock_methd
|
||||
optim_wrapper.step = MagicMock()
|
||||
optim_wrapper.zero_grad = MagicMock()
|
||||
|
||||
@ -376,9 +371,9 @@ class TestAmpOptimWrapper(TestCase):
|
||||
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
|
||||
'>= 1.6')
|
||||
def test_precision_context(self):
|
||||
def test_optim_context(self):
|
||||
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
|
||||
with amp_optim_wrapper.precision_context():
|
||||
with amp_optim_wrapper.optim_context(self.model):
|
||||
x = torch.randn(1, 1, 1, 1).cuda()
|
||||
y = nn.Conv2d(1, 1, 1).cuda()(x)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
||||
|
@ -1,69 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import contextmanager
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
|
||||
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||
|
||||
|
||||
class TestOptimWrapperDict(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
model1 = nn.Linear(1, 1)
|
||||
model2 = nn.Linear(1, 1)
|
||||
self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8)
|
||||
self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9)
|
||||
self.model1 = nn.Linear(1, 1)
|
||||
self.model2 = nn.Linear(1, 1)
|
||||
self.optim1 = SGD(self.model1.parameters(), lr=0.1, momentum=0.8)
|
||||
self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9)
|
||||
self.optim_wrapper1 = OptimWrapper(self.optim1)
|
||||
self.optim_wrapper2 = OptimWrapper(self.optim2)
|
||||
self.optimizers_wrappers = dict(
|
||||
optim1=self.optim_wrapper1, optim2=self.optim_wrapper2)
|
||||
|
||||
@patch('torch.cuda.is_available', lambda: True)
|
||||
def test_init(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertEqual(optim_wrapper_dict.optim_wrappers,
|
||||
self.optimizers_wrappers)
|
||||
# Different types of OptimWrapper will raise an error
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'`OptimWrapperDict` only accept'):
|
||||
OptimWrapperDict(**dict(optim1=self.optim1, optim2=self.optim2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, 'All optimizer wrappers should have the same'):
|
||||
optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2)
|
||||
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'):
|
||||
optim_wrapper2 = OptimWrapper(
|
||||
optimizer=self.optim2, accumulative_iters=2)
|
||||
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
|
||||
|
||||
def test_accumulate_grad(self):
|
||||
|
||||
@contextmanager
|
||||
def context_a(a, b, *args, **kwargs):
|
||||
a[0] = 100
|
||||
yield
|
||||
a[0] = 1
|
||||
|
||||
@contextmanager
|
||||
def context_b(a, b, *args, **kwargs):
|
||||
b[0] = 200
|
||||
yield
|
||||
b[0] = 2
|
||||
|
||||
a = [0]
|
||||
b = [0]
|
||||
# Test enter the context both of `optim_wrapper1` and `optim_wrapper1`.
|
||||
def test_update_params(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.optim_wrapper1.accumulate_grad = context_a
|
||||
self.optim_wrapper2.accumulate_grad = context_b
|
||||
with optim_wrapper_dict.accumulate_grad(a, b, 0):
|
||||
self.assertEqual(a[0], 100)
|
||||
self.assertEqual(b[0], 200)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
'`update_params` should be called'):
|
||||
optim_wrapper_dict.update_params(1)
|
||||
|
||||
self.assertEqual(a[0], 1)
|
||||
self.assertEqual(b[0], 2)
|
||||
def test_backward(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
'`backward` should be called'):
|
||||
optim_wrapper_dict.backward(1)
|
||||
|
||||
def test_step(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
'`step` should be called'):
|
||||
optim_wrapper_dict.step()
|
||||
|
||||
def test_zero_grad(self):
|
||||
# Test clear all grad
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.model1(torch.randn(1, 1)).sum().backward()
|
||||
self.model2(torch.randn(1, 1)).sum().backward()
|
||||
self.assertTrue((self.model1.weight.grad != 0).any())
|
||||
self.assertTrue((self.model2.weight.grad != 0).any())
|
||||
optim_wrapper_dict.zero_grad()
|
||||
self.assertTrue((self.model1.weight.grad == 0).all())
|
||||
self.assertTrue((self.model2.weight.grad == 0).all())
|
||||
|
||||
def test_optim_context(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
'`optim_context` should be called'):
|
||||
with optim_wrapper_dict.optim_context(self.model1):
|
||||
yield
|
||||
|
||||
def test_initialize_count_status(self):
|
||||
# Test `initialize_count_status` can be called.
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
optim_wrapper_dict.initialize_count_status(self.model1, 1, 1)
|
||||
|
||||
def test_param_groups(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertEqual(optim_wrapper_dict.param_groups['optim1'],
|
||||
self.optim1.param_groups)
|
||||
self.assertEqual(optim_wrapper_dict.param_groups['optim2'],
|
||||
self.optim2.param_groups)
|
||||
|
||||
def test_get_lr(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
lr = optim_wrapper_dict.get_lr()
|
||||
self.assertEqual(lr['optim1.lr'], [0.1])
|
||||
self.assertEqual(lr['optim2.lr'], [0.2])
|
||||
|
||||
def test_get_momentum(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
momentum = optim_wrapper_dict.get_momentum()
|
||||
self.assertEqual(momentum['optim1.momentum'], [0.8])
|
||||
self.assertEqual(momentum['optim2.momentum'], [0.9])
|
||||
|
||||
def test_state_dict(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
@ -109,18 +132,6 @@ class TestOptimWrapperDict(TestCase):
|
||||
list(optim_wrapper_dict.keys()),
|
||||
list(self.optimizers_wrappers.keys()))
|
||||
|
||||
def test_get_lr(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
lr = optim_wrapper_dict.get_lr()
|
||||
self.assertEqual(lr['optim1.lr'], [0.1])
|
||||
self.assertEqual(lr['optim2.lr'], [0.2])
|
||||
|
||||
def test_get_momentum(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
momentum = optim_wrapper_dict.get_momentum()
|
||||
self.assertEqual(momentum['optim1.momentum'], [0.8])
|
||||
self.assertEqual(momentum['optim2.momentum'], [0.9])
|
||||
|
||||
def test_getitem(self):
|
||||
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
|
||||
self.assertIs(self.optimizers_wrappers['optim1'],
|
||||
|
@ -1135,8 +1135,9 @@ class TestRunner(TestCase):
|
||||
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
||||
cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
runner.train()
|
||||
self.assertEqual(runner.optim_wrapper._inner_count, 12)
|
||||
self.assertEqual(runner.optim_wrapper._max_counts, 12)
|
||||
|
||||
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
|
||||
|
||||
@ -1183,6 +1184,8 @@ class TestRunner(TestCase):
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.optim_wrapper._inner_count, 12)
|
||||
self.assertEqual(runner.optim_wrapper._max_counts, 12)
|
||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||
|
||||
self.assertEqual(len(epoch_results), 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user