[Feature] Support dipu device (#1127)

This commit is contained in:
CokeDong 2023-05-25 14:10:45 +08:00 committed by GitHub
parent 7451216259
commit 49613414b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 3 deletions

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available, is_npu_available,
is_npu_support_full_precision)
is_dipu_available, is_mlu_available, is_mps_available,
is_npu_available, is_npu_support_full_precision)
__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_npu_support_full_precision'
'is_dipu_available', 'is_npu_support_full_precision'
]

View File

@ -16,6 +16,12 @@ try:
except Exception:
IS_NPU_AVAILABLE = False
try:
import torch_dipu # noqa: F401
IS_DIPU_AVAILABLE = True
except Exception:
IS_DIPU_AVAILABLE = False
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
@ -63,6 +69,10 @@ def is_mps_available() -> bool:
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
def is_dipu_available() -> bool:
return IS_DIPU_AVAILABLE
def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220
@ -79,6 +89,8 @@ elif is_mlu_available():
DEVICE = 'mlu'
elif is_mps_available():
DEVICE = 'mps'
elif is_dipu_available():
DEVICE = 'dipu'
def get_device() -> str: