[Feat] Support training on MPS (#331)

* [Feat] Support mps

* fix docstring
pull/341/head
Alex Yang 2022-06-23 16:53:19 +08:00 committed by GitHub
parent e877862d5b
commit 2994195be2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 4 deletions

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available, from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available) is_mlu_available, is_mps_available)
__all__ = [ __all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available', 'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available' 'is_mlu_available', 'is_mps_available'
] ]

View File

@ -37,15 +37,25 @@ def is_mlu_available() -> bool:
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
def get_device() -> str: def get_device() -> str:
"""Returns the currently existing device type. """Returns the currently existing device type.
Returns: Returns:
str: cuda | mlu | cpu. str: cuda | mlu | mps | cpu.
""" """
if is_cuda_available(): if is_cuda_available():
return 'cuda' return 'cuda'
elif is_mlu_available(): elif is_mlu_available():
return 'mlu' return 'mlu'
elif is_mps_available():
return 'mps'
else: else:
return 'cpu' return 'cpu'

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import get_device, is_cuda_available, is_mlu_available from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
def test_get_device(): def test_get_device():
@ -8,5 +9,7 @@ def test_get_device():
assert device == 'cuda' assert device == 'cuda'
elif is_mlu_available(): elif is_mlu_available():
assert device == 'mlu' assert device == 'mlu'
elif is_mps_available():
assert device == 'mps'
else: else:
assert device == 'cpu' assert device == 'cpu'