mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the get_device
and is_npu_available
(#1004)
* refactor get_device and is_npu_available * Update mmengine/device/utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
ff27b723db
commit
cbb671403f
@ -4,6 +4,17 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
# Enable operator support for dynamic shape and
|
||||
# binary operator support on the NPU.
|
||||
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
|
||||
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
|
||||
IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
except Exception:
|
||||
IS_NPU_AVAILABLE = False
|
||||
|
||||
|
||||
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||
@ -35,16 +46,7 @@ def is_cuda_available() -> bool:
|
||||
|
||||
def is_npu_available() -> bool:
|
||||
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
# Enable operator support for dynamic shape and
|
||||
# binary operator support on the NPU.
|
||||
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
|
||||
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
|
||||
except Exception:
|
||||
return False
|
||||
return hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
return IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
def is_mlu_available() -> bool:
|
||||
@ -60,19 +62,21 @@ def is_mps_available() -> bool:
|
||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
DEVICE = 'cpu'
|
||||
if is_npu_available():
|
||||
DEVICE = 'npu'
|
||||
elif is_cuda_available():
|
||||
DEVICE = 'cuda'
|
||||
elif is_mlu_available():
|
||||
DEVICE = 'mlu'
|
||||
elif is_mps_available():
|
||||
DEVICE = 'mps'
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Returns the currently existing device type.
|
||||
|
||||
Returns:
|
||||
str: cuda | npu | mlu | mps | cpu.
|
||||
"""
|
||||
if is_npu_available():
|
||||
return 'npu'
|
||||
elif is_cuda_available():
|
||||
return 'cuda'
|
||||
elif is_mlu_available():
|
||||
return 'mlu'
|
||||
elif is_mps_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
return DEVICE
|
||||
|
Loading…
x
Reference in New Issue
Block a user