From 49613414b257bef028ff1d63bbebea32badfecfd Mon Sep 17 00:00:00 2001 From: CokeDong <408244909@qq.com> Date: Thu, 25 May 2023 14:10:45 +0800 Subject: [PATCH] [Feature] Support dipu device (#1127) --- mmengine/device/__init__.py | 6 +++--- mmengine/device/utils.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index 623aa0b8..bfd82a85 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .utils import (get_device, get_max_cuda_memory, is_cuda_available, - is_mlu_available, is_mps_available, is_npu_available, - is_npu_support_full_precision) + is_dipu_available, is_mlu_available, is_mps_available, + is_npu_available, is_npu_support_full_precision) __all__ = [ 'get_max_cuda_memory', 'get_device', 'is_cuda_available', 'is_mlu_available', 'is_mps_available', 'is_npu_available', - 'is_npu_support_full_precision' + 'is_dipu_available', 'is_npu_support_full_precision' ] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index 63c90633..0bb69d2e 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -16,6 +16,12 @@ try: except Exception: IS_NPU_AVAILABLE = False +try: + import torch_dipu # noqa: F401 + IS_DIPU_AVAILABLE = True +except Exception: + IS_DIPU_AVAILABLE = False + def get_max_cuda_memory(device: Optional[torch.device] = None) -> int: """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for @@ -63,6 +69,10 @@ def is_mps_available() -> bool: return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() +def is_dipu_available() -> bool: + return IS_DIPU_AVAILABLE + + def is_npu_support_full_precision() -> bool: """Returns True if npu devices support full precision training.""" version_of_support_full_precision = 220 @@ -79,6 +89,8 @@ elif is_mlu_available(): DEVICE = 'mlu' elif is_mps_available(): DEVICE = 'mps' +elif is_dipu_available(): + DEVICE = 'dipu' def get_device() -> str: