mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Add unit tests for autocast with Ascend device (#1363)
This commit is contained in:
parent
88dc1e98b1
commit
e9e08dbb65
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import mmengine
|
||||
from mmengine.device import get_device, is_mlu_available
|
||||
from mmengine.device import get_device, is_mlu_available, is_npu_available
|
||||
from mmengine.runner import autocast
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
class TestAmp(unittest.TestCase):
|
||||
|
||||
def test_autocast(self):
|
||||
if is_mlu_available():
|
||||
if is_npu_available():
|
||||
device = 'npu'
|
||||
with autocast(device_type=device):
|
||||
# torch.autocast support npu mode.
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||
with autocast(enabled=False, device_type=device):
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
# Test with fp32_enabled
|
||||
with autocast(enabled=False, device_type=device):
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
elif is_mlu_available():
|
||||
device = 'mlu'
|
||||
with autocast(device_type=device):
|
||||
# torch.autocast support mlu mode.
|
||||
|
Loading…
x
Reference in New Issue
Block a user