feat(mlu): Support PyTorch backend on MLU. (#1515)

* feat(mlu): Support PyTorch backend on MLU.

* fix redundant device variable.

* Update mmseg/apis/train.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update comments.

* Update mmseg/apis/train.py

* Update is_mlu_available flag.

* align util_distribution.py to mmdet.

* align util_distribution.py to mmdet.

* add build_dp, build_ddp testcase.

* Update mmseg/utils/util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmseg/utils/util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmseg/utils/util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add mmcv version check for mlu device.

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Aston.He 2022-05-25 18:11:42 +08:00 committed by GitHub
parent aa50358c71
commit 7628a61f92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 176 additions and 18 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
import warnings
@ -6,7 +7,6 @@ import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg
@ -14,7 +14,8 @@ from mmcv.utils import build_from_cfg
from mmseg import digit_version
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import find_latest_checkpoint, get_root_logger
from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint,
get_root_logger)
def init_random_seed(seed=None, device='cuda'):
@ -99,21 +100,23 @@ def train_segmentor(model,
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
# put model on devices
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
# DDP wrapper
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)

View File

@ -7,6 +7,7 @@ from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _DistributedSampler
from mmseg.core.utils import sync_random_seed
from mmseg.utils import get_device
class DistributedSampler(_DistributedSampler):
@ -41,7 +42,8 @@ class DistributedSampler(_DistributedSampler):
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
device = get_device()
self.seed = sync_random_seed(seed, device)
def __iter__(self) -> Iterator:
"""

View File

@ -3,8 +3,9 @@ from .collect_env import collect_env
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .set_env import setup_multi_processes
from .util_distribution import build_ddp, build_dp, get_device
__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'setup_multi_processes'
'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device'
]

View File

@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmseg import digit_version
dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}
ddp_factory = {'cuda': MMDistributedDataParallel}
def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel module; if device is mlu,
return a MLUDataParallel module.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
dim (int): Dimension used to scatter the data. Defaults to 0.
Returns:
:class:`nn.Module`: parallelized module.
"""
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
'Please use MMCV >= 1.5.0 for MLU training!'
from mmcv.device.mlu import MLUDataParallel
dp_factory['mlu'] = MLUDataParallel
model = model.mlu()
return dp_factory[device](model, dim=dim, *args, **kwargs)
def build_ddp(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel module;
if device is mlu, return a MLUDistributedDataParallel module.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, mlu or cuda.
Returns:
:class:`nn.Module`: parallelized module.
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.'
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
'Please use MMCV >= 1.5.0 for MLU training!'
from mmcv.device.mlu import MLUDistributedDataParallel
ddp_factory['mlu'] = MLUDistributedDataParallel
model = model.mlu()
return ddp_factory[device](model, *args, **kwargs)
def is_mlu_available():
"""Returns a bool indicating if MLU is currently available."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def get_device():
"""Returns an available device, cpu, cuda or mlu."""
is_device_available = {
'cuda': torch.cuda.is_available(),
'mlu': is_mlu_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) == 1 else 'cpu'

View File

@ -0,0 +1,68 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch
import mmcv
import torch
import torch.nn as nn
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel,
is_module_wrapper)
from mmseg import digit_version
from mmseg.utils import build_ddp, build_dp
def mock(*args, **kwargs):
pass
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_build_dp():
model = Model()
assert not is_module_wrapper(model)
mmdp = build_dp(model, 'cpu')
assert isinstance(mmdp, MMDataParallel)
if torch.cuda.is_available():
mmdp = build_dp(model, 'cuda')
assert isinstance(mmdp, MMDataParallel)
if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
from mmcv.device.mlu import MLUDataParallel
from mmcv.utils import IS_MLU_AVAILABLE
if IS_MLU_AVAILABLE:
mludp = build_dp(model, 'mlu')
assert isinstance(mludp, MLUDataParallel)
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_build_ddp():
model = Model()
assert not is_module_wrapper(model)
if torch.cuda.is_available():
mmddp = build_ddp(
model, 'cuda', device_id=[0], process_group=MagicMock())
assert isinstance(mmddp, MMDistributedDataParallel)
if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
from mmcv.device.mlu import MLUDistributedDataParallel
from mmcv.utils import IS_MLU_AVAILABLE
if IS_MLU_AVAILABLE:
mluddp = build_ddp(
model, 'mlu', device_ids=[0], process_group=MagicMock())
assert isinstance(mluddp, MLUDistributedDataParallel)

View File

@ -9,7 +9,6 @@ import warnings
import mmcv
import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmcv.utils import DictAction
@ -18,7 +17,7 @@ from mmseg import digit_version
from mmseg.apis import multi_gpu_test, single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import setup_multi_processes
from mmseg.utils import build_ddp, build_dp, get_device, setup_multi_processes
def parse_args():
@ -260,6 +259,7 @@ def main():
else:
tmpdir = None
cfg.device = get_device()
if not distributed:
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
@ -269,7 +269,7 @@ def main():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = revert_sync_batchnorm(model)
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
results = single_gpu_test(
model,
data_loader,
@ -281,9 +281,10 @@ def main():
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
results = multi_gpu_test(
model,

View File

@ -17,7 +17,8 @@ from mmseg import __version__
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
from mmseg.utils import (collect_env, get_device, get_root_logger,
setup_multi_processes)
def parse_args():
@ -184,7 +185,8 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
seed = init_random_seed(args.seed)
cfg.device = get_device()
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')