mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add ApexOptimWrapper (#742)
* add ApexOptimWrapper * typo fix * add apex amp.initialize in optim_context * assert apex_amp * polish code * add parameters of apex_amp.initialize * add docs * polish code * polish code * polish code * fix calling of apex amp load_state_dict * polish * add comments * Update apex_optimizer_wrapper.py * Update apex_optimizer_wrapper.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
bc49e0c0a1
commit
e35ed5fd2e
@ -20,6 +20,7 @@ Optimizer
|
|||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
AmpOptimWrapper
|
AmpOptimWrapper
|
||||||
|
ApexOptimWrapper
|
||||||
OptimWrapper
|
OptimWrapper
|
||||||
OptimWrapperDict
|
OptimWrapperDict
|
||||||
DefaultOptimWrapperConstructor
|
DefaultOptimWrapperConstructor
|
||||||
|
@ -20,6 +20,7 @@ Optimizer
|
|||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
AmpOptimWrapper
|
AmpOptimWrapper
|
||||||
|
ApexOptimWrapper
|
||||||
OptimWrapper
|
OptimWrapper
|
||||||
OptimWrapperDict
|
OptimWrapperDict
|
||||||
DefaultOptimWrapperConstructor
|
DefaultOptimWrapperConstructor
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||||
AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
AmpOptimWrapper, ApexOptimWrapper,
|
||||||
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
|
DefaultOptimWrapperConstructor, OptimWrapper,
|
||||||
|
OptimWrapperDict, build_optim_wrapper)
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||||
@ -25,8 +26,8 @@ __all__ = [
|
|||||||
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
||||||
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
||||||
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
||||||
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
|
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper',
|
||||||
'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum',
|
'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR',
|
||||||
'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum',
|
'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR',
|
||||||
'ReduceOnPlateauParamScheduler'
|
'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler'
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .amp_optimizer_wrapper import AmpOptimWrapper
|
from .amp_optimizer_wrapper import AmpOptimWrapper
|
||||||
|
from .apex_optimizer_wrapper import ApexOptimWrapper
|
||||||
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||||
build_optim_wrapper)
|
build_optim_wrapper)
|
||||||
from .default_constructor import DefaultOptimWrapperConstructor
|
from .default_constructor import DefaultOptimWrapperConstructor
|
||||||
@ -10,5 +11,6 @@ from .zero_optimizer import ZeroRedundancyOptimizer
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS',
|
||||||
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper',
|
||||||
'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer'
|
'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict',
|
||||||
|
'ZeroRedundancyOptimizer'
|
||||||
]
|
]
|
||||||
|
200
mmengine/optim/optimizer/apex_optimizer_wrapper.py
Normal file
200
mmengine/optim/optimizer/apex_optimizer_wrapper.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# a circular import will be caused by
|
||||||
|
# from mmengine.model.wrappers import is_model_wrapper
|
||||||
|
import mmengine
|
||||||
|
from mmengine.registry import OPTIM_WRAPPERS
|
||||||
|
from .optimizer_wrapper import OptimWrapper
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex.amp as apex_amp
|
||||||
|
except ImportError:
|
||||||
|
apex_amp = None
|
||||||
|
|
||||||
|
|
||||||
|
@OPTIM_WRAPPERS.register_module()
|
||||||
|
class ApexOptimWrapper(OptimWrapper):
|
||||||
|
"""A subclass of :class:`OptimWrapper` that supports automatic mixed
|
||||||
|
precision training based on apex.amp.
|
||||||
|
|
||||||
|
``ApexOptimWrapper`` provides a unified interface with
|
||||||
|
``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
``ApexOptimWrapper`` requires `nvidia apex <https://github.com/NVIDIA/apex>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opt_level (str): Pure or mixed precision optimization level. Accepted
|
||||||
|
values are "O0", "O1", "O2", and "O3". Defaults to "O1".
|
||||||
|
loss_scale (float or str, optional): If passed as a string, must be a
|
||||||
|
string representing a number, e.g., "128.0", or the string
|
||||||
|
"dynamic". Defaults to "dynamic".
|
||||||
|
enabled (bool): If False, renders all Amp calls no-ops, so your script
|
||||||
|
should run as if Amp were not present. Defaults to True.
|
||||||
|
cast_model_type (torch.dtype, optional): Model's parameters and
|
||||||
|
buffers to the desired type. Defaults to None.
|
||||||
|
patch_torch_functions (bool, optional): Patch all Torch functions
|
||||||
|
and Tensor methods to perform Tensor Core-friendly ops like GEMMs
|
||||||
|
and convolutions in FP16, and any ops that benefit from FP32
|
||||||
|
precision in FP32. Defaults to None.
|
||||||
|
keep_batchnorm_fp32 (bool or str, optional): To enhance precision
|
||||||
|
and enable cudnn batchnorm (which improves performance),
|
||||||
|
it's often beneficial to keep batchnorm weights in FP32
|
||||||
|
even if the rest of the model is FP16.
|
||||||
|
If passed as a string, must be the string "True" or "False".
|
||||||
|
Defaults to None.
|
||||||
|
master_weights (bool, optional): Maintain FP32 master weights to
|
||||||
|
accompany any FP16 model weights. FP32 master weights are stepped
|
||||||
|
by the optimizer to enhance precision and capture small gradients.
|
||||||
|
Defaults to None.
|
||||||
|
cast_model_outputs (torch.dtype, optional): Option to ensure that
|
||||||
|
the outputs of your model(s) are always cast to a particular type
|
||||||
|
regardless of ``opt_level``. Defaults to None.
|
||||||
|
num_losses (int): Option to tell Amp in advance how many
|
||||||
|
losses/backward passes you plan to use. Defaults to 1.
|
||||||
|
verbosity (int): Set to 0 to suppress Amp-related output.
|
||||||
|
Defaults to 1.
|
||||||
|
min_loss_scale (float, optional): Sets a floor for the loss scale
|
||||||
|
values that can be chosen by dynamic loss scaling.
|
||||||
|
The default value of None means that no floor is imposed.
|
||||||
|
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
|
||||||
|
Defaults to None.
|
||||||
|
max_loss_scale (float, optional): Sets a ceiling for the loss scale
|
||||||
|
values that can be chosen by dynamic loss scaling. If dynamic
|
||||||
|
loss scaling is not used, `max_loss_scale` is ignored.
|
||||||
|
Defaults to 2.**24.
|
||||||
|
**kwargs: Keyword arguments passed to OptimWrapper.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If you use ``IterBasedRunner`` and enable gradient accumulation,
|
||||||
|
the original `max_iters` should be multiplied by
|
||||||
|
``accumulative_counts``.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
`New in version 0.6.0.`
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
opt_level: str = 'O1',
|
||||||
|
loss_scale: Union[float, str, None] = 'dynamic',
|
||||||
|
enabled: Optional[bool] = True,
|
||||||
|
cast_model_type: Optional[torch.dtype] = None,
|
||||||
|
patch_torch_functions: Optional[bool] = None,
|
||||||
|
keep_batchnorm_fp32: Union[bool, str, None] = None,
|
||||||
|
master_weights: Optional[bool] = None,
|
||||||
|
cast_model_outputs: Optional[torch.dtype] = None,
|
||||||
|
num_losses: int = 1,
|
||||||
|
verbosity: int = 1,
|
||||||
|
min_loss_scale: Optional[float] = None,
|
||||||
|
max_loss_scale: Optional[float] = 2.**24,
|
||||||
|
**kwargs):
|
||||||
|
assert apex_amp is not None, \
|
||||||
|
'Apex is not installed. Please check ' \
|
||||||
|
'https://github.com/NVIDIA/apex#linux.'
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.opt_level = opt_level
|
||||||
|
self.loss_scale = loss_scale
|
||||||
|
self.enabled = enabled
|
||||||
|
self.cast_model_type = cast_model_type
|
||||||
|
self.patch_torch_functions = patch_torch_functions
|
||||||
|
self.keep_batchnorm_fp32 = keep_batchnorm_fp32
|
||||||
|
self.master_weights = master_weights
|
||||||
|
self.cast_model_outputs = cast_model_outputs
|
||||||
|
self.num_losses = num_losses
|
||||||
|
self.verbosity = verbosity
|
||||||
|
self.min_loss_scale = min_loss_scale
|
||||||
|
self.max_loss_scale = max_loss_scale
|
||||||
|
self._apex_amp_state_dict = None
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor, **kwargs) -> None:
|
||||||
|
"""Perform gradient back propagation with :attr:`loss_scaler`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (torch.Tensor): The loss of current iteration.
|
||||||
|
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`
|
||||||
|
"""
|
||||||
|
with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward(**kwargs)
|
||||||
|
self._inner_count += 1
|
||||||
|
|
||||||
|
def state_dict(self) -> dict:
|
||||||
|
"""Get the state dictionary of :attr:`optimizer` and
|
||||||
|
:attr:`apex_amp`.
|
||||||
|
|
||||||
|
Based on the state dictionary of the optimizer, the returned state
|
||||||
|
dictionary will add a key named "apex_amp".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The merged state dict of :attr:`apex_amp` and
|
||||||
|
:attr:`optimizer`.
|
||||||
|
"""
|
||||||
|
state_dict = self.optimizer.state_dict()
|
||||||
|
state_dict['apex_amp'] = apex_amp.state_dict()
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
|
"""Load and parse the state dictionary of :attr:`optimizer` and
|
||||||
|
:attr:`apex_amp`.
|
||||||
|
|
||||||
|
If state_dict contains "apex_amp", the :attr:`apex_amp` will
|
||||||
|
load the corresponding keys. Otherwise, only the :attr:`optimizer`
|
||||||
|
will load the state dictionary.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
:meth:`load_state_dict` shuold be called after
|
||||||
|
`apex_amp.initialize` is called.
|
||||||
|
Args:
|
||||||
|
state_dict (dict): The state dict of :attr:`optimizer` and
|
||||||
|
:attr:`apex_amp`
|
||||||
|
"""
|
||||||
|
if 'apex_amp' in state_dict:
|
||||||
|
# when `apex_amp` is not initialized, calling `load_state_dict`
|
||||||
|
# will raise an error, so we temporarily cache the apex_amp
|
||||||
|
# part, and then load it into `apex_amp` after completing
|
||||||
|
# the `apex_amp` initialization in `optim_context` method
|
||||||
|
if hasattr(self.optimizer, '_amp_stash'):
|
||||||
|
apex_amp.load_state_dict(state_dict.pop('apex_amp'))
|
||||||
|
else:
|
||||||
|
self._apex_amp_state_dict = state_dict.pop('apex_amp')
|
||||||
|
self.optimizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def optim_context(self, model: nn.Module):
|
||||||
|
"""Enables the context for mixed precision training, and enables the
|
||||||
|
context for disabling gradient synchronization during gradient
|
||||||
|
accumulation context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The training model.
|
||||||
|
"""
|
||||||
|
with super().optim_context(model):
|
||||||
|
# when a given optimizer be passed through apex_amp.initialize,
|
||||||
|
# the "_amp_stash" property will be added
|
||||||
|
if not hasattr(self.optimizer, '_amp_stash'):
|
||||||
|
if mmengine.model.wrappers.is_model_wrapper(model):
|
||||||
|
model = model.module
|
||||||
|
model, self.optimizer = apex_amp.initialize(
|
||||||
|
model,
|
||||||
|
self.optimizer,
|
||||||
|
opt_level=self.opt_level,
|
||||||
|
loss_scale=self.loss_scale,
|
||||||
|
enabled=self.enabled,
|
||||||
|
cast_model_type=self.cast_model_type,
|
||||||
|
patch_torch_functions=self.patch_torch_functions,
|
||||||
|
keep_batchnorm_fp32=self.keep_batchnorm_fp32,
|
||||||
|
master_weights=self.master_weights,
|
||||||
|
cast_model_outputs=self.cast_model_outputs,
|
||||||
|
num_losses=self.num_losses,
|
||||||
|
verbosity=self.verbosity,
|
||||||
|
min_loss_scale=self.min_loss_scale,
|
||||||
|
max_loss_scale=self.max_loss_scale)
|
||||||
|
# loading apex_amp state_dict after initialization of apex_amp
|
||||||
|
if self._apex_amp_state_dict is not None:
|
||||||
|
apex_amp.load_state_dict(self._apex_amp_state_dict)
|
||||||
|
self._apex_amp_state_dict = None
|
||||||
|
yield
|
@ -14,12 +14,19 @@ from torch.optim import SGD, Adam, Optimizer
|
|||||||
|
|
||||||
from mmengine.dist import all_gather
|
from mmengine.dist import all_gather
|
||||||
from mmengine.logging import MessageHub, MMLogger
|
from mmengine.logging import MessageHub, MMLogger
|
||||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
from mmengine.testing._internal import MultiProcessTestCase
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
from mmengine.utils import digit_version
|
from mmengine.utils import digit_version
|
||||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||||
|
|
||||||
|
is_apex_available = False
|
||||||
|
try:
|
||||||
|
import apex.amp as apex_amp
|
||||||
|
is_apex_available = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToyModel(nn.Module):
|
class ToyModel(nn.Module):
|
||||||
|
|
||||||
@ -283,6 +290,101 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
optim_wrapper.zero_grad = MagicMock()
|
optim_wrapper.zero_grad = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex')
|
||||||
|
class TestApexOptimWrapper(TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.model = ToyModel().cuda()
|
||||||
|
self.optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_apex_available,
|
||||||
|
reason='`apex` is not available, Please install apex from '
|
||||||
|
'https://www.github.com/nvidia/apex')
|
||||||
|
def test_init(self):
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(
|
||||||
|
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||||
|
with apex_optim_wrapper.optim_context(self.model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_apex_available,
|
||||||
|
reason='`apex` is not available, Please install apex from '
|
||||||
|
'https://www.github.com/nvidia/apex')
|
||||||
|
def test_step(self):
|
||||||
|
optimizer = MagicMock(spec=Optimizer)
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(
|
||||||
|
optimizer=optimizer, opt_level='O1', loss_scale=1)
|
||||||
|
with apex_optim_wrapper.optim_context(self.model):
|
||||||
|
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||||
|
apex_optim_wrapper.backward(loss)
|
||||||
|
apex_optim_wrapper.step()
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_apex_available,
|
||||||
|
reason='`apex` is not available, Please install apex from '
|
||||||
|
'https://www.github.com/nvidia/apex')
|
||||||
|
def test_backward(self):
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(
|
||||||
|
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||||
|
with apex_optim_wrapper.optim_context(self.model):
|
||||||
|
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||||
|
apex_optim_wrapper.backward(loss)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_apex_available,
|
||||||
|
reason='`apex` is not available, Please install apex from '
|
||||||
|
'https://www.github.com/nvidia/apex')
|
||||||
|
def test_state_dict(self):
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(
|
||||||
|
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||||
|
with apex_optim_wrapper.optim_context(self.model):
|
||||||
|
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
|
||||||
|
apex_optim_wrapper.update_params(loss)
|
||||||
|
state_dict = apex_optim_wrapper.state_dict()
|
||||||
|
amp_state_dict = state_dict.pop('apex_amp')
|
||||||
|
optim_state_dict = state_dict
|
||||||
|
|
||||||
|
self.assertDictEqual(optim_state_dict,
|
||||||
|
apex_optim_wrapper.optimizer.state_dict())
|
||||||
|
self.assertDictEqual(amp_state_dict, apex_amp.state_dict())
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_apex_available,
|
||||||
|
reason='`apex` is not available, Please install apex from '
|
||||||
|
'https://www.github.com/nvidia/apex')
|
||||||
|
def test_load_state_dict(self):
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(
|
||||||
|
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||||
|
with apex_optim_wrapper.optim_context(self.model):
|
||||||
|
# Test load from optimizer
|
||||||
|
optimizer = SGD(self.model.parameters(), lr=0.1)
|
||||||
|
apex_optim_wrapper.load_state_dict(optimizer.state_dict())
|
||||||
|
|
||||||
|
self.assertDictEqual(optimizer.state_dict(),
|
||||||
|
apex_optim_wrapper.optimizer.state_dict())
|
||||||
|
# Test load from optim_wrapper
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer)
|
||||||
|
apex_optim_wrapper_ = ApexOptimWrapper(
|
||||||
|
optimizer=SGD(self.model.parameters(), lr=0.1))
|
||||||
|
apex_optim_wrapper_.load_state_dict(
|
||||||
|
apex_optim_wrapper.state_dict())
|
||||||
|
self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(),
|
||||||
|
apex_optim_wrapper_.optimizer.state_dict())
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_apex_available,
|
||||||
|
reason='`apex` is not available, Please install apex from '
|
||||||
|
'https://www.github.com/nvidia/apex')
|
||||||
|
def test_optim_context(self):
|
||||||
|
apex_optim_wrapper = ApexOptimWrapper(
|
||||||
|
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
|
||||||
|
with apex_optim_wrapper.optim_context(self.model):
|
||||||
|
x = torch.randn(1, 1, 1, 1).cuda()
|
||||||
|
y = nn.Conv2d(1, 1, 1).cuda()(x)
|
||||||
|
self.assertEqual(y.dtype, torch.float16)
|
||||||
|
|
||||||
|
|
||||||
class TestAmpOptimWrapper(TestCase):
|
class TestAmpOptimWrapper(TestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user