[Feat] Support training on MPS (#331)

* [Feat] Support mps

* fix docstring
pull/341/head
Alex Yang 2022-06-23 16:53:19 +08:00 committed by GitHub
parent e877862d5b
commit 2994195be2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 4 deletions

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available)
is_mlu_available, is_mps_available)
__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available'
'is_mlu_available', 'is_mps_available'
]

View File

@ -37,15 +37,25 @@ def is_mlu_available() -> bool:
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | cpu.
str: cuda | mlu | mps | cpu.
"""
if is_cuda_available():
return 'cuda'
elif is_mlu_available():
return 'mlu'
elif is_mps_available():
return 'mps'
else:
return 'cpu'

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import get_device, is_cuda_available, is_mlu_available
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
def test_get_device():
@ -8,5 +9,7 @@ def test_get_device():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
elif is_mps_available():
assert device == 'mps'
else:
assert device == 'cpu'