diff --git a/docs/en/api/optim.rst b/docs/en/api/optim.rst index 3de2ade6..142d0f08 100644 --- a/docs/en/api/optim.rst +++ b/docs/en/api/optim.rst @@ -20,6 +20,7 @@ Optimizer :template: classtemplate.rst AmpOptimWrapper + ApexOptimWrapper OptimWrapper OptimWrapperDict DefaultOptimWrapperConstructor diff --git a/docs/zh_cn/api/optim.rst b/docs/zh_cn/api/optim.rst index 3de2ade6..142d0f08 100644 --- a/docs/zh_cn/api/optim.rst +++ b/docs/zh_cn/api/optim.rst @@ -20,6 +20,7 @@ Optimizer :template: classtemplate.rst AmpOptimWrapper + ApexOptimWrapper OptimWrapper OptimWrapperDict DefaultOptimWrapperConstructor diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py index 72118b17..38426ea8 100644 --- a/mmengine/optim/__init__.py +++ b/mmengine/optim/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, - AmpOptimWrapper, DefaultOptimWrapperConstructor, - OptimWrapper, OptimWrapperDict, build_optim_wrapper) + AmpOptimWrapper, ApexOptimWrapper, + DefaultOptimWrapperConstructor, OptimWrapper, + OptimWrapperDict, build_optim_wrapper) # yapf: disable from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, CosineAnnealingLR, CosineAnnealingMomentum, @@ -25,8 +26,8 @@ __all__ = [ 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler', 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', - '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict', - 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum', - 'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum', - 'ReduceOnPlateauParamScheduler' + '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper', + 'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', + 'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR', + 'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler' ] diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py index fdc77679..01169514 100644 --- a/mmengine/optim/optimizer/__init__.py +++ b/mmengine/optim/optimizer/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .amp_optimizer_wrapper import AmpOptimWrapper +from .apex_optimizer_wrapper import ApexOptimWrapper from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, build_optim_wrapper) from .default_constructor import DefaultOptimWrapperConstructor @@ -10,5 +11,6 @@ from .zero_optimizer import ZeroRedundancyOptimizer __all__ = [ 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper', - 'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer' + 'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict', + 'ZeroRedundancyOptimizer' ] diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py new file mode 100644 index 00000000..5f2f6f4a --- /dev/null +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager +from typing import Optional, Union + +import torch +import torch.nn as nn + +# a circular import will be caused by +# from mmengine.model.wrappers import is_model_wrapper +import mmengine +from mmengine.registry import OPTIM_WRAPPERS +from .optimizer_wrapper import OptimWrapper + +try: + import apex.amp as apex_amp +except ImportError: + apex_amp = None + + +@OPTIM_WRAPPERS.register_module() +class ApexOptimWrapper(OptimWrapper): + """A subclass of :class:`OptimWrapper` that supports automatic mixed + precision training based on apex.amp. + + ``ApexOptimWrapper`` provides a unified interface with + ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``. + + Warning: + ``ApexOptimWrapper`` requires `nvidia apex `_ + + Args: + opt_level (str): Pure or mixed precision optimization level. Accepted + values are "O0", "O1", "O2", and "O3". Defaults to "O1". + loss_scale (float or str, optional): If passed as a string, must be a + string representing a number, e.g., "128.0", or the string + "dynamic". Defaults to "dynamic". + enabled (bool): If False, renders all Amp calls no-ops, so your script + should run as if Amp were not present. Defaults to True. + cast_model_type (torch.dtype, optional): Model's parameters and + buffers to the desired type. Defaults to None. + patch_torch_functions (bool, optional): Patch all Torch functions + and Tensor methods to perform Tensor Core-friendly ops like GEMMs + and convolutions in FP16, and any ops that benefit from FP32 + precision in FP32. Defaults to None. + keep_batchnorm_fp32 (bool or str, optional): To enhance precision + and enable cudnn batchnorm (which improves performance), + it's often beneficial to keep batchnorm weights in FP32 + even if the rest of the model is FP16. + If passed as a string, must be the string "True" or "False". + Defaults to None. + master_weights (bool, optional): Maintain FP32 master weights to + accompany any FP16 model weights. FP32 master weights are stepped + by the optimizer to enhance precision and capture small gradients. + Defaults to None. + cast_model_outputs (torch.dtype, optional): Option to ensure that + the outputs of your model(s) are always cast to a particular type + regardless of ``opt_level``. Defaults to None. + num_losses (int): Option to tell Amp in advance how many + losses/backward passes you plan to use. Defaults to 1. + verbosity (int): Set to 0 to suppress Amp-related output. + Defaults to 1. + min_loss_scale (float, optional): Sets a floor for the loss scale + values that can be chosen by dynamic loss scaling. + The default value of None means that no floor is imposed. + If dynamic loss scaling is not used, `min_loss_scale` is ignored. + Defaults to None. + max_loss_scale (float, optional): Sets a ceiling for the loss scale + values that can be chosen by dynamic loss scaling. If dynamic + loss scaling is not used, `max_loss_scale` is ignored. + Defaults to 2.**24. + **kwargs: Keyword arguments passed to OptimWrapper. + + Note: + If you use ``IterBasedRunner`` and enable gradient accumulation, + the original `max_iters` should be multiplied by + ``accumulative_counts``. + + Note: + `New in version 0.6.0.` + """ # noqa: E501 + + def __init__(self, + opt_level: str = 'O1', + loss_scale: Union[float, str, None] = 'dynamic', + enabled: Optional[bool] = True, + cast_model_type: Optional[torch.dtype] = None, + patch_torch_functions: Optional[bool] = None, + keep_batchnorm_fp32: Union[bool, str, None] = None, + master_weights: Optional[bool] = None, + cast_model_outputs: Optional[torch.dtype] = None, + num_losses: int = 1, + verbosity: int = 1, + min_loss_scale: Optional[float] = None, + max_loss_scale: Optional[float] = 2.**24, + **kwargs): + assert apex_amp is not None, \ + 'Apex is not installed. Please check ' \ + 'https://github.com/NVIDIA/apex#linux.' + super().__init__(**kwargs) + self.opt_level = opt_level + self.loss_scale = loss_scale + self.enabled = enabled + self.cast_model_type = cast_model_type + self.patch_torch_functions = patch_torch_functions + self.keep_batchnorm_fp32 = keep_batchnorm_fp32 + self.master_weights = master_weights + self.cast_model_outputs = cast_model_outputs + self.num_losses = num_losses + self.verbosity = verbosity + self.min_loss_scale = min_loss_scale + self.max_loss_scale = max_loss_scale + self._apex_amp_state_dict = None + + def backward(self, loss: torch.Tensor, **kwargs) -> None: + """Perform gradient back propagation with :attr:`loss_scaler`. + + Args: + loss (torch.Tensor): The loss of current iteration. + kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` + """ + with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward(**kwargs) + self._inner_count += 1 + + def state_dict(self) -> dict: + """Get the state dictionary of :attr:`optimizer` and + :attr:`apex_amp`. + + Based on the state dictionary of the optimizer, the returned state + dictionary will add a key named "apex_amp". + + Returns: + dict: The merged state dict of :attr:`apex_amp` and + :attr:`optimizer`. + """ + state_dict = self.optimizer.state_dict() + state_dict['apex_amp'] = apex_amp.state_dict() + return state_dict + + def load_state_dict(self, state_dict: dict) -> None: + """Load and parse the state dictionary of :attr:`optimizer` and + :attr:`apex_amp`. + + If state_dict contains "apex_amp", the :attr:`apex_amp` will + load the corresponding keys. Otherwise, only the :attr:`optimizer` + will load the state dictionary. + + Note: + :meth:`load_state_dict` shuold be called after + `apex_amp.initialize` is called. + Args: + state_dict (dict): The state dict of :attr:`optimizer` and + :attr:`apex_amp` + """ + if 'apex_amp' in state_dict: + # when `apex_amp` is not initialized, calling `load_state_dict` + # will raise an error, so we temporarily cache the apex_amp + # part, and then load it into `apex_amp` after completing + # the `apex_amp` initialization in `optim_context` method + if hasattr(self.optimizer, '_amp_stash'): + apex_amp.load_state_dict(state_dict.pop('apex_amp')) + else: + self._apex_amp_state_dict = state_dict.pop('apex_amp') + self.optimizer.load_state_dict(state_dict) + + @contextmanager + def optim_context(self, model: nn.Module): + """Enables the context for mixed precision training, and enables the + context for disabling gradient synchronization during gradient + accumulation context. + + Args: + model (nn.Module): The training model. + """ + with super().optim_context(model): + # when a given optimizer be passed through apex_amp.initialize, + # the "_amp_stash" property will be added + if not hasattr(self.optimizer, '_amp_stash'): + if mmengine.model.wrappers.is_model_wrapper(model): + model = model.module + model, self.optimizer = apex_amp.initialize( + model, + self.optimizer, + opt_level=self.opt_level, + loss_scale=self.loss_scale, + enabled=self.enabled, + cast_model_type=self.cast_model_type, + patch_torch_functions=self.patch_torch_functions, + keep_batchnorm_fp32=self.keep_batchnorm_fp32, + master_weights=self.master_weights, + cast_model_outputs=self.cast_model_outputs, + num_losses=self.num_losses, + verbosity=self.verbosity, + min_loss_scale=self.min_loss_scale, + max_loss_scale=self.max_loss_scale) + # loading apex_amp state_dict after initialization of apex_amp + if self._apex_amp_state_dict is not None: + apex_amp.load_state_dict(self._apex_amp_state_dict) + self._apex_amp_state_dict = None + yield diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 35984ce3..d00033f2 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -14,12 +14,19 @@ from torch.optim import SGD, Adam, Optimizer from mmengine.dist import all_gather from mmengine.logging import MessageHub, MMLogger -from mmengine.optim import AmpOptimWrapper, OptimWrapper +from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION +is_apex_available = False +try: + import apex.amp as apex_amp + is_apex_available = True +except ImportError: + pass + class ToyModel(nn.Module): @@ -283,6 +290,101 @@ class TestOptimWrapper(MultiProcessTestCase): optim_wrapper.zero_grad = MagicMock() +@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex') +class TestApexOptimWrapper(TestCase): + + def setUp(self) -> None: + self.model = ToyModel().cuda() + self.optimizer = SGD(self.model.parameters(), lr=0.1) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_init(self): + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + pass + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_step(self): + optimizer = MagicMock(spec=Optimizer) + apex_optim_wrapper = ApexOptimWrapper( + optimizer=optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.backward(loss) + apex_optim_wrapper.step() + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_backward(self): + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.backward(loss) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_state_dict(self): + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.update_params(loss) + state_dict = apex_optim_wrapper.state_dict() + amp_state_dict = state_dict.pop('apex_amp') + optim_state_dict = state_dict + + self.assertDictEqual(optim_state_dict, + apex_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(amp_state_dict, apex_amp.state_dict()) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_load_state_dict(self): + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + # Test load from optimizer + optimizer = SGD(self.model.parameters(), lr=0.1) + apex_optim_wrapper.load_state_dict(optimizer.state_dict()) + + self.assertDictEqual(optimizer.state_dict(), + apex_optim_wrapper.optimizer.state_dict()) + # Test load from optim_wrapper + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) + apex_optim_wrapper_ = ApexOptimWrapper( + optimizer=SGD(self.model.parameters(), lr=0.1)) + apex_optim_wrapper_.load_state_dict( + apex_optim_wrapper.state_dict()) + self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), + apex_optim_wrapper_.optimizer.state_dict()) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_optim_context(self): + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + x = torch.randn(1, 1, 1, 1).cuda() + y = nn.Conv2d(1, 1, 1).cuda()(x) + self.assertEqual(y.dtype, torch.float16) + + class TestAmpOptimWrapper(TestCase): def setUp(self) -> None: