diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index f09437b1..2684efd6 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -121,9 +121,18 @@ def autocast(device_type: Optional[str] = None, assert dtype == torch.bfloat16, ( 'In CPU autocast, only support `torch.bfloat16` dtype') + elif device_type == 'mlu': + pass else: - raise ValueError('User specified autocast device_type must be ' - F'cuda or cpu, but got {device_type}') + # Device like MPS does not support fp16 training or testing. + # If an inappropriate device is set and fp16 is enabled, an error + # will be thrown. + if enabled is False: + yield + return + else: + raise ValueError('User specified autocast device_type must be ' + f'cuda or cpu, but got {device_type}') with torch.autocast( device_type=device_type, diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index 7d7a8ca8..7ef60563 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -4,6 +4,8 @@ import unittest import torch import torch.nn as nn +import mmengine +from mmengine.device import get_device from mmengine.runner import autocast from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -56,3 +58,24 @@ class TestAmp(unittest.TestCase): layer = nn.Conv2d(1, 1, 1).to(device) res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32) + + # Test mps + if digit_version(TORCH_VERSION) >= digit_version('1.12.0'): + mmengine.runner.amp.get_device = lambda: 'mps' + with autocast(enabled=False): + layer = nn.Conv2d(1, 1, 1) + res = layer(torch.randn(1, 1, 1, 1)) + self.assertEqual(res.dtype, torch.float32) + + with self.assertRaisesRegex(ValueError, + 'User specified autocast device_type'): + with autocast(enabled=True): + pass + # Native pytorch does not support mlu, here we simply test autocast + # will call `torch.autocast`, which will be overridden by mlu version + # pytorch + mmengine.runner.amp.get_device = lambda: 'mlu' + with self.assertRaises(RuntimeError): + with autocast(enabled=False): + pass + mmengine.runner.amp.get_device = get_device