[Fix] Fix optim_wrapper unittest for pytorch <= 1.10.0 (#975)

pull/956/head^2
Qian Zhao 2023-03-02 14:14:23 +08:00 committed by GitHub
parent 2ed8e343a0
commit b1b1f53db2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 0 deletions

View File

@ -17,6 +17,8 @@ from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
is_apex_available = False
try:
@ -438,6 +440,10 @@ class TestAmpOptimWrapper(TestCase):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_step(self, dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
optimizer = MagicMock(spec=Optimizer)
@ -454,6 +460,10 @@ class TestAmpOptimWrapper(TestCase):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_backward(self, dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(
@ -512,6 +522,10 @@ class TestAmpOptimWrapper(TestCase):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_optim_context(self, dtype, target_dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(