mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Make autocast compatible with mps (#587)
* [Fix] Make autocast compatible with mps * Enhance unit test * fix unit test * clean the code * fix unit test
This commit is contained in:
parent
6073d9ebd8
commit
abe56651db
@ -121,9 +121,18 @@ def autocast(device_type: Optional[str] = None,
|
|||||||
assert dtype == torch.bfloat16, (
|
assert dtype == torch.bfloat16, (
|
||||||
'In CPU autocast, only support `torch.bfloat16` dtype')
|
'In CPU autocast, only support `torch.bfloat16` dtype')
|
||||||
|
|
||||||
|
elif device_type == 'mlu':
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError('User specified autocast device_type must be '
|
# Device like MPS does not support fp16 training or testing.
|
||||||
F'cuda or cpu, but got {device_type}')
|
# If an inappropriate device is set and fp16 is enabled, an error
|
||||||
|
# will be thrown.
|
||||||
|
if enabled is False:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise ValueError('User specified autocast device_type must be '
|
||||||
|
f'cuda or cpu, but got {device_type}')
|
||||||
|
|
||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
|
@ -4,6 +4,8 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
from mmengine.device import get_device
|
||||||
from mmengine.runner import autocast
|
from mmengine.runner import autocast
|
||||||
from mmengine.utils import digit_version
|
from mmengine.utils import digit_version
|
||||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||||
@ -56,3 +58,24 @@ class TestAmp(unittest.TestCase):
|
|||||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
self.assertEqual(res.dtype, torch.float32)
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
|
|
||||||
|
# Test mps
|
||||||
|
if digit_version(TORCH_VERSION) >= digit_version('1.12.0'):
|
||||||
|
mmengine.runner.amp.get_device = lambda: 'mps'
|
||||||
|
with autocast(enabled=False):
|
||||||
|
layer = nn.Conv2d(1, 1, 1)
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1))
|
||||||
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
'User specified autocast device_type'):
|
||||||
|
with autocast(enabled=True):
|
||||||
|
pass
|
||||||
|
# Native pytorch does not support mlu, here we simply test autocast
|
||||||
|
# will call `torch.autocast`, which will be overridden by mlu version
|
||||||
|
# pytorch
|
||||||
|
mmengine.runner.amp.get_device = lambda: 'mlu'
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
with autocast(enabled=False):
|
||||||
|
pass
|
||||||
|
mmengine.runner.amp.get_device = get_device
|
||||||
|
Loading…
x
Reference in New Issue
Block a user