mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the get_device
and is_npu_available
(#1004)
* refactor get_device and is_npu_available * Update mmengine/device/utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
ff27b723db
commit
cbb671403f
@ -4,6 +4,17 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
|
# Enable operator support for dynamic shape and
|
||||||
|
# binary operator support on the NPU.
|
||||||
|
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
|
||||||
|
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
|
||||||
|
IS_NPU_AVAILABLE = hasattr(torch, 'npu') and torch.npu.is_available()
|
||||||
|
except Exception:
|
||||||
|
IS_NPU_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||||
@ -35,16 +46,7 @@ def is_cuda_available() -> bool:
|
|||||||
|
|
||||||
def is_npu_available() -> bool:
|
def is_npu_available() -> bool:
|
||||||
"""Returns True if Ascend PyTorch and npu devices exist."""
|
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||||
try:
|
return IS_NPU_AVAILABLE
|
||||||
import torch_npu # noqa: F401
|
|
||||||
|
|
||||||
# Enable operator support for dynamic shape and
|
|
||||||
# binary operator support on the NPU.
|
|
||||||
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
|
|
||||||
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
return hasattr(torch, 'npu') and torch.npu.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def is_mlu_available() -> bool:
|
def is_mlu_available() -> bool:
|
||||||
@ -60,19 +62,21 @@ def is_mps_available() -> bool:
|
|||||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = 'cpu'
|
||||||
|
if is_npu_available():
|
||||||
|
DEVICE = 'npu'
|
||||||
|
elif is_cuda_available():
|
||||||
|
DEVICE = 'cuda'
|
||||||
|
elif is_mlu_available():
|
||||||
|
DEVICE = 'mlu'
|
||||||
|
elif is_mps_available():
|
||||||
|
DEVICE = 'mps'
|
||||||
|
|
||||||
|
|
||||||
def get_device() -> str:
|
def get_device() -> str:
|
||||||
"""Returns the currently existing device type.
|
"""Returns the currently existing device type.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: cuda | npu | mlu | mps | cpu.
|
str: cuda | npu | mlu | mps | cpu.
|
||||||
"""
|
"""
|
||||||
if is_npu_available():
|
return DEVICE
|
||||||
return 'npu'
|
|
||||||
elif is_cuda_available():
|
|
||||||
return 'cuda'
|
|
||||||
elif is_mlu_available():
|
|
||||||
return 'mlu'
|
|
||||||
elif is_mps_available():
|
|
||||||
return 'mps'
|
|
||||||
else:
|
|
||||||
return 'cpu'
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user