mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add ApexOptimWrapper (#742)
* add ApexOptimWrapper * typo fix * add apex amp.initialize in optim_context * assert apex_amp * polish code * add parameters of apex_amp.initialize * add docs * polish code * polish code * polish code * fix calling of apex amp load_state_dict * polish * add comments * Update apex_optimizer_wrapper.py * Update apex_optimizer_wrapper.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
bc49e0c0a1
commit
e35ed5fd2e
@ -20,6 +20,7 @@ Optimizer
|
||||
:template: classtemplate.rst
|
||||
|
||||
AmpOptimWrapper
|
||||
ApexOptimWrapper
|
||||
OptimWrapper
|
||||
OptimWrapperDict
|
||||
DefaultOptimWrapperConstructor
|
||||
|
@ -20,6 +20,7 @@ Optimizer
|
||||
:template: classtemplate.rst
|
||||
|
||||
AmpOptimWrapper
|
||||
ApexOptimWrapper
|
||||
OptimWrapper
|
||||
OptimWrapperDict
|
||||
DefaultOptimWrapperConstructor
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
||||
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
|
||||
AmpOptimWrapper, ApexOptimWrapper,
|
||||
DefaultOptimWrapperConstructor, OptimWrapper,
|
||||
OptimWrapperDict, build_optim_wrapper)
|
||||
# yapf: disable
|
||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||
@ -25,8 +26,8 @@ __all__ = [
|
||||
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
||||
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
||||
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
||||
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
|
||||
'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum',
|
||||
'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum',
|
||||
'ReduceOnPlateauParamScheduler'
|
||||
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper',
|
||||
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
|
||||
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
|
||||
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler'
|
||||
]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .amp_optimizer_wrapper import AmpOptimWrapper
|
||||
from .apex_optimizer_wrapper import ApexOptimWrapper
|
||||
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
build_optim_wrapper)
|
||||
from .default_constructor import DefaultOptimWrapperConstructor
|
||||
@ -10,5 +11,6 @@ from .zero_optimizer import ZeroRedundancyOptimizer
|
||||
__all__ = [
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
||||
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
||||
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
|
||||
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
|
||||
'ZeroRedundancyOptimizer'
|
||||
]
|
||||
|
200
mmengine/optim/optimizer/apex_optimizer_wrapper.py
Normal file
200
mmengine/optim/optimizer/apex_optimizer_wrapper.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# a circular import will be caused by
|
||||
# from mmengine.model.wrappers import is_model_wrapper
|
||||
import mmengine
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
except ImportError:
|
||||
apex_amp = None
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
class ApexOptimWrapper(OptimWrapper):
|
||||
"""A subclass of :class:`OptimWrapper` that supports automatic mixed
|
||||
precision training based on apex.amp.
|
||||
|
||||
``ApexOptimWrapper`` provides a unified interface with
|
||||
``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``.
|
||||
|
||||
Warning:
|
||||
``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_
|
||||
|
||||
Args:
|
||||
opt_level (str): Pure or mixed precision optimization level. Accepted
|
||||
values are "O0", "O1", "O2", and "O3". Defaults to "O1".
|
||||
loss_scale (float or str, optional): If passed as a string, must be a
|
||||
string representing a number, e.g., "128.0", or the string
|
||||
"dynamic". Defaults to "dynamic".
|
||||
enabled (bool): If False, renders all Amp calls no-ops, so your script
|
||||
should run as if Amp were not present. Defaults to True.
|
||||
cast_model_type (torch.dtype, optional): Model's parameters and
|
||||
buffers to the desired type. Defaults to None.
|
||||
patch_torch_functions (bool, optional): Patch all Torch functions
|
||||
and Tensor methods to perform Tensor Core-friendly ops like GEMMs
|
||||
and convolutions in FP16, and any ops that benefit from FP32
|
||||
precision in FP32. Defaults to None.
|
||||
keep_batchnorm_fp32 (bool or str, optional): To enhance precision
|
||||
and enable cudnn batchnorm (which improves performance),
|
||||
it's often beneficial to keep batchnorm weights in FP32
|
||||
even if the rest of the model is FP16.
|
||||
If passed as a string, must be the string "True" or "False".
|
||||
Defaults to None.
|
||||
master_weights (bool, optional): Maintain FP32 master weights to
|
||||
accompany any FP16 model weights. FP32 master weights are stepped
|
||||
by the optimizer to enhance precision and capture small gradients.
|
||||
Defaults to None.
|
||||
cast_model_outputs (torch.dtype, optional): Option to ensure that
|
||||
the outputs of your model(s) are always cast to a particular type
|
||||
regardless of ``opt_level``. Defaults to None.
|
||||
num_losses (int): Option to tell Amp in advance how many
|
||||
losses/backward passes you plan to use. Defaults to 1.
|
||||
verbosity (int): Set to 0 to suppress Amp-related output.
|
||||
Defaults to 1.
|
||||
min_loss_scale (float, optional): Sets a floor for the loss scale
|
||||
values that can be chosen by dynamic loss scaling.
|
||||
The default value of None means that no floor is imposed.
|
||||
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
|
||||
Defaults to None.
|
||||
max_loss_scale (float, optional): Sets a ceiling for the loss scale
|
||||
values that can be chosen by dynamic loss scaling. If dynamic
|
||||
loss scaling is not used, `max_loss_scale` is ignored.
|
||||
Defaults to 2.**24.
|
||||
**kwargs: Keyword arguments passed to OptimWrapper.
|
||||
|
||||
Note:
|
||||
If you use ``IterBasedRunner`` and enable gradient accumulation,
|
||||
the original `max_iters` should be multiplied by
|
||||
``accumulative_counts``.
|
||||
|
||||
Note:
|
||||
`New in version 0.6.0.`
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
opt_level: str = 'O1',
|
||||
loss_scale: Union[float, str, None] = 'dynamic',
|
||||
enabled: Optional[bool] = True,
|
||||
cast_model_type: Optional[torch.dtype] = None,
|
||||
patch_torch_functions: Optional[bool] = None,
|
||||
keep_batchnorm_fp32: Union[bool, str, None] = None,
|
||||
master_weights: Optional[bool] = None,
|
||||
cast_model_outputs: Optional[torch.dtype] = None,
|
||||
num_losses: int = 1,
|
||||
verbosity: int = 1,
|
||||
min_loss_scale: Optional[float] = None,
|
||||
max_loss_scale: Optional[float] = 2.**24,
|
||||
**kwargs):
|
||||
assert apex_amp is not None, \
|
||||
'Apex is not installed. Please check ' \
|
||||
'https://github.com/NVIDIA/apex#linux.'
|
||||
super().__init__(**kwargs)
|
||||
self.opt_level = opt_level
|
||||
self.loss_scale = loss_scale
|
||||
self.enabled = enabled
|
||||
self.cast_model_type = cast_model_type
|
||||
self.patch_torch_functions = patch_torch_functions
|
||||
self.keep_batchnorm_fp32 = keep_batchnorm_fp32
|
||||
self.master_weights = master_weights
|
||||
self.cast_model_outputs = cast_model_outputs
|
||||
self.num_losses = num_losses
|
||||
self.verbosity = verbosity
|
||||
self.min_loss_scale = min_loss_scale
|
||||
self.max_loss_scale = max_loss_scale
|
||||
self._apex_amp_state_dict = None
|
||||
|
||||
def backward(self, loss: torch.Tensor, **kwargs) -> None:
|
||||
"""Perform gradient back propagation with :attr:`loss_scaler`.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss of current iteration.
|
||||
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`
|
||||
"""
|
||||
with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward(**kwargs)
|
||||
self._inner_count += 1
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary of :attr:`optimizer` and
|
||||
:attr:`apex_amp`.
|
||||
|
||||
Based on the state dictionary of the optimizer, the returned state
|
||||
dictionary will add a key named "apex_amp".
|
||||
|
||||
Returns:
|
||||
dict: The merged state dict of :attr:`apex_amp` and
|
||||
:attr:`optimizer`.
|
||||
"""
|
||||
state_dict = self.optimizer.state_dict()
|
||||
state_dict['apex_amp'] = apex_amp.state_dict()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load and parse the state dictionary of :attr:`optimizer` and
|
||||
:attr:`apex_amp`.
|
||||
|
||||
If state_dict contains "apex_amp", the :attr:`apex_amp` will
|
||||
load the corresponding keys. Otherwise, only the :attr:`optimizer`
|
||||
will load the state dictionary.
|
||||
|
||||
Note:
|
||||
:meth:`load_state_dict` shuold be called after
|
||||
`apex_amp.initialize` is called.
|
||||
Args:
|
||||
state_dict (dict): The state dict of :attr:`optimizer` and
|
||||
:attr:`apex_amp`
|
||||
"""
|
||||
if 'apex_amp' in state_dict:
|
||||
# when `apex_amp` is not initialized, calling `load_state_dict`
|
||||
# will raise an error, so we temporarily cache the apex_amp
|
||||
# part, and then load it into `apex_amp` after completing
|
||||
# the `apex_amp` initialization in `optim_context` method
|
||||
if hasattr(self.optimizer, '_amp_stash'):
|
||||
apex_amp.load_state_dict(state_dict.pop('apex_amp'))
|
||||
else:
|
||||
self._apex_amp_state_dict = state_dict.pop('apex_amp')
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
@contextmanager
|
||||
def optim_context(self, model: nn.Module):
|
||||
"""Enables the context for mixed precision training, and enables the
|
||||
context for disabling gradient synchronization during gradient
|
||||
accumulation context.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The training model.
|
||||
"""
|
||||
with super().optim_context(model):
|
||||
# when a given optimizer be passed through apex_amp.initialize,
|
||||
# the "_amp_stash" property will be added
|
||||
if not hasattr(self.optimizer, '_amp_stash'):
|
||||
if mmengine.model.wrappers.is_model_wrapper(model):
|
||||
model = model.module
|
||||
model, self.optimizer = apex_amp.initialize(
|
||||
model,
|
||||
self.optimizer,
|
||||
opt_level=self.opt_level,
|
||||
loss_scale=self.loss_scale,
|
||||
enabled=self.enabled,
|
||||
cast_model_type=self.cast_model_type,
|
||||
patch_torch_functions=self.patch_torch_functions,
|
||||
keep_batchnorm_fp32=self.keep_batchnorm_fp32,
|
||||
master_weights=self.master_weights,
|
||||
cast_model_outputs=self.cast_model_outputs,
|
||||
num_losses=self.num_losses,
|
||||
verbosity=self.verbosity,
|
||||
min_loss_scale=self.min_loss_scale,
|
||||
max_loss_scale=self.max_loss_scale)
|
||||
# loading apex_amp state_dict after initialization of apex_amp
|
||||
if self._apex_amp_state_dict is not None:
|
||||
apex_amp.load_state_dict(self._apex_amp_state_dict)
|
||||
self._apex_amp_state_dict = None
|
||||
yield
|
@ -14,12 +14,19 @@ from torch.optim import SGD, Adam, Optimizer
|
||||
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
||||
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
is_apex_available = False
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
is_apex_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
@ -283,6 +290,101 @@ class TestOptimWrapper(MultiProcessTestCase):
|
||||
optim_wrapper.zero_grad = MagicMock()
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex')
|
||||
class TestApexOptimWrapper(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model = ToyModel().cuda()
|
||||
self.optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_apex_available,
|
||||
reason='`apex` is not available, Please install apex from '
|
||||
'https://www.github.com/nvidia/apex')
|
||||
def test_init(self):
|
||||
apex_optim_wrapper = ApexOptimWrapper(
|
||||
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||
with apex_optim_wrapper.optim_context(self.model):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_apex_available,
|
||||
reason='`apex` is not available, Please install apex from '
|
||||
'https://www.github.com/nvidia/apex')
|
||||
def test_step(self):
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
apex_optim_wrapper = ApexOptimWrapper(
|
||||
optimizer=optimizer, opt_level='O1', loss_scale=1)
|
||||
with apex_optim_wrapper.optim_context(self.model):
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||
apex_optim_wrapper.backward(loss)
|
||||
apex_optim_wrapper.step()
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_apex_available,
|
||||
reason='`apex` is not available, Please install apex from '
|
||||
'https://www.github.com/nvidia/apex')
|
||||
def test_backward(self):
|
||||
apex_optim_wrapper = ApexOptimWrapper(
|
||||
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||
with apex_optim_wrapper.optim_context(self.model):
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||
apex_optim_wrapper.backward(loss)
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_apex_available,
|
||||
reason='`apex` is not available, Please install apex from '
|
||||
'https://www.github.com/nvidia/apex')
|
||||
def test_state_dict(self):
|
||||
apex_optim_wrapper = ApexOptimWrapper(
|
||||
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||
with apex_optim_wrapper.optim_context(self.model):
|
||||
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||
apex_optim_wrapper.update_params(loss)
|
||||
state_dict = apex_optim_wrapper.state_dict()
|
||||
amp_state_dict = state_dict.pop('apex_amp')
|
||||
optim_state_dict = state_dict
|
||||
|
||||
self.assertDictEqual(optim_state_dict,
|
||||
apex_optim_wrapper.optimizer.state_dict())
|
||||
self.assertDictEqual(amp_state_dict, apex_amp.state_dict())
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_apex_available,
|
||||
reason='`apex` is not available, Please install apex from '
|
||||
'https://www.github.com/nvidia/apex')
|
||||
def test_load_state_dict(self):
|
||||
apex_optim_wrapper = ApexOptimWrapper(
|
||||
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||
with apex_optim_wrapper.optim_context(self.model):
|
||||
# Test load from optimizer
|
||||
optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||
apex_optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||
|
||||
self.assertDictEqual(optimizer.state_dict(),
|
||||
apex_optim_wrapper.optimizer.state_dict())
|
||||
# Test load from optim_wrapper
|
||||
apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer)
|
||||
apex_optim_wrapper_ = ApexOptimWrapper(
|
||||
optimizer=SGD(self.model.parameters(), lr=0.1))
|
||||
apex_optim_wrapper_.load_state_dict(
|
||||
apex_optim_wrapper.state_dict())
|
||||
self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(),
|
||||
apex_optim_wrapper_.optimizer.state_dict())
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_apex_available,
|
||||
reason='`apex` is not available, Please install apex from '
|
||||
'https://www.github.com/nvidia/apex')
|
||||
def test_optim_context(self):
|
||||
apex_optim_wrapper = ApexOptimWrapper(
|
||||
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||
with apex_optim_wrapper.optim_context(self.model):
|
||||
x = torch.randn(1, 1, 1, 1).cuda()
|
||||
y = nn.Conv2d(1, 1, 1).cuda()(x)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
||||
|
||||
|
||||
class TestAmpOptimWrapper(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user