mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support mmseg with NPU backend. (#2768)
## Motivation Added ascending device support in mmseg. ## Modification The main modification points are as follows: We added an NPU device in the DDP scenario and DP scenario when using the NPU. ## BC-breaking (Optional) None ## Use cases (Optional) We tested [fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py) .
This commit is contained in:
parent
49f2a71953
commit
ae78cb9d53
@ -136,6 +136,11 @@ def train_segmentor(model,
|
||||
logger=logger,
|
||||
meta=meta))
|
||||
|
||||
if cfg.device == 'npu':
|
||||
optimiter_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
|
||||
cfg.optimizer_config = optimiter_config if \
|
||||
not cfg.optimizer_config else cfg.optimizer_config
|
||||
|
||||
# register hooks
|
||||
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
||||
cfg.checkpoint_config, cfg.log_config,
|
||||
@ -187,7 +192,11 @@ def train_segmentor(model,
|
||||
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||
if resume_from is not None:
|
||||
cfg.resume_from = resume_from
|
||||
|
||||
if cfg.resume_from:
|
||||
if cfg.device == 'npu':
|
||||
runner.resume(cfg.resume_from, map_location='npu')
|
||||
else:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
runner.load_checkpoint(cfg.load_from)
|
||||
|
@ -33,6 +33,14 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
|
||||
dp_factory['mlu'] = MLUDataParallel
|
||||
model = model.mlu()
|
||||
|
||||
elif device == 'npu':
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
|
||||
'Please use MMCV >= 1.7.0 for NPU training!'
|
||||
from mmcv.device.npu import NPUDataParallel
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
dp_factory['npu'] = NPUDataParallel
|
||||
model = model.npu()
|
||||
|
||||
return dp_factory[device](model, dim=dim, *args, **kwargs)
|
||||
|
||||
|
||||
@ -53,7 +61,8 @@ def build_ddp(model, device='cuda', *args, **kwargs):
|
||||
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
|
||||
DistributedDataParallel.html
|
||||
"""
|
||||
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.'
|
||||
assert device in ['cuda', 'mlu', 'npu'], 'Only available for cuda, '\
|
||||
'npu or mlu devices.'
|
||||
if device == 'cuda':
|
||||
model = model.cuda()
|
||||
elif device == 'mlu':
|
||||
@ -63,6 +72,14 @@ def build_ddp(model, device='cuda', *args, **kwargs):
|
||||
ddp_factory['mlu'] = MLUDistributedDataParallel
|
||||
model = model.mlu()
|
||||
|
||||
elif device == 'npu':
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
|
||||
'Please use MMCV >= 1.7.0 for NPU training!'
|
||||
from mmcv.device.npu import NPUDistributedDataParallel
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
ddp_factory['npu'] = NPUDistributedDataParallel
|
||||
model = model.npu()
|
||||
|
||||
return ddp_factory[device](model, *args, **kwargs)
|
||||
|
||||
|
||||
@ -71,11 +88,17 @@ def is_mlu_available():
|
||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
|
||||
|
||||
def is_npu_available():
|
||||
"""Returns a bool indicating if NPU is currently available."""
|
||||
return hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Returns an available device, cpu, cuda or mlu."""
|
||||
"""Returns an available device, cpu, npu, cuda or mlu."""
|
||||
is_device_available = {
|
||||
'npu': is_npu_available(),
|
||||
'cuda': torch.cuda.is_available(),
|
||||
'mlu': is_mlu_available()
|
||||
}
|
||||
device_list = [k for k, v in is_device_available.items() if v]
|
||||
return device_list[0] if len(device_list) == 1 else 'cpu'
|
||||
return device_list[0] if len(device_list) >= 1 else 'cpu'
|
||||
|
@ -46,6 +46,13 @@ def test_build_dp():
|
||||
mludp = build_dp(model, 'mlu')
|
||||
assert isinstance(mludp, MLUDataParallel)
|
||||
|
||||
if digit_version(mmcv.__version__) >= digit_version('1.7.0'):
|
||||
from mmcv.device.npu import NPUDataParallel
|
||||
from mmcv.utils import IS_NPU_AVAILABLE
|
||||
if IS_NPU_AVAILABLE:
|
||||
npu_dp = model.npu(model, 'npu')
|
||||
assert isinstance(npu_dp, NPUDataParallel)
|
||||
|
||||
|
||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||
@patch('torch.distributed.broadcast', mock)
|
||||
@ -66,3 +73,11 @@ def test_build_ddp():
|
||||
mluddp = build_ddp(
|
||||
model, 'mlu', device_ids=[0], process_group=MagicMock())
|
||||
assert isinstance(mluddp, MLUDistributedDataParallel)
|
||||
|
||||
if digit_version(mmcv.__version__) >= digit_version('1.7.0'):
|
||||
from mmcv.device.npu import NPUDistributedDataParallel
|
||||
from mmcv.utils import IS_NPU_AVAILABLE
|
||||
if IS_NPU_AVAILABLE:
|
||||
npu_ddp = build_ddp(
|
||||
model, 'npu', device_ids=[0], process_group=MagicMock())
|
||||
assert isinstance(npu_ddp, NPUDistributedDataParallel)
|
||||
|
Loading…
x
Reference in New Issue
Block a user