[Feature] Add support for mps ()

* [Feature] Add support for MPS

* fix import error

* update ut

* fix error

* trigger CI

* use a unique basename for test file modules

* avoid bc-breaking
pull/2107/head
Zaida Zhou 2022-07-07 16:05:49 +08:00 committed by GitHub
parent 357b484d37
commit 6a03918f55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 315 additions and 72 deletions

View File

@ -61,7 +61,7 @@ jobs:
--ignore=tests/test_utils/test_parrots_jit.py \
--ignore=tests/test_utils/test_trace.py \
--ignore=tests/test_utils/test_hub.py \
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
--ignore=tests/test_device \
--ignore=tests/test_utils/test_torch_ops.py
build_without_ops:

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import ipu, mlu
from . import ipu, mlu, mps
from .scatter_gather import scatter, scatter_kwargs
from .utils import get_device
__all__ = ['mlu', 'ipu']
__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']

View File

@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmcv.utils import deprecated_api_warning
from .utils import get_device
def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
"""scatter copies tensor to devices directly."""
current_device = get_device()
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to(current_device) if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')
class Scatter:
@staticmethod
@deprecated_api_warning({'target_mlus': 'target_devices'},
cls_name='Scatter')
def forward(target_devices, input):
outputs = scatter(input, target_devices)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )

View File

@ -1,9 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MLUDataParallel
from .distributed import MLUDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs
__all__ = [
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter',
'scatter_kwargs'
]
__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MPSDataParallel
__all__ = ['MPSDataParallel']

View File

@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from ..scatter_gather import scatter_kwargs
class MPSDataParallel(MMDataParallel):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mps:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

View File

@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from mmcv.utils import deprecated_api_warning
from ._functions import Scatter
from .utils import get_device
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter(inputs, target_devices, dim=0):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device = get_device()
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_devices != [-1]:
obj = obj.to(current_device)
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_devices, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_devices, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_devices]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_devices, dim) if inputs else []
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
elif IS_MPS_AVAILABLE:
return 'mps'
else:
return 'cpu'

View File

@ -14,7 +14,7 @@ class MMDataParallel(DataParallel):
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
- It implements two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to

View File

@ -36,7 +36,8 @@ except ImportError:
'is_method_overridden', 'has_method'
]
else:
from .device_type import IS_IPU_AVAILABLE, IS_MLU_AVAILABLE
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
from .env import collect_env
from .hub import load_url
from .logging import get_logger, print_log
@ -76,5 +77,5 @@ else:
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'torch_meshgrid'
'IS_MPS_AVAILABLE', 'torch_meshgrid'
]

View File

@ -22,3 +22,19 @@ def is_mlu_available() -> bool:
IS_MLU_AVAILABLE = is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
try:
import torch
return hasattr(torch.backends,
'mps') and torch.backends.mps.is_available()
except Exception:
return False
IS_MPS_AVAILABLE = is_mps_available()

View File

@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def test_get_device():
current_device = get_device()
if IS_CUDA_AVAILABLE:
assert current_device == 'cuda'
elif IS_MLU_AVAILABLE:
assert current_device == 'mlu'
elif IS_MPS_AVAILABLE:
assert current_device == 'mps'
else:
assert current_device == 'cpu'

View File

@ -0,0 +1,90 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.device._functions import Scatter, scatter
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mlu'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output)
# if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mps'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mps'), output)
# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])
def test_Scatter():
# if the device is CPU, just return the input
target_devices = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])
target_devices = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
target_devices = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mlu'), outputs[0])
target_devices = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output[0])
# if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE:
target_devices = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mps'), outputs[0])
target_devices = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mps'), output[0])

View File

@ -1,12 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from mmcv.device.mlu import MLUDataParallel, MLUDistributedDataParallel
from mmcv.device.mlu._functions import Scatter, scatter
from mmcv.parallel import is_module_wrapper
from mmcv.utils import IS_MLU_AVAILABLE
@ -38,61 +35,3 @@ def test_is_module_wrapper():
mluddp = MLUDistributedDataParallel(model, process_group=MagicMock())
assert is_module_wrapper(mluddp)
def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mlu'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output)
# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])
def test_Scatter():
# if the device is CPU, just return the input
target_mlus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_mlus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])
target_mlus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_mlus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
target_mlus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_mlus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mlu'), outputs[0])
target_mlus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_mlus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output[0])

View File

@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import torch.nn as nn
from mmcv.device.mps import MPSDataParallel
from mmcv.parallel import is_module_wrapper
from mmcv.utils import IS_MPS_AVAILABLE
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
assert not is_module_wrapper(model)
if IS_MPS_AVAILABLE:
mpsdp = MPSDataParallel(model)
assert is_module_wrapper(mpsdp)