[Enhance] Add unit tests for autocast with Ascend device (#1363)

This commit is contained in:
6V 2023-09-27 10:20:13 +08:00 committed by GitHub
parent 88dc1e98b1
commit e9e08dbb65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import mmengine
from mmengine.device import get_device, is_mlu_available
from mmengine.device import get_device, is_mlu_available, is_npu_available
from mmengine.runner import autocast
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION
class TestAmp(unittest.TestCase):
def test_autocast(self):
if is_mlu_available():
if is_npu_available():
device = 'npu'
with autocast(device_type=device):
# torch.autocast support npu mode.
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False, device_type=device):
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
# Test with fp32_enabled
with autocast(enabled=False, device_type=device):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
elif is_mlu_available():
device = 'mlu'
with autocast(device_type=device):
# torch.autocast support mlu mode.