parent
e877862d5b
commit
2994195be2
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||
is_mlu_available)
|
||||
is_mlu_available, is_mps_available)
|
||||
|
||||
__all__ = [
|
||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||
'is_mlu_available'
|
||||
'is_mlu_available', 'is_mps_available'
|
||||
]
|
||||
|
|
|
@ -37,15 +37,25 @@ def is_mlu_available() -> bool:
|
|||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
|
||||
|
||||
def is_mps_available() -> bool:
|
||||
"""Return True if mps devices exist.
|
||||
|
||||
It's specialized for mac m1 chips and require torch version 1.12 or higher.
|
||||
"""
|
||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Returns the currently existing device type.
|
||||
|
||||
Returns:
|
||||
str: cuda | mlu | cpu.
|
||||
str: cuda | mlu | mps | cpu.
|
||||
"""
|
||||
if is_cuda_available():
|
||||
return 'cuda'
|
||||
elif is_mlu_available():
|
||||
return 'mlu'
|
||||
elif is_mps_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.device import get_device, is_cuda_available, is_mlu_available
|
||||
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
|
||||
is_mps_available)
|
||||
|
||||
|
||||
def test_get_device():
|
||||
|
@ -8,5 +9,7 @@ def test_get_device():
|
|||
assert device == 'cuda'
|
||||
elif is_mlu_available():
|
||||
assert device == 'mlu'
|
||||
elif is_mps_available():
|
||||
assert device == 'mps'
|
||||
else:
|
||||
assert device == 'cpu'
|
||||
|
|
Loading…
Reference in New Issue