[Feature] Support mmcls with NPU backend. (#1072)

* init npu

* Avoid to import latest MMCV code to be compatible with old verisons.

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1159/head
wangjiangben-hw 2022-10-24 11:45:14 +08:00 committed by GitHub
parent 38040d5e05
commit 17ed870fd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 5 deletions

View File

@ -131,7 +131,6 @@ def train_model(model,
model = wrap_distributed_model(
model,
cfg.device,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
@ -173,6 +172,10 @@ def train_model(model,
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is None and device == 'npu':
fp16_cfg = {'loss_scale': 'dynamic'}
if fp16_cfg is not None:
if device == 'ipu':
from mmcv.device.ipu import IPUFp16OptimizerHook

View File

@ -4,6 +4,7 @@ from torch.utils.data import DistributedSampler as _DistributedSampler
from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS
from mmcls.utils import auto_select_device
@SAMPLERS.register_module()
@ -30,7 +31,7 @@ class DistributedSampler(_DistributedSampler):
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.seed = sync_random_seed(seed, device=auto_select_device())
def __iter__(self):
# deterministically shuffle based on epoch

View File

@ -16,7 +16,10 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
Returns:
model(nn.Module): the model to be parallelized.
"""
if device == 'cuda':
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
elif device == 'cpu':
@ -49,9 +52,16 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
if device == 'cuda':
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
from torch.npu import current_device
model = NPUDistributedDataParallel(
model.npu(), *args, device_ids=[current_device()], **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
from torch.cuda import current_device
model = MMDistributedDataParallel(
model.cuda(), *args, device_ids=[current_device()], **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')