mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support mmseg with NPU backend. (#2768)
## Motivation Added ascending device support in mmseg. ## Modification The main modification points are as follows: We added an NPU device in the DDP scenario and DP scenario when using the NPU. ## BC-breaking (Optional) None ## Use cases (Optional) We tested [fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py) .
This commit is contained in:
parent
49f2a71953
commit
ae78cb9d53
@ -136,6 +136,11 @@ def train_segmentor(model,
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
meta=meta))
|
meta=meta))
|
||||||
|
|
||||||
|
if cfg.device == 'npu':
|
||||||
|
optimiter_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
|
||||||
|
cfg.optimizer_config = optimiter_config if \
|
||||||
|
not cfg.optimizer_config else cfg.optimizer_config
|
||||||
|
|
||||||
# register hooks
|
# register hooks
|
||||||
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
||||||
cfg.checkpoint_config, cfg.log_config,
|
cfg.checkpoint_config, cfg.log_config,
|
||||||
@ -187,8 +192,12 @@ def train_segmentor(model,
|
|||||||
resume_from = find_latest_checkpoint(cfg.work_dir)
|
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||||
if resume_from is not None:
|
if resume_from is not None:
|
||||||
cfg.resume_from = resume_from
|
cfg.resume_from = resume_from
|
||||||
|
|
||||||
if cfg.resume_from:
|
if cfg.resume_from:
|
||||||
runner.resume(cfg.resume_from)
|
if cfg.device == 'npu':
|
||||||
|
runner.resume(cfg.resume_from, map_location='npu')
|
||||||
|
else:
|
||||||
|
runner.resume(cfg.resume_from)
|
||||||
elif cfg.load_from:
|
elif cfg.load_from:
|
||||||
runner.load_checkpoint(cfg.load_from)
|
runner.load_checkpoint(cfg.load_from)
|
||||||
runner.run(data_loaders, cfg.workflow)
|
runner.run(data_loaders, cfg.workflow)
|
||||||
|
@ -33,6 +33,14 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
|
|||||||
dp_factory['mlu'] = MLUDataParallel
|
dp_factory['mlu'] = MLUDataParallel
|
||||||
model = model.mlu()
|
model = model.mlu()
|
||||||
|
|
||||||
|
elif device == 'npu':
|
||||||
|
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
|
||||||
|
'Please use MMCV >= 1.7.0 for NPU training!'
|
||||||
|
from mmcv.device.npu import NPUDataParallel
|
||||||
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
|
dp_factory['npu'] = NPUDataParallel
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
return dp_factory[device](model, dim=dim, *args, **kwargs)
|
return dp_factory[device](model, dim=dim, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -53,7 +61,8 @@ def build_ddp(model, device='cuda', *args, **kwargs):
|
|||||||
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
|
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
|
||||||
DistributedDataParallel.html
|
DistributedDataParallel.html
|
||||||
"""
|
"""
|
||||||
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.'
|
assert device in ['cuda', 'mlu', 'npu'], 'Only available for cuda, '\
|
||||||
|
'npu or mlu devices.'
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
elif device == 'mlu':
|
elif device == 'mlu':
|
||||||
@ -63,6 +72,14 @@ def build_ddp(model, device='cuda', *args, **kwargs):
|
|||||||
ddp_factory['mlu'] = MLUDistributedDataParallel
|
ddp_factory['mlu'] = MLUDistributedDataParallel
|
||||||
model = model.mlu()
|
model = model.mlu()
|
||||||
|
|
||||||
|
elif device == 'npu':
|
||||||
|
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
|
||||||
|
'Please use MMCV >= 1.7.0 for NPU training!'
|
||||||
|
from mmcv.device.npu import NPUDistributedDataParallel
|
||||||
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
|
ddp_factory['npu'] = NPUDistributedDataParallel
|
||||||
|
model = model.npu()
|
||||||
|
|
||||||
return ddp_factory[device](model, *args, **kwargs)
|
return ddp_factory[device](model, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -71,11 +88,17 @@ def is_mlu_available():
|
|||||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||||
|
|
||||||
|
|
||||||
|
def is_npu_available():
|
||||||
|
"""Returns a bool indicating if NPU is currently available."""
|
||||||
|
return hasattr(torch, 'npu') and torch.npu.is_available()
|
||||||
|
|
||||||
|
|
||||||
def get_device():
|
def get_device():
|
||||||
"""Returns an available device, cpu, cuda or mlu."""
|
"""Returns an available device, cpu, npu, cuda or mlu."""
|
||||||
is_device_available = {
|
is_device_available = {
|
||||||
|
'npu': is_npu_available(),
|
||||||
'cuda': torch.cuda.is_available(),
|
'cuda': torch.cuda.is_available(),
|
||||||
'mlu': is_mlu_available()
|
'mlu': is_mlu_available()
|
||||||
}
|
}
|
||||||
device_list = [k for k, v in is_device_available.items() if v]
|
device_list = [k for k, v in is_device_available.items() if v]
|
||||||
return device_list[0] if len(device_list) == 1 else 'cpu'
|
return device_list[0] if len(device_list) >= 1 else 'cpu'
|
||||||
|
@ -46,6 +46,13 @@ def test_build_dp():
|
|||||||
mludp = build_dp(model, 'mlu')
|
mludp = build_dp(model, 'mlu')
|
||||||
assert isinstance(mludp, MLUDataParallel)
|
assert isinstance(mludp, MLUDataParallel)
|
||||||
|
|
||||||
|
if digit_version(mmcv.__version__) >= digit_version('1.7.0'):
|
||||||
|
from mmcv.device.npu import NPUDataParallel
|
||||||
|
from mmcv.utils import IS_NPU_AVAILABLE
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
npu_dp = model.npu(model, 'npu')
|
||||||
|
assert isinstance(npu_dp, NPUDataParallel)
|
||||||
|
|
||||||
|
|
||||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||||
@patch('torch.distributed.broadcast', mock)
|
@patch('torch.distributed.broadcast', mock)
|
||||||
@ -66,3 +73,11 @@ def test_build_ddp():
|
|||||||
mluddp = build_ddp(
|
mluddp = build_ddp(
|
||||||
model, 'mlu', device_ids=[0], process_group=MagicMock())
|
model, 'mlu', device_ids=[0], process_group=MagicMock())
|
||||||
assert isinstance(mluddp, MLUDistributedDataParallel)
|
assert isinstance(mluddp, MLUDistributedDataParallel)
|
||||||
|
|
||||||
|
if digit_version(mmcv.__version__) >= digit_version('1.7.0'):
|
||||||
|
from mmcv.device.npu import NPUDistributedDataParallel
|
||||||
|
from mmcv.utils import IS_NPU_AVAILABLE
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
npu_ddp = build_ddp(
|
||||||
|
model, 'npu', device_ids=[0], process_group=MagicMock())
|
||||||
|
assert isinstance(npu_ddp, NPUDistributedDataParallel)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user