[Feature] Support MLU backend. (#1159)

* Training on MLU is available
pull/1215/head
TangYueran 2022-11-15 17:02:16 +08:00 committed by GitHub
parent 05e4bc17b2
commit dc8691e889
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 7 deletions

View File

@ -10,11 +10,11 @@ from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import (get_root_logger, wrap_distributed_model,
wrap_non_distributed_model)
from mmcls.utils import (auto_select_device, get_root_logger,
wrap_distributed_model, wrap_non_distributed_model)
def init_random_seed(seed=None, device='cuda'):
def init_random_seed(seed=None, device=None):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
@ -30,7 +30,8 @@ def init_random_seed(seed=None, device='cuda'):
"""
if seed is not None:
return seed
if device is None:
device = auto_select_device()
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339

View File

@ -8,6 +8,8 @@ from mmcv.runner import OptimizerHook, get_dist_info
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcls.utils import auto_select_device
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
@ -59,7 +61,7 @@ class DistOptimizerHook(OptimizerHook):
runner.optimizer.step()
def sync_random_seed(seed=None, device='cuda'):
def sync_random_seed(seed=None, device=None):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
@ -81,6 +83,8 @@ def sync_random_seed(seed=None, device='cuda'):
Returns:
int: Seed to be used.
"""
if device is None:
device = auto_select_device()
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

View File

@ -4,7 +4,6 @@ from torch.utils.data import DistributedSampler as _DistributedSampler
from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS
from mmcls.utils import auto_select_device
@SAMPLERS.register_module()
@ -31,7 +30,7 @@ class DistributedSampler(_DistributedSampler):
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed, device=auto_select_device())
self.seed = sync_random_seed(seed)
def __iter__(self):
# deterministically shuffle based on epoch

View File

@ -19,6 +19,9 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'mlu':
from mmcv.device.mlu import MLUDataParallel
model = MLUDataParallel(model.mlu(), dim=dim, *args, **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
@ -57,6 +60,15 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
from torch.npu import current_device
model = NPUDistributedDataParallel(
model.npu(), *args, device_ids=[current_device()], **kwargs)
elif device == 'mlu':
import os
from mmcv.device.mlu import MLUDistributedDataParallel
model = MLUDistributedDataParallel(
model.mlu(),
*args,
device_ids=[int(os.environ['LOCAL_RANK'])],
**kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
from torch.cuda import current_device