[Refactor] Refactor the get_device and is_npu_available (#1004)

* refactor get_device and is_npu_available

* Update mmengine/device/utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Qian Zhao 2023-03-19 15:55:27 +08:00 committed by GitHub
parent ff27b723db
commit cbb671403f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,17 @@ from typing import Optional
import torch
try:
import torch_npu # noqa: F401
# Enable operator support for dynamic shape and
# binary operator support on the NPU.
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available()
except Exception:
IS_NPU_AVAILABLE = False
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
@ -35,16 +46,7 @@ def is_cuda_available() -> bool:
def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
try:
import torch_npu # noqa: F401
# Enable operator support for dynamic shape and
# binary operator support on the NPU.
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()
return IS_NPU_AVAILABLE
def is_mlu_available() -> bool:
@ -60,19 +62,21 @@ def is_mps_available() -> bool:
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
DEVICE = 'cpu'
if is_npu_available():
DEVICE = 'npu'
elif is_cuda_available():
DEVICE = 'cuda'
elif is_mlu_available():
DEVICE = 'mlu'
elif is_mps_available():
DEVICE = 'mps'
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | npu | mlu | mps | cpu.
"""
if is_npu_available():
return 'npu'
elif is_cuda_available():
return 'cuda'
elif is_mlu_available():
return 'mlu'
elif is_mps_available():
return 'mps'
else:
return 'cpu'
return DEVICE