mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support PyTorch backend on MLU (#1770)
* feat(MLU): Support PyTorch backend on MLU * MMCV support PyTorch backend on MLU * Add MLUDataParallel and MLUDistributedDataParallel * Add MLU operator support * [Fix]: Fix PR comments and add IS_MLU to get device available check * [Fix]: fix PR comments of dist_utils.py * [Doc] Rewrite annotations of functions. * [Docs] Rewrite annotation in distributed.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Docs] Fix lint Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1880/head
parent
cd9dcc19da
commit
4826a9b7e4
|
@ -45,7 +45,7 @@ jobs:
|
|||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pip install -r requirements/test.txt
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py
|
||||
|
||||
build_without_ops:
|
||||
runs-on: ubuntu-18.04
|
||||
|
|
|
@ -13,3 +13,4 @@ from .visualization import *
|
|||
# - runner
|
||||
# - parallel
|
||||
# - op
|
||||
# - device
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import mlu
|
||||
|
||||
__all__ = ['mlu']
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .data_parallel import MLUDataParallel
|
||||
from .distributed import MLUDistributedDataParallel
|
||||
from .scatter_gather import scatter, scatter_kwargs
|
||||
from .utils import IS_MLU
|
||||
|
||||
__all__ = [
|
||||
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter',
|
||||
'scatter_kwargs', 'IS_MLU'
|
||||
]
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
def scatter(input, devices):
|
||||
"""scatter copies tensor to MLU directly."""
|
||||
if isinstance(input, list):
|
||||
outputs = [scatter(_input, devices) for _input in input]
|
||||
return outputs
|
||||
elif isinstance(input, torch.Tensor):
|
||||
output = input.contiguous()
|
||||
return output.to('mlu') if devices != [-1] else output
|
||||
else:
|
||||
raise Exception(f'Unknown type {type(input)}.')
|
||||
|
||||
|
||||
class Scatter:
|
||||
|
||||
@staticmethod
|
||||
def forward(target_mlus, input):
|
||||
outputs = scatter(input, target_mlus)
|
||||
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from .scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
class MLUDataParallel(MMDataParallel):
|
||||
"""The MLUDataParallel module that supports DataContainer.
|
||||
|
||||
MLUDataParallel is a class inherited from MMDataParall, which supports
|
||||
MLU training and inference only.
|
||||
|
||||
The main differences with MMDataParallel:
|
||||
|
||||
- It only supports single-card of MLU, and only use first card to
|
||||
run training and inference.
|
||||
|
||||
- It uses direct host-to-device copy instead of stream-background
|
||||
scatter.
|
||||
|
||||
.. warning::
|
||||
MLUDataParallel only supports single MLU training, if you need to
|
||||
train with multiple MLUs, please use MLUDistributedDataParallel
|
||||
instead. If you have multiple MLUs, you can set the environment
|
||||
variable ``MLU_VISIBLE_DEVICES=0`` (or any other card number(s))
|
||||
to specify the running device.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`): Module to be encapsulated.
|
||||
dim (int): Dimension used to scatter the data. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dim=0, **kwargs):
|
||||
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
||||
self.device_ids = [0]
|
||||
self.src_device_obj = torch.device('mlu:0')
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmcv.parallel import MMDistributedDataParallel
|
||||
from .scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
class MLUDistributedDataParallel(MMDistributedDataParallel):
|
||||
"""The DDP module supports DataContainer.
|
||||
|
||||
MLUDDP has one difference from MMDDP which moves data to MLU with coping
|
||||
instead of scattering.
|
||||
"""
|
||||
|
||||
def to_kwargs(self, inputs, kwargs, device_id):
|
||||
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
|
||||
# to move all tensors to device_id
|
||||
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmcv.parallel.data_container import DataContainer
|
||||
from ._functions import Scatter
|
||||
|
||||
|
||||
def scatter(inputs, target_mlus, dim=0):
|
||||
"""Scatter inputs to target mlu.
|
||||
|
||||
The only difference from original :func:`scatter` is to add support for
|
||||
:type:`~mmcv.parallel.DataContainer`.
|
||||
"""
|
||||
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if target_mlus != [-1]:
|
||||
obj = obj.to('mlu')
|
||||
return obj
|
||||
else:
|
||||
# for CPU inference we use self-implemented scatter
|
||||
return Scatter.forward(target_mlus, obj)
|
||||
if isinstance(obj, DataContainer):
|
||||
if obj.cpu_only:
|
||||
return obj.data
|
||||
else:
|
||||
return Scatter.forward(target_mlus, obj.data)
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
out = list(map(list, zip(*map(scatter_map, obj))))
|
||||
return out
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return out
|
||||
return [obj for targets in target_mlus]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
# has a reference to the actual function scatter_map, which has references
|
||||
# to a closure that has a reference to the scatter_map cell (because the
|
||||
# fn is recursive). To avoid this reference cycle, we set the function to
|
||||
# None, clearing the cell
|
||||
try:
|
||||
return scatter_map(inputs)
|
||||
finally:
|
||||
scatter_map = None
|
||||
|
||||
|
||||
def scatter_kwargs(inputs, kwargs, target_mlus, dim=0):
|
||||
"""Scatter with support for kwargs dictionary."""
|
||||
inputs = scatter(inputs, target_mlus, dim) if inputs else []
|
||||
kwargs = scatter(kwargs, target_mlus, dim) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
def is_mlu_available():
|
||||
try:
|
||||
import torch
|
||||
return (hasattr(torch, 'is_mlu_available')
|
||||
and torch.is_mlu_available())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
IS_MLU = is_mlu_available()
|
|
@ -8,6 +8,8 @@ using namespace at;
|
|||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_MLU(x) \
|
||||
TORCH_CHECK(x.device().type() == at::kMLU, #x " must be a MLU tensor")
|
||||
#define CHECK_CPU(x) \
|
||||
TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
|
@ -15,6 +17,9 @@ using namespace at;
|
|||
#define CHECK_CUDA_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_MLU_INPUT(x) \
|
||||
CHECK_MLU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_CPU_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2021 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef PYTORCH_MLU_HELPER_HPP_
|
||||
#define PYTORCH_MLU_HELPER_HPP_
|
||||
|
||||
#ifdef MMCV_WITH_MLU
|
||||
#include "aten.h"
|
||||
|
||||
#define NFU_ALIGN_SIZE 128
|
||||
|
||||
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
||||
|
||||
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
|
||||
|
||||
#endif
|
||||
|
||||
#endif // PYTORCH_MLU_HELPER_HPP_
|
|
@ -12,6 +12,8 @@ from torch import distributed as dist
|
|||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
|
||||
from mmcv.device.mlu import IS_MLU
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
|
||||
|
@ -47,9 +49,18 @@ def init_dist(launcher, backend='nccl', **kwargs):
|
|||
def _init_dist_pytorch(backend, **kwargs):
|
||||
# TODO: use local_rank instead of rank % num_gpus
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
if IS_MLU:
|
||||
import torch_mlu # noqa: F401
|
||||
torch.mlu.set_device(rank)
|
||||
dist.init_process_group(
|
||||
backend='cncl',
|
||||
rank=rank,
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
**kwargs)
|
||||
else:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_mpi(backend, **kwargs):
|
||||
|
|
18
setup.py
18
setup.py
|
@ -12,6 +12,10 @@ try:
|
|||
if torch.__version__ == 'parrots':
|
||||
from parrots.utils.build_extension import BuildExtension
|
||||
EXT_TYPE = 'parrots'
|
||||
elif (hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()) or \
|
||||
os.getenv('FORCE_MLU', '0') == '1':
|
||||
from torch_mlu.utils.cpp_extension import BuildExtension
|
||||
EXT_TYPE = 'pytorch'
|
||||
else:
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
EXT_TYPE = 'pytorch'
|
||||
|
@ -288,6 +292,20 @@ def get_extensions():
|
|||
extension = CUDAExtension
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
|
||||
elif (hasattr(torch, 'is_mlu_available') and
|
||||
torch.is_mlu_available()) or \
|
||||
os.getenv('FORCE_MLU', '0') == '1':
|
||||
from torch_mlu.utils.cpp_extension import MLUExtension
|
||||
define_macros += [('MMCV_WITH_MLU', None)]
|
||||
mlu_args = os.getenv('MMCV_MLU_ARGS')
|
||||
extra_compile_args['cncc'] = [mlu_args] if mlu_args else []
|
||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.mlu')
|
||||
extension = MLUExtension
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
|
||||
else:
|
||||
print(f'Compiling {ext_name} without CUDA')
|
||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.device.mlu import IS_MLU, MLUDataParallel, MLUDistributedDataParallel
|
||||
from mmcv.device.mlu._functions import Scatter, scatter
|
||||
from mmcv.parallel import is_module_wrapper
|
||||
|
||||
|
||||
def mock(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||
@patch('torch.distributed.broadcast', mock)
|
||||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
|
||||
def test_is_module_wrapper():
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
model = Model()
|
||||
assert not is_module_wrapper(model)
|
||||
|
||||
if IS_MLU:
|
||||
mludp = MLUDataParallel(model)
|
||||
assert is_module_wrapper(mludp)
|
||||
|
||||
mluddp = MLUDistributedDataParallel(model, process_group=MagicMock())
|
||||
assert is_module_wrapper(mluddp)
|
||||
|
||||
|
||||
def test_scatter():
|
||||
# if the device is CPU, just return the input
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[-1])
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[-1])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
# if the device is MLU, copy the input from CPU to MLU
|
||||
if IS_MLU:
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[0])
|
||||
assert torch.allclose(input.to('mlu'), output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[0])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mlu'), output)
|
||||
|
||||
# input should be a tensor or list of tensor
|
||||
with pytest.raises(Exception):
|
||||
scatter(5, [-1])
|
||||
|
||||
|
||||
def test_Scatter():
|
||||
# if the device is CPU, just return the input
|
||||
target_mlus = [-1]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_mlus, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input, outputs[0])
|
||||
|
||||
target_mlus = [-1]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_mlus, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
# if the device is MLU, copy the input from CPU to MLU
|
||||
if IS_MLU:
|
||||
target_mlus = [0]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_mlus, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input.to('mlu'), outputs[0])
|
||||
|
||||
target_mlus = [0]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_mlus, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mlu'), output[0])
|
Loading…
Reference in New Issue