diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index a6575a42..5858857b 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -4,6 +4,17 @@ from typing import Optional import torch +try: + import torch_npu # noqa: F401 + + # Enable operator support for dynamic shape and + # binary operator support on the NPU. + npu_jit_compile = bool(os.getenv('NPUJITCompile', False)) + torch.npu.set_compile_mode(jit_compile=npu_jit_compile) + IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available() +except Exception: + IS_NPU_AVAILABLE = False + def get_max_cuda_memory(device: Optional[torch.device] = None) -> int: """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for @@ -35,16 +46,7 @@ def is_cuda_available() -> bool: def is_npu_available() -> bool: """Returns True if Ascend PyTorch and npu devices exist.""" - try: - import torch_npu # noqa: F401 - - # Enable operator support for dynamic shape and - # binary operator support on the NPU. - npu_jit_compile = bool(os.getenv('NPUJITCompile', False)) - torch.npu.set_compile_mode(jit_compile=npu_jit_compile) - except Exception: - return False - return hasattr(torch, 'npu') and torch.npu.is_available() + return IS_NPU_AVAILABLE def is_mlu_available() -> bool: @@ -60,19 +62,21 @@ def is_mps_available() -> bool: return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() +DEVICE = 'cpu' +if is_npu_available(): + DEVICE = 'npu' +elif is_cuda_available(): + DEVICE = 'cuda' +elif is_mlu_available(): + DEVICE = 'mlu' +elif is_mps_available(): + DEVICE = 'mps' + + def get_device() -> str: """Returns the currently existing device type. Returns: str: cuda | npu | mlu | mps | cpu. """ - if is_npu_available(): - return 'npu' - elif is_cuda_available(): - return 'cuda' - elif is_mlu_available(): - return 'mlu' - elif is_mps_available(): - return 'mps' - else: - return 'cpu' + return DEVICE