[Fix] Fix optim_wrapper unittest for pytorch <= 1.10.0 (#975)
parent
2ed8e343a0
commit
b1b1f53db2
|
@ -17,6 +17,8 @@ from mmengine.logging import MessageHub, MMLogger
|
|||
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
|
||||
is_apex_available = False
|
||||
try:
|
||||
|
@ -438,6 +440,10 @@ class TestAmpOptimWrapper(TestCase):
|
|||
not torch.cuda.is_available(),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
|
||||
def test_step(self, dtype):
|
||||
if dtype is not None and (digit_version(TORCH_VERSION) <
|
||||
digit_version('1.10.0')):
|
||||
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
|
||||
'support `dtype` argument in autocast')
|
||||
if dtype == 'bfloat16' and not bf16_supported():
|
||||
raise unittest.SkipTest('bfloat16 not supported by device')
|
||||
optimizer = MagicMock(spec=Optimizer)
|
||||
|
@ -454,6 +460,10 @@ class TestAmpOptimWrapper(TestCase):
|
|||
not torch.cuda.is_available(),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
|
||||
def test_backward(self, dtype):
|
||||
if dtype is not None and (digit_version(TORCH_VERSION) <
|
||||
digit_version('1.10.0')):
|
||||
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
|
||||
'support `dtype` argument in autocast')
|
||||
if dtype == 'bfloat16' and not bf16_supported():
|
||||
raise unittest.SkipTest('bfloat16 not supported by device')
|
||||
amp_optim_wrapper = AmpOptimWrapper(
|
||||
|
@ -512,6 +522,10 @@ class TestAmpOptimWrapper(TestCase):
|
|||
not torch.cuda.is_available(),
|
||||
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
|
||||
def test_optim_context(self, dtype, target_dtype):
|
||||
if dtype is not None and (digit_version(TORCH_VERSION) <
|
||||
digit_version('1.10.0')):
|
||||
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
|
||||
'support `dtype` argument in autocast')
|
||||
if dtype == 'bfloat16' and not bf16_supported():
|
||||
raise unittest.SkipTest('bfloat16 not supported by device')
|
||||
amp_optim_wrapper = AmpOptimWrapper(
|
||||
|
|
Loading…
Reference in New Issue