mmsegmentation/mmseg/utils/util_distribution.py
luomaoling ae78cb9d53
[Feature] Support mmseg with NPU backend. (#2768)
## Motivation

Added ascending device support in mmseg.

## Modification

The main modification points are as follows:
We added an NPU device in the DDP scenario and DP scenario when using
the NPU.

## BC-breaking (Optional)

None

## Use cases (Optional)

We tested
[fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py)
.
2023-03-23 19:42:49 +08:00

105 lines
3.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmseg import digit_version
dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}
ddp_factory = {'cuda': MMDistributedDataParallel}
def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel module; if device is mlu,
return a MLUDataParallel module.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
dim (int): Dimension used to scatter the data. Defaults to 0.
Returns:
:class:`nn.Module`: parallelized module.
"""
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
'Please use MMCV >= 1.5.0 for MLU training!'
from mmcv.device.mlu import MLUDataParallel
dp_factory['mlu'] = MLUDataParallel
model = model.mlu()
elif device == 'npu':
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
'Please use MMCV >= 1.7.0 for NPU training!'
from mmcv.device.npu import NPUDataParallel
torch.npu.set_compile_mode(jit_compile=False)
dp_factory['npu'] = NPUDataParallel
model = model.npu()
return dp_factory[device](model, dim=dim, *args, **kwargs)
def build_ddp(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel module;
if device is mlu, return a MLUDistributedDataParallel module.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, mlu or cuda.
Returns:
:class:`nn.Module`: parallelized module.
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'mlu', 'npu'], 'Only available for cuda, '\
'npu or mlu devices.'
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
'Please use MMCV >= 1.5.0 for MLU training!'
from mmcv.device.mlu import MLUDistributedDataParallel
ddp_factory['mlu'] = MLUDistributedDataParallel
model = model.mlu()
elif device == 'npu':
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
'Please use MMCV >= 1.7.0 for NPU training!'
from mmcv.device.npu import NPUDistributedDataParallel
torch.npu.set_compile_mode(jit_compile=False)
ddp_factory['npu'] = NPUDistributedDataParallel
model = model.npu()
return ddp_factory[device](model, *args, **kwargs)
def is_mlu_available():
"""Returns a bool indicating if MLU is currently available."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def is_npu_available():
"""Returns a bool indicating if NPU is currently available."""
return hasattr(torch, 'npu') and torch.npu.is_available()
def get_device():
"""Returns an available device, cpu, npu, cuda or mlu."""
is_device_available = {
'npu': is_npu_available(),
'cuda': torch.cuda.is_available(),
'mlu': is_mlu_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) >= 1 else 'cpu'