From dc8691e889843e5f81ac9000757e9dfa887c9ce1 Mon Sep 17 00:00:00 2001 From: TangYueran <65158769+Qiza-lyhm@users.noreply.github.com> Date: Tue, 15 Nov 2022 17:02:16 +0800 Subject: [PATCH] [Feature] Support MLU backend. (#1159) * Training on MLU is available --- mmcls/apis/train.py | 9 +++++---- mmcls/core/utils/dist_utils.py | 6 +++++- mmcls/datasets/samplers/distributed_sampler.py | 3 +-- mmcls/utils/distribution.py | 12 ++++++++++++ 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index 909b116d..c240632c 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -10,11 +10,11 @@ from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook, from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook from mmcls.datasets import build_dataloader, build_dataset -from mmcls.utils import (get_root_logger, wrap_distributed_model, - wrap_non_distributed_model) +from mmcls.utils import (auto_select_device, get_root_logger, + wrap_distributed_model, wrap_non_distributed_model) -def init_random_seed(seed=None, device='cuda'): +def init_random_seed(seed=None, device=None): """Initialize random seed. If the seed is not set, the seed will be automatically randomized, @@ -30,7 +30,8 @@ def init_random_seed(seed=None, device='cuda'): """ if seed is not None: return seed - + if device is None: + device = auto_select_device() # Make sure all ranks share the same random seed to prevent # some potential bugs. Please refer to # https://github.com/open-mmlab/mmdetection/issues/6339 diff --git a/mmcls/core/utils/dist_utils.py b/mmcls/core/utils/dist_utils.py index 8912cea4..15cf13c4 100644 --- a/mmcls/core/utils/dist_utils.py +++ b/mmcls/core/utils/dist_utils.py @@ -8,6 +8,8 @@ from mmcv.runner import OptimizerHook, get_dist_info from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) +from mmcls.utils import auto_select_device + def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): if bucket_size_mb > 0: @@ -59,7 +61,7 @@ class DistOptimizerHook(OptimizerHook): runner.optimizer.step() -def sync_random_seed(seed=None, device='cuda'): +def sync_random_seed(seed=None, device=None): """Make sure different ranks share the same seed. All workers must call this function, otherwise it will deadlock. @@ -81,6 +83,8 @@ def sync_random_seed(seed=None, device='cuda'): Returns: int: Seed to be used. """ + if device is None: + device = auto_select_device() if seed is None: seed = np.random.randint(2**31) assert isinstance(seed, int) diff --git a/mmcls/datasets/samplers/distributed_sampler.py b/mmcls/datasets/samplers/distributed_sampler.py index 9e78c400..a38c5ac1 100644 --- a/mmcls/datasets/samplers/distributed_sampler.py +++ b/mmcls/datasets/samplers/distributed_sampler.py @@ -4,7 +4,6 @@ from torch.utils.data import DistributedSampler as _DistributedSampler from mmcls.core.utils import sync_random_seed from mmcls.datasets import SAMPLERS -from mmcls.utils import auto_select_device @SAMPLERS.register_module() @@ -31,7 +30,7 @@ class DistributedSampler(_DistributedSampler): # in the same order based on the same seed. Then different ranks # could use different indices to select non-overlapped data from the # same data list. - self.seed = sync_random_seed(seed, device=auto_select_device()) + self.seed = sync_random_seed(seed) def __iter__(self): # deterministically shuffle based on epoch diff --git a/mmcls/utils/distribution.py b/mmcls/utils/distribution.py index d57bd2b5..c6e4c724 100644 --- a/mmcls/utils/distribution.py +++ b/mmcls/utils/distribution.py @@ -19,6 +19,9 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): if device == 'npu': from mmcv.device.npu import NPUDataParallel model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs) + elif device == 'mlu': + from mmcv.device.mlu import MLUDataParallel + model = MLUDataParallel(model.mlu(), dim=dim, *args, **kwargs) elif device == 'cuda': from mmcv.parallel import MMDataParallel model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs) @@ -57,6 +60,15 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs): from torch.npu import current_device model = NPUDistributedDataParallel( model.npu(), *args, device_ids=[current_device()], **kwargs) + elif device == 'mlu': + import os + + from mmcv.device.mlu import MLUDistributedDataParallel + model = MLUDistributedDataParallel( + model.mlu(), + *args, + device_ids=[int(os.environ['LOCAL_RANK'])], + **kwargs) elif device == 'cuda': from mmcv.parallel import MMDistributedDataParallel from torch.cuda import current_device