mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
feat(mlu): Support PyTorch backend on MLU. (#1515)
* feat(mlu): Support PyTorch backend on MLU. * fix redundant device variable. * Update mmseg/apis/train.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update comments. * Update mmseg/apis/train.py * Update is_mlu_available flag. * align util_distribution.py to mmdet. * align util_distribution.py to mmdet. * add build_dp, build_ddp testcase. * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add mmcv version check for mlu device. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
aa50358c71
commit
7628a61f92
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
|
||||
@ -6,7 +7,6 @@ import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||
build_runner, get_dist_info)
|
||||
from mmcv.utils import build_from_cfg
|
||||
@ -14,7 +14,8 @@ from mmcv.utils import build_from_cfg
|
||||
from mmseg import digit_version
|
||||
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
||||
from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint,
|
||||
get_root_logger)
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
@ -99,21 +100,23 @@ def train_segmentor(model,
|
||||
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
|
||||
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||
|
||||
# put model on gpus
|
||||
# put model on devices
|
||||
if distributed:
|
||||
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
||||
# Sets the `find_unused_parameters` parameter in
|
||||
# torch.nn.parallel.DistributedDataParallel
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
# DDP wrapper
|
||||
model = build_ddp(
|
||||
model,
|
||||
cfg.device,
|
||||
device_ids=[int(os.environ['LOCAL_RANK'])],
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=find_unused_parameters)
|
||||
else:
|
||||
if not torch.cuda.is_available():
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
||||
'Please use MMCV >= 1.4.4 for CPU training!'
|
||||
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
|
||||
|
||||
# build runner
|
||||
optimizer = build_optimizer(model, cfg.optimizer)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from torch.utils.data import Dataset
|
||||
from torch.utils.data import DistributedSampler as _DistributedSampler
|
||||
|
||||
from mmseg.core.utils import sync_random_seed
|
||||
from mmseg.utils import get_device
|
||||
|
||||
|
||||
class DistributedSampler(_DistributedSampler):
|
||||
@ -41,7 +42,8 @@ class DistributedSampler(_DistributedSampler):
|
||||
# in the same order based on the same seed. Then different ranks
|
||||
# could use different indices to select non-overlapped data from the
|
||||
# same data list.
|
||||
self.seed = sync_random_seed(seed)
|
||||
device = get_device()
|
||||
self.seed = sync_random_seed(seed, device)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""
|
||||
|
@ -3,8 +3,9 @@ from .collect_env import collect_env
|
||||
from .logger import get_root_logger
|
||||
from .misc import find_latest_checkpoint
|
||||
from .set_env import setup_multi_processes
|
||||
from .util_distribution import build_ddp, build_dp, get_device
|
||||
|
||||
__all__ = [
|
||||
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
|
||||
'setup_multi_processes'
|
||||
'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device'
|
||||
]
|
||||
|
81
mmseg/utils/util_distribution.py
Normal file
81
mmseg/utils/util_distribution.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
|
||||
from mmseg import digit_version
|
||||
|
||||
dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}
|
||||
|
||||
ddp_factory = {'cuda': MMDistributedDataParallel}
|
||||
|
||||
|
||||
def build_dp(model, device='cuda', dim=0, *args, **kwargs):
|
||||
"""build DataParallel module by device type.
|
||||
|
||||
if device is cuda, return a MMDataParallel module; if device is mlu,
|
||||
return a MLUDataParallel module.
|
||||
|
||||
Args:
|
||||
model (:class:`nn.Module`): module to be parallelized.
|
||||
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
|
||||
dim (int): Dimension used to scatter the data. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
:class:`nn.Module`: parallelized module.
|
||||
"""
|
||||
if device == 'cuda':
|
||||
model = model.cuda()
|
||||
elif device == 'mlu':
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
|
||||
'Please use MMCV >= 1.5.0 for MLU training!'
|
||||
from mmcv.device.mlu import MLUDataParallel
|
||||
dp_factory['mlu'] = MLUDataParallel
|
||||
model = model.mlu()
|
||||
|
||||
return dp_factory[device](model, dim=dim, *args, **kwargs)
|
||||
|
||||
|
||||
def build_ddp(model, device='cuda', *args, **kwargs):
|
||||
"""Build DistributedDataParallel module by device type.
|
||||
|
||||
If device is cuda, return a MMDistributedDataParallel module;
|
||||
if device is mlu, return a MLUDistributedDataParallel module.
|
||||
|
||||
Args:
|
||||
model (:class:`nn.Module`): module to be parallelized.
|
||||
device (str): device type, mlu or cuda.
|
||||
|
||||
Returns:
|
||||
:class:`nn.Module`: parallelized module.
|
||||
|
||||
References:
|
||||
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
|
||||
DistributedDataParallel.html
|
||||
"""
|
||||
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.'
|
||||
if device == 'cuda':
|
||||
model = model.cuda()
|
||||
elif device == 'mlu':
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
|
||||
'Please use MMCV >= 1.5.0 for MLU training!'
|
||||
from mmcv.device.mlu import MLUDistributedDataParallel
|
||||
ddp_factory['mlu'] = MLUDistributedDataParallel
|
||||
model = model.mlu()
|
||||
|
||||
return ddp_factory[device](model, *args, **kwargs)
|
||||
|
||||
|
||||
def is_mlu_available():
|
||||
"""Returns a bool indicating if MLU is currently available."""
|
||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Returns an available device, cpu, cuda or mlu."""
|
||||
is_device_available = {
|
||||
'cuda': torch.cuda.is_available(),
|
||||
'mlu': is_mlu_available()
|
||||
}
|
||||
device_list = [k for k, v in is_device_available.items() if v]
|
||||
return device_list[0] if len(device_list) == 1 else 'cpu'
|
68
tests/test_utils/test_util_distribution.py
Normal file
68
tests/test_utils/test_util_distribution.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel,
|
||||
is_module_wrapper)
|
||||
|
||||
from mmseg import digit_version
|
||||
from mmseg.utils import build_ddp, build_dp
|
||||
|
||||
|
||||
def mock(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||
@patch('torch.distributed.broadcast', mock)
|
||||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
|
||||
def test_build_dp():
|
||||
model = Model()
|
||||
assert not is_module_wrapper(model)
|
||||
|
||||
mmdp = build_dp(model, 'cpu')
|
||||
assert isinstance(mmdp, MMDataParallel)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
mmdp = build_dp(model, 'cuda')
|
||||
assert isinstance(mmdp, MMDataParallel)
|
||||
|
||||
if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
|
||||
from mmcv.device.mlu import MLUDataParallel
|
||||
from mmcv.utils import IS_MLU_AVAILABLE
|
||||
if IS_MLU_AVAILABLE:
|
||||
mludp = build_dp(model, 'mlu')
|
||||
assert isinstance(mludp, MLUDataParallel)
|
||||
|
||||
|
||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||
@patch('torch.distributed.broadcast', mock)
|
||||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
|
||||
def test_build_ddp():
|
||||
model = Model()
|
||||
assert not is_module_wrapper(model)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
mmddp = build_ddp(
|
||||
model, 'cuda', device_id=[0], process_group=MagicMock())
|
||||
assert isinstance(mmddp, MMDistributedDataParallel)
|
||||
|
||||
if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
|
||||
from mmcv.device.mlu import MLUDistributedDataParallel
|
||||
from mmcv.utils import IS_MLU_AVAILABLE
|
||||
if IS_MLU_AVAILABLE:
|
||||
mluddp = build_ddp(
|
||||
model, 'mlu', device_ids=[0], process_group=MagicMock())
|
||||
assert isinstance(mluddp, MLUDistributedDataParallel)
|
@ -9,7 +9,6 @@ import warnings
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
||||
wrap_fp16_model)
|
||||
from mmcv.utils import DictAction
|
||||
@ -18,7 +17,7 @@ from mmseg import digit_version
|
||||
from mmseg.apis import multi_gpu_test, single_gpu_test
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import setup_multi_processes
|
||||
from mmseg.utils import build_ddp, build_dp, get_device, setup_multi_processes
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -260,6 +259,7 @@ def main():
|
||||
else:
|
||||
tmpdir = None
|
||||
|
||||
cfg.device = get_device()
|
||||
if not distributed:
|
||||
warnings.warn(
|
||||
'SyncBN is only supported with DDP. To be compatible with DP, '
|
||||
@ -269,7 +269,7 @@ def main():
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
||||
'Please use MMCV >= 1.4.4 for CPU training!'
|
||||
model = revert_sync_batchnorm(model)
|
||||
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
|
||||
results = single_gpu_test(
|
||||
model,
|
||||
data_loader,
|
||||
@ -281,9 +281,10 @@ def main():
|
||||
format_only=args.format_only or eval_on_format_results,
|
||||
format_args=eval_kwargs)
|
||||
else:
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
model = build_ddp(
|
||||
model,
|
||||
cfg.device,
|
||||
device_ids=[int(os.environ['LOCAL_RANK'])],
|
||||
broadcast_buffers=False)
|
||||
results = multi_gpu_test(
|
||||
model,
|
||||
|
@ -17,7 +17,8 @@ from mmseg import __version__
|
||||
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
|
||||
from mmseg.datasets import build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
|
||||
from mmseg.utils import (collect_env, get_device, get_root_logger,
|
||||
setup_multi_processes)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -184,7 +185,8 @@ def main():
|
||||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
seed = init_random_seed(args.seed)
|
||||
cfg.device = get_device()
|
||||
seed = init_random_seed(args.seed, device=cfg.device)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
|
Loading…
x
Reference in New Issue
Block a user