mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add support for mps (#2092)
* [Feature] Add support for MPS * fix import error * update ut * fix error * trigger CI * use a unique basename for test file modules * avoid bc-breakingpull/2107/head
parent
357b484d37
commit
6a03918f55
.github/workflows
tests/test_device
|
@ -61,7 +61,7 @@ jobs:
|
|||
--ignore=tests/test_utils/test_parrots_jit.py \
|
||||
--ignore=tests/test_utils/test_trace.py \
|
||||
--ignore=tests/test_utils/test_hub.py \
|
||||
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
|
||||
--ignore=tests/test_device \
|
||||
--ignore=tests/test_utils/test_torch_ops.py
|
||||
|
||||
build_without_ops:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import ipu, mlu
|
||||
from . import ipu, mlu, mps
|
||||
from .scatter_gather import scatter, scatter_kwargs
|
||||
from .utils import get_device
|
||||
|
||||
__all__ = ['mlu', 'ipu']
|
||||
__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from .utils import get_device
|
||||
|
||||
|
||||
def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
|
||||
"""scatter copies tensor to devices directly."""
|
||||
current_device = get_device()
|
||||
if isinstance(input, list):
|
||||
outputs = [scatter(_input, devices) for _input in input]
|
||||
return outputs
|
||||
elif isinstance(input, torch.Tensor):
|
||||
output = input.contiguous()
|
||||
return output.to(current_device) if devices != [-1] else output
|
||||
else:
|
||||
raise Exception(f'Unknown type {type(input)}.')
|
||||
|
||||
|
||||
class Scatter:
|
||||
|
||||
@staticmethod
|
||||
@deprecated_api_warning({'target_mlus': 'target_devices'},
|
||||
cls_name='Scatter')
|
||||
def forward(target_devices, input):
|
||||
outputs = scatter(input, target_devices)
|
||||
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
|
|
@ -1,9 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .data_parallel import MLUDataParallel
|
||||
from .distributed import MLUDistributedDataParallel
|
||||
from .scatter_gather import scatter, scatter_kwargs
|
||||
|
||||
__all__ = [
|
||||
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter',
|
||||
'scatter_kwargs'
|
||||
]
|
||||
__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .data_parallel import MPSDataParallel
|
||||
|
||||
__all__ = ['MPSDataParallel']
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from ..scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
class MPSDataParallel(MMDataParallel):
|
||||
"""The MPSDataParallel module that supports DataContainer.
|
||||
|
||||
MPSDataParallel is a class inherited from MMDataParall, which supports
|
||||
MPS training and inference only.
|
||||
|
||||
The main differences with MMDataParallel:
|
||||
|
||||
- It only supports single-card of MPS, and only use first card to
|
||||
run training and inference.
|
||||
|
||||
- It uses direct host-to-device copy instead of stream-background
|
||||
scatter.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`): Module to be encapsulated.
|
||||
dim (int): Dimension used to scatter the data. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dim=0, **kwargs):
|
||||
super().__init__(*args, dim=dim, **kwargs)
|
||||
self.device_ids = [0]
|
||||
self.src_device_obj = torch.device('mps:0')
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmcv.parallel.data_container import DataContainer
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from ._functions import Scatter
|
||||
from .utils import get_device
|
||||
|
||||
|
||||
@deprecated_api_warning({'target_mlus': 'target_devices'})
|
||||
def scatter(inputs, target_devices, dim=0):
|
||||
"""Scatter inputs to target devices.
|
||||
|
||||
The only difference from original :func:`scatter` is to add support for
|
||||
:type:`~mmcv.parallel.DataContainer`.
|
||||
"""
|
||||
current_device = get_device()
|
||||
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if target_devices != [-1]:
|
||||
obj = obj.to(current_device)
|
||||
return [obj]
|
||||
else:
|
||||
# for CPU inference we use self-implemented scatter
|
||||
return Scatter.forward(target_devices, obj)
|
||||
if isinstance(obj, DataContainer):
|
||||
if obj.cpu_only:
|
||||
return obj.data
|
||||
else:
|
||||
return Scatter.forward(target_devices, obj.data)
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
out = list(map(list, zip(*map(scatter_map, obj))))
|
||||
return out
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return out
|
||||
return [obj for _ in target_devices]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
# has a reference to the actual function scatter_map, which has references
|
||||
# to a closure that has a reference to the scatter_map cell (because the
|
||||
# fn is recursive). To avoid this reference cycle, we set the function to
|
||||
# None, clearing the cell
|
||||
try:
|
||||
return scatter_map(inputs)
|
||||
finally:
|
||||
scatter_map = None
|
||||
|
||||
|
||||
@deprecated_api_warning({'target_mlus': 'target_devices'})
|
||||
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
|
||||
"""Scatter with support for kwargs dictionary."""
|
||||
inputs = scatter(inputs, target_devices, dim) if inputs else []
|
||||
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Returns the currently existing device type.
|
||||
|
||||
Returns:
|
||||
str: cuda | mlu | mps | cpu.
|
||||
"""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
return 'cuda'
|
||||
elif IS_MLU_AVAILABLE:
|
||||
return 'mlu'
|
||||
elif IS_MPS_AVAILABLE:
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
|
@ -14,7 +14,7 @@ class MMDataParallel(DataParallel):
|
|||
|
||||
- It supports a custom type :class:`DataContainer` which allows more
|
||||
flexible control of input data during both GPU and CPU inference.
|
||||
- It implement two more APIs ``train_step()`` and ``val_step()``.
|
||||
- It implements two more APIs ``train_step()`` and ``val_step()``.
|
||||
|
||||
.. warning::
|
||||
MMDataParallel only supports single GPU training, if you need to
|
||||
|
|
|
@ -36,7 +36,8 @@ except ImportError:
|
|||
'is_method_overridden', 'has_method'
|
||||
]
|
||||
else:
|
||||
from .device_type import IS_IPU_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
|
||||
IS_MPS_AVAILABLE)
|
||||
from .env import collect_env
|
||||
from .hub import load_url
|
||||
from .logging import get_logger, print_log
|
||||
|
@ -76,5 +77,5 @@ else:
|
|||
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
|
||||
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
|
||||
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
|
||||
'torch_meshgrid'
|
||||
'IS_MPS_AVAILABLE', 'torch_meshgrid'
|
||||
]
|
||||
|
|
|
@ -22,3 +22,19 @@ def is_mlu_available() -> bool:
|
|||
|
||||
|
||||
IS_MLU_AVAILABLE = is_mlu_available()
|
||||
|
||||
|
||||
def is_mps_available() -> bool:
|
||||
"""Return True if mps devices exist.
|
||||
|
||||
It's specialized for mac m1 chips and require torch version 1.12 or higher.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
return hasattr(torch.backends,
|
||||
'mps') and torch.backends.mps.is_available()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
IS_MPS_AVAILABLE = is_mps_available()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.device import get_device
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
|
||||
|
||||
|
||||
def test_get_device():
|
||||
current_device = get_device()
|
||||
if IS_CUDA_AVAILABLE:
|
||||
assert current_device == 'cuda'
|
||||
elif IS_MLU_AVAILABLE:
|
||||
assert current_device == 'mlu'
|
||||
elif IS_MPS_AVAILABLE:
|
||||
assert current_device == 'mps'
|
||||
else:
|
||||
assert current_device == 'cpu'
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.device._functions import Scatter, scatter
|
||||
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
|
||||
|
||||
|
||||
def test_scatter():
|
||||
# if the device is CPU, just return the input
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[-1])
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[-1])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
# if the device is MLU, copy the input from CPU to MLU
|
||||
if IS_MLU_AVAILABLE:
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[0])
|
||||
assert torch.allclose(input.to('mlu'), output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[0])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mlu'), output)
|
||||
|
||||
# if the device is MPS, copy the input from CPU to MPS
|
||||
if IS_MPS_AVAILABLE:
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[0])
|
||||
assert torch.allclose(input.to('mps'), output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[0])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mps'), output)
|
||||
|
||||
# input should be a tensor or list of tensor
|
||||
with pytest.raises(Exception):
|
||||
scatter(5, [-1])
|
||||
|
||||
|
||||
def test_Scatter():
|
||||
# if the device is CPU, just return the input
|
||||
target_devices = [-1]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_devices, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input, outputs[0])
|
||||
|
||||
target_devices = [-1]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_devices, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
# if the device is MLU, copy the input from CPU to MLU
|
||||
if IS_MLU_AVAILABLE:
|
||||
target_devices = [0]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_devices, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input.to('mlu'), outputs[0])
|
||||
|
||||
target_devices = [0]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_devices, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mlu'), output[0])
|
||||
|
||||
# if the device is MPS, copy the input from CPU to MPS
|
||||
if IS_MPS_AVAILABLE:
|
||||
target_devices = [0]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_devices, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input.to('mps'), outputs[0])
|
||||
|
||||
target_devices = [0]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_devices, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mps'), output[0])
|
|
@ -1,12 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.device.mlu import MLUDataParallel, MLUDistributedDataParallel
|
||||
from mmcv.device.mlu._functions import Scatter, scatter
|
||||
from mmcv.parallel import is_module_wrapper
|
||||
from mmcv.utils import IS_MLU_AVAILABLE
|
||||
|
||||
|
@ -38,61 +35,3 @@ def test_is_module_wrapper():
|
|||
|
||||
mluddp = MLUDistributedDataParallel(model, process_group=MagicMock())
|
||||
assert is_module_wrapper(mluddp)
|
||||
|
||||
|
||||
def test_scatter():
|
||||
# if the device is CPU, just return the input
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[-1])
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[-1])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
# if the device is MLU, copy the input from CPU to MLU
|
||||
if IS_MLU_AVAILABLE:
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
output = scatter(input=input, devices=[0])
|
||||
assert torch.allclose(input.to('mlu'), output)
|
||||
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = scatter(input=inputs, devices=[0])
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mlu'), output)
|
||||
|
||||
# input should be a tensor or list of tensor
|
||||
with pytest.raises(Exception):
|
||||
scatter(5, [-1])
|
||||
|
||||
|
||||
def test_Scatter():
|
||||
# if the device is CPU, just return the input
|
||||
target_mlus = [-1]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_mlus, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input, outputs[0])
|
||||
|
||||
target_mlus = [-1]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_mlus, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input, output)
|
||||
|
||||
# if the device is MLU, copy the input from CPU to MLU
|
||||
if IS_MLU_AVAILABLE:
|
||||
target_mlus = [0]
|
||||
input = torch.zeros([1, 3, 3, 3])
|
||||
outputs = Scatter.forward(target_mlus, input)
|
||||
assert isinstance(outputs, tuple)
|
||||
assert torch.allclose(input.to('mlu'), outputs[0])
|
||||
|
||||
target_mlus = [0]
|
||||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
|
||||
outputs = Scatter.forward(target_mlus, inputs)
|
||||
assert isinstance(outputs, tuple)
|
||||
for input, output in zip(inputs, outputs):
|
||||
assert torch.allclose(input.to('mlu'), output[0])
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.device.mps import MPSDataParallel
|
||||
from mmcv.parallel import is_module_wrapper
|
||||
from mmcv.utils import IS_MPS_AVAILABLE
|
||||
|
||||
|
||||
def mock(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||
@patch('torch.distributed.broadcast', mock)
|
||||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
|
||||
def test_is_module_wrapper():
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
model = Model()
|
||||
assert not is_module_wrapper(model)
|
||||
|
||||
if IS_MPS_AVAILABLE:
|
||||
mpsdp = MPSDataParallel(model)
|
||||
assert is_module_wrapper(mpsdp)
|
Loading…
Reference in New Issue