From 7628a61f920b6bd77f0dfe5e3074193ea3d13221 Mon Sep 17 00:00:00 2001 From: "Aston.He" <48208796+alpha-baymax@users.noreply.github.com> Date: Wed, 25 May 2022 18:11:42 +0800 Subject: [PATCH] feat(mlu): Support PyTorch backend on MLU. (#1515) * feat(mlu): Support PyTorch backend on MLU. * fix redundant device variable. * Update mmseg/apis/train.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update comments. * Update mmseg/apis/train.py * Update is_mlu_available flag. * align util_distribution.py to mmdet. * align util_distribution.py to mmdet. * add build_dp, build_ddp testcase. * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add mmcv version check for mlu device. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmseg/apis/train.py | 19 +++-- .../datasets/samplers/distributed_sampler.py | 4 +- mmseg/utils/__init__.py | 3 +- mmseg/utils/util_distribution.py | 81 +++++++++++++++++++ tests/test_utils/test_util_distribution.py | 68 ++++++++++++++++ tools/test.py | 13 +-- tools/train.py | 6 +- 7 files changed, 176 insertions(+), 18 deletions(-) create mode 100644 mmseg/utils/util_distribution.py create mode 100644 tests/test_utils/test_util_distribution.py diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index 3563e3620..be8e422b3 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os import random import warnings @@ -6,7 +7,6 @@ import mmcv import numpy as np import torch import torch.distributed as dist -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, build_runner, get_dist_info) from mmcv.utils import build_from_cfg @@ -14,7 +14,8 @@ from mmcv.utils import build_from_cfg from mmseg import digit_version from mmseg.core import DistEvalHook, EvalHook, build_optimizer from mmseg.datasets import build_dataloader, build_dataset -from mmseg.utils import find_latest_checkpoint, get_root_logger +from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint, + get_root_logger) def init_random_seed(seed=None, device='cuda'): @@ -99,21 +100,23 @@ def train_segmentor(model, train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})} data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] - # put model on gpus + # put model on devices if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in - # torch.nn.parallel.DistributedDataParallel - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], + # DDP wrapper + model = build_ddp( + model, + cfg.device, + device_ids=[int(os.environ['LOCAL_RANK'])], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: if not torch.cuda.is_available(): assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ 'Please use MMCV >= 1.4.4 for CPU training!' - model = MMDataParallel(model, device_ids=cfg.gpu_ids) + model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) + # build runner optimizer = build_optimizer(model, cfg.optimizer) diff --git a/mmseg/datasets/samplers/distributed_sampler.py b/mmseg/datasets/samplers/distributed_sampler.py index d1a13c716..4f9bf3579 100644 --- a/mmseg/datasets/samplers/distributed_sampler.py +++ b/mmseg/datasets/samplers/distributed_sampler.py @@ -7,6 +7,7 @@ from torch.utils.data import Dataset from torch.utils.data import DistributedSampler as _DistributedSampler from mmseg.core.utils import sync_random_seed +from mmseg.utils import get_device class DistributedSampler(_DistributedSampler): @@ -41,7 +42,8 @@ class DistributedSampler(_DistributedSampler): # in the same order based on the same seed. Then different ranks # could use different indices to select non-overlapped data from the # same data list. - self.seed = sync_random_seed(seed) + device = get_device() + self.seed = sync_random_seed(seed, device) def __iter__(self) -> Iterator: """ diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index ed002c7de..e3ef4b355 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -3,8 +3,9 @@ from .collect_env import collect_env from .logger import get_root_logger from .misc import find_latest_checkpoint from .set_env import setup_multi_processes +from .util_distribution import build_ddp, build_dp, get_device __all__ = [ 'get_root_logger', 'collect_env', 'find_latest_checkpoint', - 'setup_multi_processes' + 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device' ] diff --git a/mmseg/utils/util_distribution.py b/mmseg/utils/util_distribution.py new file mode 100644 index 000000000..16651c225 --- /dev/null +++ b/mmseg/utils/util_distribution.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel + +from mmseg import digit_version + +dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} + +ddp_factory = {'cuda': MMDistributedDataParallel} + + +def build_dp(model, device='cuda', dim=0, *args, **kwargs): + """build DataParallel module by device type. + + if device is cuda, return a MMDataParallel module; if device is mlu, + return a MLUDataParallel module. + + Args: + model (:class:`nn.Module`): module to be parallelized. + device (str): device type, cuda, cpu or mlu. Defaults to cuda. + dim (int): Dimension used to scatter the data. Defaults to 0. + + Returns: + :class:`nn.Module`: parallelized module. + """ + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ + 'Please use MMCV >= 1.5.0 for MLU training!' + from mmcv.device.mlu import MLUDataParallel + dp_factory['mlu'] = MLUDataParallel + model = model.mlu() + + return dp_factory[device](model, dim=dim, *args, **kwargs) + + +def build_ddp(model, device='cuda', *args, **kwargs): + """Build DistributedDataParallel module by device type. + + If device is cuda, return a MMDistributedDataParallel module; + if device is mlu, return a MLUDistributedDataParallel module. + + Args: + model (:class:`nn.Module`): module to be parallelized. + device (str): device type, mlu or cuda. + + Returns: + :class:`nn.Module`: parallelized module. + + References: + .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. + DistributedDataParallel.html + """ + assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ + 'Please use MMCV >= 1.5.0 for MLU training!' + from mmcv.device.mlu import MLUDistributedDataParallel + ddp_factory['mlu'] = MLUDistributedDataParallel + model = model.mlu() + + return ddp_factory[device](model, *args, **kwargs) + + +def is_mlu_available(): + """Returns a bool indicating if MLU is currently available.""" + return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + + +def get_device(): + """Returns an available device, cpu, cuda or mlu.""" + is_device_available = { + 'cuda': torch.cuda.is_available(), + 'mlu': is_mlu_available() + } + device_list = [k for k, v in is_device_available.items() if v] + return device_list[0] if len(device_list) == 1 else 'cpu' diff --git a/tests/test_utils/test_util_distribution.py b/tests/test_utils/test_util_distribution.py new file mode 100644 index 000000000..103d1d6ba --- /dev/null +++ b/tests/test_utils/test_util_distribution.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import MagicMock, patch + +import mmcv +import torch +import torch.nn as nn +from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel, + is_module_wrapper) + +from mmseg import digit_version +from mmseg.utils import build_ddp, build_dp + + +def mock(*args, **kwargs): + pass + + +class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + + def forward(self, x): + return self.conv(x) + + +@patch('torch.distributed._broadcast_coalesced', mock) +@patch('torch.distributed.broadcast', mock) +@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +def test_build_dp(): + model = Model() + assert not is_module_wrapper(model) + + mmdp = build_dp(model, 'cpu') + assert isinstance(mmdp, MMDataParallel) + + if torch.cuda.is_available(): + mmdp = build_dp(model, 'cuda') + assert isinstance(mmdp, MMDataParallel) + + if digit_version(mmcv.__version__) >= digit_version('1.5.0'): + from mmcv.device.mlu import MLUDataParallel + from mmcv.utils import IS_MLU_AVAILABLE + if IS_MLU_AVAILABLE: + mludp = build_dp(model, 'mlu') + assert isinstance(mludp, MLUDataParallel) + + +@patch('torch.distributed._broadcast_coalesced', mock) +@patch('torch.distributed.broadcast', mock) +@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +def test_build_ddp(): + model = Model() + assert not is_module_wrapper(model) + + if torch.cuda.is_available(): + mmddp = build_ddp( + model, 'cuda', device_id=[0], process_group=MagicMock()) + assert isinstance(mmddp, MMDistributedDataParallel) + + if digit_version(mmcv.__version__) >= digit_version('1.5.0'): + from mmcv.device.mlu import MLUDistributedDataParallel + from mmcv.utils import IS_MLU_AVAILABLE + if IS_MLU_AVAILABLE: + mluddp = build_ddp( + model, 'mlu', device_ids=[0], process_group=MagicMock()) + assert isinstance(mluddp, MLUDistributedDataParallel) diff --git a/tools/test.py b/tools/test.py index 12892ec9b..a643b08be 100644 --- a/tools/test.py +++ b/tools/test.py @@ -9,7 +9,6 @@ import warnings import mmcv import torch from mmcv.cnn.utils import revert_sync_batchnorm -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, wrap_fp16_model) from mmcv.utils import DictAction @@ -18,7 +17,7 @@ from mmseg import digit_version from mmseg.apis import multi_gpu_test, single_gpu_test from mmseg.datasets import build_dataloader, build_dataset from mmseg.models import build_segmentor -from mmseg.utils import setup_multi_processes +from mmseg.utils import build_ddp, build_dp, get_device, setup_multi_processes def parse_args(): @@ -260,6 +259,7 @@ def main(): else: tmpdir = None + cfg.device = get_device() if not distributed: warnings.warn( 'SyncBN is only supported with DDP. To be compatible with DP, ' @@ -269,7 +269,7 @@ def main(): assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ 'Please use MMCV >= 1.4.4 for CPU training!' model = revert_sync_batchnorm(model) - model = MMDataParallel(model, device_ids=cfg.gpu_ids) + model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) results = single_gpu_test( model, data_loader, @@ -281,9 +281,10 @@ def main(): format_only=args.format_only or eval_on_format_results, format_args=eval_kwargs) else: - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], + model = build_ddp( + model, + cfg.device, + device_ids=[int(os.environ['LOCAL_RANK'])], broadcast_buffers=False) results = multi_gpu_test( model, diff --git a/tools/train.py b/tools/train.py index 6e7adc8d6..c4219b04b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -17,7 +17,8 @@ from mmseg import __version__ from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor -from mmseg.utils import collect_env, get_root_logger, setup_multi_processes +from mmseg.utils import (collect_env, get_device, get_root_logger, + setup_multi_processes) def parse_args(): @@ -184,7 +185,8 @@ def main(): logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds - seed = init_random_seed(args.seed) + cfg.device = get_device() + seed = init_random_seed(args.seed, device=cfg.device) seed = seed + dist.get_rank() if args.diff_seed else seed logger.info(f'Set random seed to {seed}, ' f'deterministic: {args.deterministic}')