parent
e877862d5b
commit
2994195be2
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||||
is_mlu_available)
|
is_mlu_available, is_mps_available)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||||
'is_mlu_available'
|
'is_mlu_available', 'is_mps_available'
|
||||||
]
|
]
|
||||||
|
|
|
@ -37,15 +37,25 @@ def is_mlu_available() -> bool:
|
||||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||||
|
|
||||||
|
|
||||||
|
def is_mps_available() -> bool:
|
||||||
|
"""Return True if mps devices exist.
|
||||||
|
|
||||||
|
It's specialized for mac m1 chips and require torch version 1.12 or higher.
|
||||||
|
"""
|
||||||
|
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
def get_device() -> str:
|
def get_device() -> str:
|
||||||
"""Returns the currently existing device type.
|
"""Returns the currently existing device type.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: cuda | mlu | cpu.
|
str: cuda | mlu | mps | cpu.
|
||||||
"""
|
"""
|
||||||
if is_cuda_available():
|
if is_cuda_available():
|
||||||
return 'cuda'
|
return 'cuda'
|
||||||
elif is_mlu_available():
|
elif is_mlu_available():
|
||||||
return 'mlu'
|
return 'mlu'
|
||||||
|
elif is_mps_available():
|
||||||
|
return 'mps'
|
||||||
else:
|
else:
|
||||||
return 'cpu'
|
return 'cpu'
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmengine.device import get_device, is_cuda_available, is_mlu_available
|
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
|
||||||
|
is_mps_available)
|
||||||
|
|
||||||
|
|
||||||
def test_get_device():
|
def test_get_device():
|
||||||
|
@ -8,5 +9,7 @@ def test_get_device():
|
||||||
assert device == 'cuda'
|
assert device == 'cuda'
|
||||||
elif is_mlu_available():
|
elif is_mlu_available():
|
||||||
assert device == 'mlu'
|
assert device == 'mlu'
|
||||||
|
elif is_mps_available():
|
||||||
|
assert device == 'mps'
|
||||||
else:
|
else:
|
||||||
assert device == 'cpu'
|
assert device == 'cpu'
|
||||||
|
|
Loading…
Reference in New Issue