[Refactor] Refactor the accumulate gradient implemention of OptimWrapper (#284)

* merge context

* update unit test

* add docstring

* fix bug in AmpOptimWrapper

* add docstring for backward

* add warning and docstring for accumuate gradient

* fix docstring

* fix docstring

* add params_group method

* fix as comment

* fix as comment

* make default_value of loss_scale to dynamic

* Fix docstring

* decouple should update and should no sync

* rename attribute in OptimWrapper

* fix docstring

* fix comment

* fix comment

* fix as comment

* fix as comment and add unit test
This commit is contained in:
Mashiro 2022-06-13 23:20:53 +08:00 committed by GitHub
parent fd295741ca
commit b7866021c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 435 additions and 357 deletions

View File

@ -111,8 +111,8 @@ class BaseModel(BaseModule):
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
# enable automatic mixed precision training context.
with optim_wrapper.precision_context():
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.data_preprocessor(data, True)
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.parse_losses(losses)

View File

@ -89,8 +89,8 @@ class MMDistributedDataParallel(DistributedDataParallel):
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
# enable automatic mixed precision training context.
with optim_wrapper.precision_context():
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.module.data_preprocessor(
data, training=True)
losses = self(batch_inputs, data_samples, mode='loss')

View File

@ -2,6 +2,7 @@
from contextlib import contextmanager
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from mmengine.registry import OPTIM_WRAPPERS
@ -25,15 +26,21 @@ class AmpOptimWrapper(OptimWrapper):
loss_scale (float or str or dict): The initial configuration of
`torch.cuda.amp.GradScaler`. See more specific arguments
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501
Defaults to ``dynamic``.
- "dynamic": Initialize GradScale without any arguments.
- float: Initialize GradScaler with ``init_scale``.
- dict: Initialize GradScaler with more detail configuration.
**kwargs: Keyword arguments passed to OptimWrapper.
Note:
If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by
``accumulative_counts``.
"""
def __init__(self, loss_scale=512., **kwargs):
def __init__(self, loss_scale='dynamic', **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert torch.cuda.is_available(), (
@ -62,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper):
loss (torch.Tensor): The loss of current iteration.
"""
self.loss_scaler.scale(loss).backward()
self._inner_count += 1
def step(self):
"""Update parameters with :attr:`loss_scaler`."""
@ -104,7 +112,13 @@ class AmpOptimWrapper(OptimWrapper):
self.optimizer.load_state_dict(state_dict)
@contextmanager
def precision_context(self):
"""A wrapper of ``torch.cuda.amp.autocast``"""
with torch.cuda.amp.autocast():
def optim_context(self, model: nn.Module):
"""Enables the context for mixed precision training, and enables the
context for disabling gradient synchronization during gradient
accumulation context.
Args:
model (nn.Module): The training model.
"""
with super().optim_context(model), torch.cuda.amp.autocast():
yield

View File

@ -73,7 +73,7 @@ class DefaultOptimWrapperConstructor:
Optional fields are
- any arguments of the corresponding optimizer wrapper type,
e.g., accumulative_iters, clip_grad, etc.
e.g., accumulative_counts, clip_grad, etc.
The positional fields of ``optimizer`` are

View File

@ -27,21 +27,31 @@ class OptimWrapper:
Args:
optimizer (Optimizer): Optimizer used to update model parameters.
accumulative_iters (int): The number of iterations to accumulate
accumulative_counts (int): The number of iterations to accumulate
gradients. The parameters will be updated per
``accumulative_iters``.
``accumulative_counts``.
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
the arguments of ``torch.nn.utils.clip_grad``.
Warnings:
If ``accumulative_iters`` is larger than 1, :meth:`update_params` must
be called in the context of ``accumulate_grad``.
Note:
If ``accumulative_counts`` is larger than 1, perform
:meth:`update_params` under the context of ``optim_context``
could avoid unnecessary gradient synchronization.
Note:
If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by
``accumulative_counts``.
Note:
The subclass should ensure that once :meth:`update_params` is called,
``_inner_count += 1`` is automatically performed.
Examples:
>>> # Config sample of OptimWrapper.
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> accumulative_iters=3,
>>> _accumulative_counts=1,
>>> clip_grad=dict(max_norm=0.2))
>>> # Use OptimWrapper to update model.
>>> import torch.nn as nn
@ -59,29 +69,32 @@ class OptimWrapper:
>>> for data in dataloader:
>>> loss = model(data)
>>> optim_wrapper.update_params(loss)
>>> # Enable gradient accumulation. If model is a subclass instance of
>>> # DistributedDataParallel, ``accumulate_grad`` context manager can
>>> # avoid unnecessary gradient synchronize.
>>> # Enable gradient accumulation
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> _accumulative_counts=3,
>>> clip_grad=dict(max_norm=0.2))
>>> ddp_model = DistributedDataParallel(model)
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>> optim_wrapper.initialize_count_status(0, len(dataloader))
>>> # If model is a subclass instance of DistributedDataParallel,
>>> # `optim_context` context manager can avoid unnecessary gradient
>>> # synchronize.
>>> for iter, data in enumerate(dataloader):
>>> with optim_wrapper.accumulate_grad(
>>> model, iter, len(dataloader)):
>>> with optim_wrapper.optim_context(ddp_model):
>>> loss = model(data)
>>> optim_wrapper.update_params(loss)
>>> optim_wrapper.update_params(loss)
"""
def __init__(self,
optimizer: Optimizer,
accumulative_iters: int = 1,
accumulative_counts: int = 1,
clip_grad: Optional[dict] = None):
assert accumulative_iters > 0, (
'accumulative_iters at least greater than or equal to 1')
self.accumulative_iters = accumulative_iters
# `max_iters` and `cur_iter` is only valid in gradient accumulative
# mode (`accumulative_iters` > 1). `cur_iter` and `max_iter` will be
# updated in the ``accumulate_grad`` context that is enabled in
# `runner.train_loop`.
self.cur_iter = 0
self.max_iters = 0
assert accumulative_counts > 0, (
'_accumulative_counts at least greater than or equal to 1')
self._accumulative_counts = accumulative_counts
assert isinstance(optimizer, Optimizer), (
'optimizer must be a `torch.optim.Optimizer` instance, but got '
f'{type(optimizer)}')
@ -90,13 +103,27 @@ class OptimWrapper:
if clip_grad is not None:
# clip_grad_kwargs should not be non-empty dict.
assert isinstance(clip_grad, dict) and clip_grad, (
'If `clip_grad_kwargs` is not None, it should be a `dict` '
'If `clip_grad` is not None, it should be a `dict` '
'which is the arguments of `torch.nn.utils.clip_grad`')
self.clip_grad_kwargs = clip_grad
self.logger = MMLogger.get_current_instance()
# Used to update `grad_norm` log message.
self.message_hub = MessageHub.get_current_instance()
self.iter_status_initialized = False
self._inner_count = 0
# `_max_counts` means the total number of parameter updates. It
# ensures that the gradient of the last few iterations will not be
# lost when the `_max_counts` is not divisible by
# `accumulative_counts`.
self._max_counts = -1
# If `_inner_count` is smaller than `_divisible_counts`, the loss
# factor used for gradient accumulation should be the same as
# `_accumulative_counts`. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
self._divisible_counts = -1
# The `_remainder_iter` is used for calculating loss factor at the
# last few iterations. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
self._remainder_counts = -1
def update_params(self, loss: torch.Tensor) -> None:
"""Update parameters in :attr:`optimizer`.
@ -104,36 +131,12 @@ class OptimWrapper:
Args:
loss (torch.Tensor): A tensor for back propagation.
"""
if self.accumulative_iters == 1:
# update parameters without gradient accumulation. The gradient
# should not be rescaled and `loss_factor=1`.
loss_factor = 1
else:
# gradient accumulation must be called in the context of
# ``accumulate_grad``.
assert hasattr(self, 'divisible_iters'), (
'gradient accumulation must be performed in the context of'
'`OptimWrapper.accumulate_grad`')
# if `self.accumulative_iters > 1`, the gradient needs to be
# rescaled and accumulated. In most cases, `loss_factor` equals to
# `self.accumulative_iters`. However `self.max_iters` may not be
# divisible `self.by accumulative_iters`, so the `loss_scale` for
# the last few iterations needs to be recalculated.
if self.cur_iter < self.divisible_iters:
loss_factor = self.accumulative_iters
else:
loss_factor = self.remainder_iters
assert loss_factor > 0, (
'loss_factor should be larger than zero! This error could '
'happened when gradient accumulation context enabled with an '
'error `cur_iter` or `max_iters` please check your loop')
loss = loss / loss_factor
loss = self.scale_loss(loss)
self.backward(loss)
# Update parameters only if `self.cur_iter` is divisible by
# `self.accumulative_iters` or `self.cur_iter` equals to
# `self.max_iters`
if self._should_update(self.cur_iter, self.max_iters):
# Update parameters only if `self._inner_count` is divisible by
# `self._accumulative_counts` or `self._inner_count` equals to
# `self._max_counts`
if self.should_update():
self.step()
self.zero_grad()
@ -145,10 +148,15 @@ class OptimWrapper:
required logic. For example, ``torch.cuda.amp`` require some extra
operation on GradScaler during backward process.
Note:
If subclasses inherit from ``OptimWrapper`` override
``backward``, ``_inner_count +=1`` must be implemented.
Args:
loss (torch.Tensor): The loss of current iteration.
"""
loss.backward()
self._inner_count += 1
def zero_grad(self) -> None:
"""A wrapper of ``Optimizer.zero_grad``.
@ -218,7 +226,7 @@ class OptimWrapper:
Provide unified interface to get learning rate of optimizer.
Returns:
List[float]: Learning rate of the optimizer.
Dict[str, List[float]]: Learning rate of the optimizer.
"""
lr = [group['lr'] for group in self.param_groups]
return dict(lr=lr)
@ -229,7 +237,7 @@ class OptimWrapper:
Provide unified interface to get momentum of optimizer.
Returns:
List[float]: Momentum of the optimizer.
Dict[str, List[float]]: Momentum of the optimizer.
"""
momentum = []
for group in self.param_groups:
@ -244,49 +252,35 @@ class OptimWrapper:
return dict(momentum=momentum)
@contextmanager
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
"""A Context manager for gradient accumulation and avoiding unnecessary
gradient synchronization during gradient accumulation.
def optim_context(self, model: nn.Module):
"""A Context for gradient accumulation and automatic mix precision
training.
If subclasses need to enable the context for mix precision training,
e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be
enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32
training, ``optim_context`` will only enable the context for
blocking the unnecessary gradient synchronization during gradient
accumulation
If model is an instance with ``no_sync`` method (which means
blocking the gradient synchronization) and
``self.accumulative_iters != 1``. The model will not automatically
``self._accumulative_counts != 1``. The model will not automatically
synchronize gradients if ``cur_iter`` is divisible by
``self.accumulative_iters``. Otherwise, this method will enable an
``self._accumulative_counts``. Otherwise, this method will enable an
empty context.
Warnings:
This context manager must be enabled if you want to use
gradient accumulation.
Args:
model (nn.Module): The training model.
cur_iter (int): Current iteration during training process.
max_iters (int): Maximum training iteration.
"""
assert max_iters > 0, '`max_iters` must be larger than zero'
self.cur_iter = cur_iter
self.max_iters = max_iters
if not self.iter_status_initialized:
self._initilize_iter_status(model)
# During gradient accumulation process, the gradient synchronize
# should only happen before updating parameters.
if (not self._should_update(cur_iter, max_iters)
and hasattr(model, 'no_sync')):
if not self.should_sync() and hasattr(model, 'no_sync'):
with model.no_sync():
yield
else:
yield
@contextmanager
def precision_context(self):
"""precision context which enables an empty context by default.
The subclass used for mixed or low precision training needs to override
this method.
"""
yield
def _clip_grad(self) -> None:
"""Clip the gradients of parameters."""
params: List[torch.Tensor] = []
@ -300,50 +294,117 @@ class OptimWrapper:
**self.clip_grad_kwargs)
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
def _initilize_iter_status(self, model: nn.Module) -> None:
def initialize_count_status(self, model: nn.Module, init_counts: int,
max_counts: int) -> None:
"""Initialize gradient accumulation related attributes.
``OptimWrapper`` can be used without calling
``initialize_iter_status``. However, Consider the case of ``len(
dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is
not divisible by 3, the last iteration will not trigger
``optimizer.step()``, resulting in one less parameter updating.
Args:
model (nn.Module): Training model
init_counts (int): The initial value of the inner count.
max_counts (int): The maximum value of the inner count.
"""
if self.max_iters % self.accumulative_iters != 0:
self._inner_count = init_counts
self._max_counts = max_counts
if self._inner_count % self._accumulative_counts != 0:
self.logger.warning(
'Resume iter number is not divisible by accumulative_iters in '
'GradientCumulativeOptimizerHook, which means the gradient of '
'some iters is lost and the result may be influenced slightly.'
)
'Resumed iteration number is not divisible by '
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, '
'which means the gradient of some iterations is lost and the '
'result may be influenced slightly.')
if has_batch_norm(model) and self.accumulative_iters > 1:
if has_batch_norm(model) and self._accumulative_counts > 1:
self.logger.warning(
'Gradient accumulative may slightly decrease '
'performance because the model has BatchNorm layers.')
residual_iters = self.max_iters - self.cur_iter
residual_counts = max_counts - init_counts
# The maximum number of training iteration that is divisible by
# accumulative_iters.
self.divisible_iters = (
residual_iters // self.accumulative_iters *
self.accumulative_iters)
# Remainder of ``self.max_iters`` divided by ``self.max_iters``
self.remainder_iters = residual_iters - self.divisible_iters
self.iter_status_initialized = True
# `_accumulative_counts`.
self._divisible_counts = (
residual_counts // self._accumulative_counts *
self._accumulative_counts)
# Remainder of `_max_counts` divided by `_accumulative_counts`
self._remainder_counts = residual_counts - self._divisible_counts
def _should_update(self, cur_iter: int, max_iters: int) -> bool:
"""Should optim_wrapper update parameters or synchronized gradient at
current iteration.
def should_update(self) -> bool:
"""Decide whether the parameters should be updated at the current
iteration.
Args:
cur_iter (int): Current iteration of training process.
max_iters (int): Maximum iterations of training process.
Called by :meth:`update_params` and check whether the optimizer
wrapper should update parameters at current iteration.
Returns:
bool: Whether to update parameters or synchronized gradient.
bool: Whether to update parameters.
"""
return ((cur_iter + 1) % self.accumulative_iters == 0
or cur_iter + 1 == max_iters)
return (self._inner_count % self._accumulative_counts == 0
or self._inner_count == self._max_counts)
def should_sync(self) -> bool:
"""Decide whether the automatic gradient synchronization should be
allowed at the current iteration.
It takes effect when gradient accumulation is used to skip
synchronization at the iterations where the parameter is not updated.
Since ``should_sync`` is called by :meth:`optim_context`, and it is
called before :meth:`backward` which means ``self._inner_count += 1``
has not happened yet. Therefore, ``self._inner_count += 1`` should be
performed manually here.
Returns:
bool: Whether to block the automatic gradient synchronization.
"""
return ((self._inner_count + 1) % self._accumulative_counts == 0
or (self._inner_count + 1) == self._max_counts)
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Get scaled loss according to ``_accumulative_counts``,
``_inner_count`` and max_counts.
Args:
loss (torch.Tensor): Original loss calculated by model.
Returns:
loss (torch.Tensor): Scaled loss.
"""
if self._accumulative_counts == 1:
# update parameters without gradient accumulation. The gradient
# should not be rescaled and `loss_factor=1`.
loss_factor = 1
elif self._max_counts == -1:
loss_factor = self._accumulative_counts
else:
# if `self._accumulative_counts > 1`, the gradient needs to be
# rescaled and accumulated. In most cases, `loss_factor` equals to
# `self._accumulative_counts`. However, `self._max_counts` may not
# be divisible by `self._accumulative_counts`, so the
# `loss_scale` for the last few iterations needs to be
# recalculated.
if self._inner_count < self._divisible_counts:
loss_factor = self._accumulative_counts
else:
loss_factor = self._remainder_counts
assert loss_factor > 0, (
'loss_factor should be larger than zero! This error could '
'happened when initialize_iter_status called with an '
'error `init_counts` or `max_counts`')
loss = loss / loss_factor
return loss
@property
def inner_count(self):
"""Get the number of updating parameters of optimizer wrapper."""
return self._inner_count
def __repr__(self):
wrapper_info = f'Type: {type(self).__name__}\n' \
f'accumulative_iters: {self.accumulative_iters}\n' \
f'optimizer: \n'
wrapper_info = (f'Type: {type(self).__name__}\n'
f'_accumulative_counts: {self._accumulative_counts}\n'
'optimizer: \n')
optimizer_str = repr(self.optimizer) + '\n'
return wrapper_info + optimizer_str

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from contextlib import ExitStack, contextmanager
from contextlib import contextmanager
from typing import Dict, Iterator, List, Tuple
import torch
@ -41,24 +40,10 @@ class OptimWrapperDict(OptimWrapper):
"""
def __init__(self, **optim_wrapper_dict: OptimWrapper):
first_key = next(iter(optim_wrapper_dict))
first_optim_wrapper = optim_wrapper_dict[first_key]
assert isinstance(first_optim_wrapper, OptimWrapper), (
'Each argument of `OptimWrapperDict` must be an `OptimWrapper '
'instance`')
optim_wrapper_class = type(first_optim_wrapper)
for key, value in optim_wrapper_dict.items():
assert type(value) == optim_wrapper_class, (
f'All optimizer wrappers should have the same type, but found'
f' {key}: {type(value)} and {first_key}: {optim_wrapper_class}'
)
if value.accumulative_iters != 1:
warnings.warn(
f'The `accumulative_iters` of {key} is '
f'{value.accumulative_iters}. OptimWrapperDict '
'will not enable any `accumulate_grad` context of its '
'optimizer wrappers. You should access the corresponding '
'optimizer wrapper to enable the context.')
assert isinstance(value, OptimWrapper), (
'`OptimWrapperDict` only accept OptimWrapper instance, '
f'but got {key}: {type(value)}')
self.optim_wrappers = optim_wrapper_dict
def update_params(self, loss: torch.Tensor) -> None:
@ -69,9 +54,8 @@ class OptimWrapperDict(OptimWrapper):
Therefore, this method is not implemented. The optimizer wrapper of
OptimWrapperDict should be accessed and call its `update_params.
"""
raise NotImplementedError(
'You should access the OptimWrapper of the '
'OptimWrapperDict and call its `update_params`')
raise NotImplementedError('`update_params` should be called by each '
'optimizer separately`')
def backward(self, loss: torch.Tensor) -> None:
"""Since OptimWrapperDict doesn't know which optimizer wrapper's
@ -81,14 +65,14 @@ class OptimWrapperDict(OptimWrapper):
The optimizer wrapper of OptimWrapperDict should be accessed and call
its `backward.
"""
raise NotImplementedError('You should access the OptimWrapper of the '
'OptimWrapperDict and call its `backward`')
raise NotImplementedError('`backward` should be called by each '
'optimizer separately`')
def step(self) -> None:
"""Since the backward method is not implemented, the step should not be
implemented either."""
raise NotImplementedError('You should access the OptimWrapper of the '
'OptimWrapperDict and call its `step`')
raise NotImplementedError('`step` should be called by each '
'optimizer separately`')
def zero_grad(self) -> None:
"""Set the gradients of all optimizer wrappers to zero."""
@ -96,49 +80,29 @@ class OptimWrapperDict(OptimWrapper):
optim_wrapper.zero_grad()
@contextmanager
def precision_context(self):
optim_wrapper = next(iter(self.optim_wrappers.values()))
with optim_wrapper.precision_context():
yield
def optim_context(self, model: nn.Module):
"""``optim_context`` should be called by each optimizer separately."""
raise NotImplementedError(
'`optim_context` should be called by each optimizer separately')
@contextmanager
def accumulate_grad(self, model: nn.Module, cur_iter: int, max_iters: int):
"""Enable ``accumulate_grad`` contexts of all optimizer wrappers.
def initialize_count_status(self, model: nn.Module, cur_iter,
max_iters) -> None:
"""Do nothing but provide unified interface for :obj:`OptimWrapper`
Warning:
Consider there is only one ``model`` arguments for all
optimizer wrappers, all optimizer wrappers are working under the
same ``model.no_sync`` context. For example, there is a model
composed of model_a(optimizer_a) and model_b(optimizer_b).
``OptimWrapperDict.accumulate_grad`` will further
call ``model.no_sync``, which will block the gradient
synchronization of both a and b. If optimizer_a and
optimizer_b have different ``accumulative_iters``, and want to
block the gradient synchronization of model_a and model_b
separately, the model should not implement the ``no_sync``
method(or enable an empty context). The ``accumulate_grad`` context
should be enabled inside the model by accessing corresponding
optimizer wrapper.
Since ``OptimWrapperDict`` does not know the correspondence between
model and optimizer wrapper. ``initialize_iter_status`` will do nothing
and each optimizer wrapper should call ``initialize_iter_status``
separately.
"""
with ExitStack() as stack:
for optim_wrapper in self.optim_wrappers.values():
stack.enter_context(
optim_wrapper.accumulate_grad(model, cur_iter, max_iters))
yield
return
def load_state_dict(self, state_dict: dict) -> None:
"""Load the state dictionary from the ``state_dict``.
Args:
state_dict (dict): Each key-value pair in `state_dict` represents
the name and the state dictionary of corresponding
:obj:`OptimWrapper`.
"""
for name, _state_dict in state_dict.items():
assert name in self.optim_wrappers, (
f'Mismatched `state_dict`! cannot found {name} in '
'OptimWrapperDict')
self.optim_wrappers[name].load_state_dict(_state_dict)
@property
def param_groups(self):
"""Returns the parameter groups of each OptimWrapper."""
param_groups = dict()
for key, value in self.optim_wrappers.items():
param_groups[key] = value.param_groups
return param_groups
def get_lr(self) -> Dict[str, List[float]]:
"""Get the learning rate of all optimizers.
@ -175,6 +139,20 @@ class OptimWrapperDict(OptimWrapper):
state_dict[name] = optim_wrapper.state_dict()
return state_dict
def load_state_dict(self, state_dict: dict) -> None:
"""Load the state dictionary from the ``state_dict``.
Args:
state_dict (dict): Each key-value pair in `state_dict` represents
the name and the state dictionary of corresponding
:obj:`OptimWrapper`.
"""
for name, _state_dict in state_dict.items():
assert name in self.optim_wrappers, (
f'Mismatched `state_dict`! cannot found {name} in '
'OptimWrapperDict')
self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding
:obj:`OptimWrapper`"""

View File

@ -101,12 +101,9 @@ class EpochBasedTrainLoop(BaseLoop):
'before_train_iter', batch_idx=idx, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
self._iter,
self._max_iters):
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook(
'after_train_iter',
@ -204,19 +201,16 @@ class IterBasedTrainLoop(BaseLoop):
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
self._iter,
self._max_iters):
# train_logs should be a dict of loss.
train_logs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.message_hub.update_info('train_logs', train_logs)
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.message_hub.update_info('train_logs', outputs)
self.runner.call_hook(
'after_train_iter',
batch_idx=self._iter,
data_batch=data_batch,
outputs=train_logs)
outputs=outputs)
self._iter += 1

View File

@ -926,7 +926,7 @@ class Runner:
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
Type: OptimWrapper
accumulative_iters: 1
accumulative_counts: 1
optimizer:
SGD (
Parameter Group 0
@ -941,7 +941,7 @@ class Runner:
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
Type: OptimWrapper
accumulative_iters: 1
accumulative_counts: 1
optimizer:
SGD (
Parameter Group 0
@ -965,7 +965,7 @@ class Runner:
>>> optim_wrapper
name: generator
Type: OptimWrapper
accumulative_iters: 1
accumulative_counts: 1
optimizer:
SGD (
Parameter Group 0
@ -977,7 +977,7 @@ class Runner:
)
name: discriminator
Type: OptimWrapper
accumulative_iters: 1
accumulative_counts: 1
optimizer:
'discriminator': Adam (
Parameter Group 0
@ -1528,7 +1528,6 @@ class Runner:
# `build_optimizer` should be called before `build_param_scheduler`
# because the latter depends on the former
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
# Automatically scaling lr by linear scaling rule
self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
@ -1540,9 +1539,14 @@ class Runner:
self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
self.call_hook('before_run')
# Initiate inner count of `optim_wrapper`.
self.optim_wrapper.initialize_count_status(
self.model,
self._train_loop.iter, # type: ignore
self._train_loop.max_iters) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()

View File

@ -8,9 +8,10 @@ import torch.distributed as torch_dist
import torch.nn as nn
from torch.optim import SGD
from mmengine.dist import all_gather
from mmengine.model import (BaseModel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel)
from mmengine.optim import OptimWrapper, OptimWrapperDict
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
@ -23,9 +24,9 @@ class ToyModel(BaseModel):
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x, data_samples=None, mode='tensor'):
x = self.conv1(x)
x = self.conv2(x)
if mode == 'loss':
x = self.conv1(x)
x = self.conv2(x)
return dict(loss=x)
elif mode == 'predict':
return x
@ -58,24 +59,41 @@ class ComplexModel(BaseModel):
pass
class TestModelWrapper(MultiProcessTestCase):
class TestDistributedDataParallel(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@unittest.skipIf(
not torch.cuda.is_available(), reason='cuda should be available')
def test_train_step(self):
self._init_dist_env(self.rank, self.world_size)
# Test `optim_wrapper` is a instance of `OptimWrapper`
model = ToyModel()
# Mixed precision training and gradient asynchronous should be valid at
# the same time
model = ToyModel().cuda()
ddp_model = MMDistributedDataParallel(module=model)
optimizer = SGD(ddp_model.parameters(), lr=0)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
inputs = torch.randn(3, 1, 1) * self.rank * 255
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, accumulative_counts=3)
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss']
self.assertIs(res.dtype, torch.float16)
grad = ddp_model.module.conv1.weight.grad
all_grads = all_gather(grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
# Gradient accumulation
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
# Test update params and clean grads.
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
grad = ddp_model.module.conv1.weight.grad
assert_allclose(grad, torch.zeros_like(grad))
all_grads = all_gather(grad)
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
@ -107,7 +125,7 @@ class TestModelWrapper(MultiProcessTestCase):
@unittest.skipIf(
not torch.cuda.is_available(), reason='cuda should be available')
class TestMMSeparateDistributedDataParallel(TestModelWrapper):
class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
def test_train_step(self):
self._init_dist_env(self.rank, self.world_size)

View File

@ -45,7 +45,7 @@ class ToyModel2(nn.Module):
class TestOptimWrapper(MultiProcessTestCase):
# Test `OptimWrapper.accumulate_grad` will block the gradient
# Test `OptimWrapper.optim_context` will block the gradient
# synchronization when using gradient accumulation strategy in distributed
# data parallel training.
def setUp(self) -> None:
@ -61,94 +61,84 @@ class TestOptimWrapper(MultiProcessTestCase):
def test_init(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.optimizer, self.optimizer)
self.assertIs(optim_wrapper.optimizer, self.optimizer)
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
self.assertEqual(optim_wrapper.accumulative_iters, 1)
self.assertEqual(optim_wrapper._accumulative_counts, 1)
self.assertIs(optim_wrapper.logger, self.logger)
self.assertIs(optim_wrapper.message_hub, self.message_hub)
self.assertEqual(optim_wrapper._inner_count, 0)
self.assertEqual(optim_wrapper._max_counts, -1)
self.assertEqual(optim_wrapper._divisible_counts, -1)
self.assertEqual(optim_wrapper._remainder_counts, -1)
with self.assertRaisesRegex(AssertionError,
'If `clip_grad_kwargs` is not None'):
'If `clip_grad` is not None'):
OptimWrapper(self.optimizer, clip_grad=[])
def test_update_params(self):
# Test update params every iteration.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1)
self._mock_method(optim_wrapper)
loss = torch.tensor(1)
optim_wrapper.update_params(loss)
optim_wrapper.backward.assert_called_with(torch.tensor(1))
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with()
with optim_wrapper.accumulate_grad(self.model, 2, 100):
optim_wrapper.update_params(torch.tensor(1))
optim_wrapper.backward.assert_called_with(torch.tensor(1))
optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with()
# It will raise an error if `accumulative_iters > 1` and
# `accumulate_grad` is not enabled.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
# Test gradient accumulation.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
self._mock_method(optim_wrapper)
with self.assertRaisesRegex(AssertionError,
'gradient accumulation must be'):
optim_wrapper.update_params(loss)
# `iter=0`, accumulate gradient and do not update params.
loss = torch.tensor(1)
optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
# `iter=0`, Call `optimizer_step` first time.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=0, max_iters=100):
loss = torch.tensor(1)
optim_wrapper.update_params(loss)
optim_wrapper.backward.assert_called_with(torch.tensor(1) / 3)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
# gradient accumulate
optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper._inner_count, 2)
# `iter=2`, Call `optimizer_step` first time.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=2, max_iters=100):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
# `iter=2`, update params.
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
self._mock_method(optim_wrapper)
# Test end of training.
with optim_wrapper.accumulate_grad(
self.model, cur_iter=99, max_iters=100):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
optim_wrapper.backward.assert_called_with(1)
# If ``accumulative_iters > 1``, call ``update_params`` with
# non-accumulate_grad context will raise an Assertion error
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=1)
optim_wrapper.accumulative_iters = 2
with self.assertRaisesRegex(AssertionError,
'gradient accumulation must be performed'):
optim_wrapper.update_params(loss)
# Test end of training without calling `initialize_iter_status`
optim_wrapper._inner_count = 99
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3)
self._mock_method(optim_wrapper)
def test_initilize_iter_status(self):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper._initilize_iter_status(self.model)
self.assertEqual(optim_wrapper.divisible_iters, 0)
self.assertEqual(optim_wrapper.remainder_iters, 0)
# After calling `initialize_iter_status`, params will be updated at the
# last iteration, and the `loss_scaler` will be adjusted.
optim_wrapper.initialize_count_status(self.model, 99, 100)
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
def test_initialize_iter_status(self):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper.initialize_count_status(self.model, 0, 100)
self.assertEqual(optim_wrapper._divisible_counts, 99)
self.assertEqual(optim_wrapper._remainder_counts, 1)
# Indivisible cur_iter will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper.cur_iter = 0
optim_wrapper.max_iters = 100
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
with self.assertLogs(self.logger) as cm:
optim_wrapper._initilize_iter_status(self.model)
optim_wrapper.initialize_count_status(self.model, 2, 100)
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Resume iter number is not')
self.assertRegex(cm.records[0].msg, 'Resumed iteration number')
# Model with batch norm will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_iters=3)
optim_wrapper.cur_iter = 0
optim_wrapper.max_iters = 99
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
model = nn.BatchNorm2d(1)
with self.assertLogs(self.logger) as cm:
optim_wrapper._initilize_iter_status(model)
optim_wrapper.initialize_count_status(model, 0, 99)
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Gradient accumulative')
@ -214,40 +204,40 @@ class TestOptimWrapper(MultiProcessTestCase):
self.assertEqual(optim_wrapper.param_groups,
self.optimizer.param_groups)
def test_accumulate_grad(self):
def test_optim_context(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel2()
ddp_model = DistributedDataParallel(model)
optimizer = SGD(ddp_model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
optim_wrapper.zero_grad()
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
# Automatically sync grads if `accumulative_iters` = 1
inputs = torch.randn(1, 1, 1, 1) * self.rank
ddp_model(inputs).sum().backward()
grad = model.conv.weight.grad
all_grads = all_gather(grad)
assert_allclose(all_grads[0], all_grads[1])
# Automatically sync grads if `accumulative_counts` = 1
optim_wrapper.initialize_count_status(model, 0, 100)
inputs = torch.randn(1, 1, 1, 1) * self.rank
ddp_model(inputs).sum().backward()
grad = model.conv.weight.grad
all_grads = all_gather(grad)
assert_allclose(all_grads[0], all_grads[1])
# Do not sync grads when `optim_wrapper.cur_iter` cannot be
# divided by `optim_wrapper.accumulative_iters`
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=3)
with optim_wrapper.accumulate_grad(ddp_model, 0, 100):
ddp_model(inputs).sum().backward()
all_grads = all_gather(model.conv.weight.grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
# sync grads if `cur_iter == 2`
with optim_wrapper.accumulate_grad(ddp_model, 2, 100):
ddp_model(inputs).sum().backward()
all_grads = all_gather(model.conv.weight.grad)
# divided by `optim_wrapper._accumulative_counts`
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=3)
optim_wrapper.initialize_count_status(model, 0, 100)
with optim_wrapper.optim_context(ddp_model):
loss = ddp_model(inputs).sum()
loss.backward()
all_grads = all_gather(model.conv.weight.grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
def test_precision_context(self):
optim_wrapper = OptimWrapper(self.optimizer)
with optim_wrapper.precision_context():
pass
# sync grads if `cur_iter == 2`
optim_wrapper.initialize_count_status(model, 2, 100)
with optim_wrapper.optim_context(ddp_model):
loss = ddp_model(inputs).sum()
loss.backward()
all_grads = all_gather(model.conv.weight.grad)
assert_allclose(all_grads[0], all_grads[1])
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
@ -260,7 +250,12 @@ class TestOptimWrapper(MultiProcessTestCase):
# TODO Test the real interface after add testing tool function which can
# test the function or method is read called.
def _mock_method(self, optim_wrapper):
optim_wrapper.backward = MagicMock()
def mock_methd(loss):
optim_wrapper._inner_count += 1
optim_wrapper.scaled_loss = loss
optim_wrapper.backward = mock_methd
optim_wrapper.step = MagicMock()
optim_wrapper.zero_grad = MagicMock()
@ -376,9 +371,9 @@ class TestAmpOptimWrapper(TestCase):
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')),
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_precision_context(self):
def test_optim_context(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
with amp_optim_wrapper.precision_context():
with amp_optim_wrapper.optim_context(self.model):
x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.float16)

View File

@ -1,69 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from unittest import TestCase
from unittest.mock import patch
import torch
import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.optim import OptimWrapper, OptimWrapperDict
class TestOptimWrapperDict(TestCase):
def setUp(self) -> None:
model1 = nn.Linear(1, 1)
model2 = nn.Linear(1, 1)
self.optim1 = SGD(model1.parameters(), lr=0.1, momentum=0.8)
self.optim2 = SGD(model2.parameters(), lr=0.2, momentum=0.9)
self.model1 = nn.Linear(1, 1)
self.model2 = nn.Linear(1, 1)
self.optim1 = SGD(self.model1.parameters(), lr=0.1, momentum=0.8)
self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9)
self.optim_wrapper1 = OptimWrapper(self.optim1)
self.optim_wrapper2 = OptimWrapper(self.optim2)
self.optimizers_wrappers = dict(
optim1=self.optim_wrapper1, optim2=self.optim_wrapper2)
@patch('torch.cuda.is_available', lambda: True)
def test_init(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertEqual(optim_wrapper_dict.optim_wrappers,
self.optimizers_wrappers)
# Different types of OptimWrapper will raise an error
with self.assertRaisesRegex(AssertionError,
'`OptimWrapperDict` only accept'):
OptimWrapperDict(**dict(optim1=self.optim1, optim2=self.optim2))
with self.assertRaisesRegex(
AssertionError, 'All optimizer wrappers should have the same'):
optim_wrapper2 = AmpOptimWrapper(optimizer=self.optim2)
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
with self.assertWarnsRegex(UserWarning, 'The `accumulative_iters` of'):
optim_wrapper2 = OptimWrapper(
optimizer=self.optim2, accumulative_iters=2)
OptimWrapperDict(optim1=self.optim_wrapper1, optim2=optim_wrapper2)
def test_accumulate_grad(self):
@contextmanager
def context_a(a, b, *args, **kwargs):
a[0] = 100
yield
a[0] = 1
@contextmanager
def context_b(a, b, *args, **kwargs):
b[0] = 200
yield
b[0] = 2
a = [0]
b = [0]
# Test enter the context both of `optim_wrapper1` and `optim_wrapper1`.
def test_update_params(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.optim_wrapper1.accumulate_grad = context_a
self.optim_wrapper2.accumulate_grad = context_b
with optim_wrapper_dict.accumulate_grad(a, b, 0):
self.assertEqual(a[0], 100)
self.assertEqual(b[0], 200)
with self.assertRaisesRegex(NotImplementedError,
'`update_params` should be called'):
optim_wrapper_dict.update_params(1)
self.assertEqual(a[0], 1)
self.assertEqual(b[0], 2)
def test_backward(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
with self.assertRaisesRegex(NotImplementedError,
'`backward` should be called'):
optim_wrapper_dict.backward(1)
def test_step(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
with self.assertRaisesRegex(NotImplementedError,
'`step` should be called'):
optim_wrapper_dict.step()
def test_zero_grad(self):
# Test clear all grad
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.model1(torch.randn(1, 1)).sum().backward()
self.model2(torch.randn(1, 1)).sum().backward()
self.assertTrue((self.model1.weight.grad != 0).any())
self.assertTrue((self.model2.weight.grad != 0).any())
optim_wrapper_dict.zero_grad()
self.assertTrue((self.model1.weight.grad == 0).all())
self.assertTrue((self.model2.weight.grad == 0).all())
def test_optim_context(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
with self.assertRaisesRegex(NotImplementedError,
'`optim_context` should be called'):
with optim_wrapper_dict.optim_context(self.model1):
yield
def test_initialize_count_status(self):
# Test `initialize_count_status` can be called.
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
optim_wrapper_dict.initialize_count_status(self.model1, 1, 1)
def test_param_groups(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertEqual(optim_wrapper_dict.param_groups['optim1'],
self.optim1.param_groups)
self.assertEqual(optim_wrapper_dict.param_groups['optim2'],
self.optim2.param_groups)
def test_get_lr(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
lr = optim_wrapper_dict.get_lr()
self.assertEqual(lr['optim1.lr'], [0.1])
self.assertEqual(lr['optim2.lr'], [0.2])
def test_get_momentum(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
momentum = optim_wrapper_dict.get_momentum()
self.assertEqual(momentum['optim1.momentum'], [0.8])
self.assertEqual(momentum['optim2.momentum'], [0.9])
def test_state_dict(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
@ -109,18 +132,6 @@ class TestOptimWrapperDict(TestCase):
list(optim_wrapper_dict.keys()),
list(self.optimizers_wrappers.keys()))
def test_get_lr(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
lr = optim_wrapper_dict.get_lr()
self.assertEqual(lr['optim1.lr'], [0.1])
self.assertEqual(lr['optim2.lr'], [0.2])
def test_get_momentum(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
momentum = optim_wrapper_dict.get_momentum()
self.assertEqual(momentum['optim1.momentum'], [0.8])
self.assertEqual(momentum['optim2.momentum'], [0.9])
def test_getitem(self):
optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers)
self.assertIs(self.optimizers_wrappers['optim1'],

View File

@ -1135,8 +1135,9 @@ class TestRunner(TestCase):
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
runner = Runner.from_cfg(cfg)
runner.train()
self.assertEqual(runner.optim_wrapper._inner_count, 12)
self.assertEqual(runner.optim_wrapper._max_counts, 12)
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
@ -1183,6 +1184,8 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg)
runner.train()
self.assertEqual(runner.optim_wrapper._inner_count, 12)
self.assertEqual(runner.optim_wrapper._max_counts, 12)
assert isinstance(runner.train_loop, IterBasedTrainLoop)
self.assertEqual(len(epoch_results), 1)