mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support dipu device (#1127)
This commit is contained in:
parent
7451216259
commit
49613414b2
@ -1,10 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||
is_mlu_available, is_mps_available, is_npu_available,
|
||||
is_npu_support_full_precision)
|
||||
is_dipu_available, is_mlu_available, is_mps_available,
|
||||
is_npu_available, is_npu_support_full_precision)
|
||||
|
||||
__all__ = [
|
||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||
'is_mlu_available', 'is_mps_available', 'is_npu_available',
|
||||
'is_npu_support_full_precision'
|
||||
'is_dipu_available', 'is_npu_support_full_precision'
|
||||
]
|
||||
|
@ -16,6 +16,12 @@ try:
|
||||
except Exception:
|
||||
IS_NPU_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import torch_dipu # noqa: F401
|
||||
IS_DIPU_AVAILABLE = True
|
||||
except Exception:
|
||||
IS_DIPU_AVAILABLE = False
|
||||
|
||||
|
||||
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||
@ -63,6 +69,10 @@ def is_mps_available() -> bool:
|
||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def is_dipu_available() -> bool:
|
||||
return IS_DIPU_AVAILABLE
|
||||
|
||||
|
||||
def is_npu_support_full_precision() -> bool:
|
||||
"""Returns True if npu devices support full precision training."""
|
||||
version_of_support_full_precision = 220
|
||||
@ -79,6 +89,8 @@ elif is_mlu_available():
|
||||
DEVICE = 'mlu'
|
||||
elif is_mps_available():
|
||||
DEVICE = 'mps'
|
||||
elif is_dipu_available():
|
||||
DEVICE = 'dipu'
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
|
Loading…
x
Reference in New Issue
Block a user