mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support MLU Devices (#288)
* support mlu * add ut and refine docstring
This commit is contained in:
parent
e1ed5669d5
commit
d0d7174274
@ -1,4 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import get_max_cuda_memory
|
||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||
is_mlu_available)
|
||||
|
||||
__all__ = ['get_max_cuda_memory']
|
||||
__all__ = [
|
||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||
'is_mlu_available'
|
||||
]
|
||||
|
@ -25,3 +25,27 @@ def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||
device=device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
||||
|
||||
|
||||
def is_cuda_available() -> bool:
|
||||
"""Returns True if cuda devices exist."""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_mlu_available() -> bool:
|
||||
"""Returns True if Cambricon PyTorch and mlu devices exist."""
|
||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Returns the currently existing device type.
|
||||
|
||||
Returns:
|
||||
str: cuda | mlu | cpu.
|
||||
"""
|
||||
if is_cuda_available():
|
||||
return 'cuda'
|
||||
elif is_mlu_available():
|
||||
return 'mlu'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
19
mmengine/dist/utils.py
vendored
19
mmengine/dist/utils.py
vendored
@ -10,6 +10,7 @@ import torch.multiprocessing as mp
|
||||
from torch import Tensor
|
||||
from torch import distributed as torch_dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from mmengine.device import is_mlu_available
|
||||
|
||||
try:
|
||||
# for python < 3.10
|
||||
@ -76,9 +77,18 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
|
||||
"""
|
||||
# TODO: use local_rank instead of rank % num_gpus
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
torch_dist.init_process_group(backend=backend, **kwargs)
|
||||
if is_mlu_available():
|
||||
import torch_mlu # noqa: F401
|
||||
torch.mlu.set_device(rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='cncl',
|
||||
rank=rank,
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
**kwargs)
|
||||
else:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
torch_dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_mpi(backend, **kwargs) -> None:
|
||||
@ -425,6 +435,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
|
||||
backend = get_backend(group)
|
||||
if backend == torch_dist.Backend.NCCL:
|
||||
return torch.device('cuda', torch.cuda.current_device())
|
||||
elif backend == 'cncl':
|
||||
import torch_mlu # noqa: F401
|
||||
return torch.device('mlu', torch.mlu.current_device())
|
||||
else:
|
||||
# GLOO and MPI backends use cpu device by default
|
||||
return torch.device('cpu')
|
||||
|
@ -21,6 +21,7 @@ from torch.utils.data import DataLoader
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.data import pseudo_collate, worker_init_fn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||
master_only, sync_random_seed)
|
||||
from mmengine.evaluator import Evaluator
|
||||
@ -821,8 +822,7 @@ class Runner:
|
||||
return model
|
||||
|
||||
# Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training.
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
model = model.to(get_device())
|
||||
|
||||
if not self.distributed:
|
||||
return model
|
||||
|
12
tests/test_device/test_device.py
Normal file
12
tests/test_device/test_device.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.device import get_device, is_cuda_available, is_mlu_available
|
||||
|
||||
|
||||
def test_get_device():
|
||||
device = get_device()
|
||||
if is_cuda_available():
|
||||
assert device == 'cuda'
|
||||
elif is_mlu_available():
|
||||
assert device == 'mlu'
|
||||
else:
|
||||
assert device == 'cpu'
|
Loading…
x
Reference in New Issue
Block a user