mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support dipu device (#1127)
This commit is contained in:
parent
7451216259
commit
49613414b2
@ -1,10 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||||
is_mlu_available, is_mps_available, is_npu_available,
|
is_dipu_available, is_mlu_available, is_mps_available,
|
||||||
is_npu_support_full_precision)
|
is_npu_available, is_npu_support_full_precision)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||||
'is_mlu_available', 'is_mps_available', 'is_npu_available',
|
'is_mlu_available', 'is_mps_available', 'is_npu_available',
|
||||||
'is_npu_support_full_precision'
|
'is_dipu_available', 'is_npu_support_full_precision'
|
||||||
]
|
]
|
||||||
|
@ -16,6 +16,12 @@ try:
|
|||||||
except Exception:
|
except Exception:
|
||||||
IS_NPU_AVAILABLE = False
|
IS_NPU_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_dipu # noqa: F401
|
||||||
|
IS_DIPU_AVAILABLE = True
|
||||||
|
except Exception:
|
||||||
|
IS_DIPU_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||||
@ -63,6 +69,10 @@ def is_mps_available() -> bool:
|
|||||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def is_dipu_available() -> bool:
|
||||||
|
return IS_DIPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def is_npu_support_full_precision() -> bool:
|
def is_npu_support_full_precision() -> bool:
|
||||||
"""Returns True if npu devices support full precision training."""
|
"""Returns True if npu devices support full precision training."""
|
||||||
version_of_support_full_precision = 220
|
version_of_support_full_precision = 220
|
||||||
@ -79,6 +89,8 @@ elif is_mlu_available():
|
|||||||
DEVICE = 'mlu'
|
DEVICE = 'mlu'
|
||||||
elif is_mps_available():
|
elif is_mps_available():
|
||||||
DEVICE = 'mps'
|
DEVICE = 'mps'
|
||||||
|
elif is_dipu_available():
|
||||||
|
DEVICE = 'dipu'
|
||||||
|
|
||||||
|
|
||||||
def get_device() -> str:
|
def get_device() -> str:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user