[Feature] Support mmcls with NPU backend. (#1072)
* init npu * Avoid to import latest MMCV code to be compatible with old verisons. Co-authored-by: mzr1996 <mzr1996@163.com>pull/1159/head
parent
38040d5e05
commit
17ed870fd1
|
@ -131,7 +131,6 @@ def train_model(model,
|
|||
model = wrap_distributed_model(
|
||||
model,
|
||||
cfg.device,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=find_unused_parameters)
|
||||
else:
|
||||
|
@ -173,6 +172,10 @@ def train_model(model,
|
|||
|
||||
# fp16 setting
|
||||
fp16_cfg = cfg.get('fp16', None)
|
||||
|
||||
if fp16_cfg is None and device == 'npu':
|
||||
fp16_cfg = {'loss_scale': 'dynamic'}
|
||||
|
||||
if fp16_cfg is not None:
|
||||
if device == 'ipu':
|
||||
from mmcv.device.ipu import IPUFp16OptimizerHook
|
||||
|
|
|
@ -4,6 +4,7 @@ from torch.utils.data import DistributedSampler as _DistributedSampler
|
|||
|
||||
from mmcls.core.utils import sync_random_seed
|
||||
from mmcls.datasets import SAMPLERS
|
||||
from mmcls.utils import auto_select_device
|
||||
|
||||
|
||||
@SAMPLERS.register_module()
|
||||
|
@ -30,7 +31,7 @@ class DistributedSampler(_DistributedSampler):
|
|||
# in the same order based on the same seed. Then different ranks
|
||||
# could use different indices to select non-overlapped data from the
|
||||
# same data list.
|
||||
self.seed = sync_random_seed(seed)
|
||||
self.seed = sync_random_seed(seed, device=auto_select_device())
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
|
|
|
@ -16,7 +16,10 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
|
|||
Returns:
|
||||
model(nn.Module): the model to be parallelized.
|
||||
"""
|
||||
if device == 'cuda':
|
||||
if device == 'npu':
|
||||
from mmcv.device.npu import NPUDataParallel
|
||||
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
|
||||
elif device == 'cuda':
|
||||
from mmcv.parallel import MMDataParallel
|
||||
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
|
||||
elif device == 'cpu':
|
||||
|
@ -49,9 +52,16 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
|
|||
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
|
||||
DistributedDataParallel.html
|
||||
"""
|
||||
if device == 'cuda':
|
||||
if device == 'npu':
|
||||
from mmcv.device.npu import NPUDistributedDataParallel
|
||||
from torch.npu import current_device
|
||||
model = NPUDistributedDataParallel(
|
||||
model.npu(), *args, device_ids=[current_device()], **kwargs)
|
||||
elif device == 'cuda':
|
||||
from mmcv.parallel import MMDistributedDataParallel
|
||||
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
|
||||
from torch.cuda import current_device
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(), *args, device_ids=[current_device()], **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Unavailable device "{device}"')
|
||||
|
||||
|
|
Loading…
Reference in New Issue