mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add MLU support for DCN (#2540)
parent
c9d477bb27
commit
c310d28c8f
|
@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| ConvexIoU | | √ | | | |
|
||||
| CornerPool | | √ | | | |
|
||||
| Correlation | | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||
| Deformable Convolution v1/v2 | √ | √ | √ | | √ |
|
||||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | | | |
|
||||
| DynamicScatter | | √ | | | |
|
||||
|
|
|
@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| ConvexIoU | | √ | | | |
|
||||
| CornerPool | | √ | | | |
|
||||
| Correlation | | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||
| Deformable Convolution v1/v2 | √ | √ | √ | | √ |
|
||||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | | | |
|
||||
| DynamicScatter | | √ | | | |
|
||||
|
|
|
@ -109,6 +109,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
if IS_MLU_AVAILABLE:
|
||||
from .deform_conv import DeformConv2dPack_MLU # noqa:F401
|
||||
from .modulated_deform_conv import \
|
||||
ModulatedDeformConv2dPack_MLU # noqa:F401
|
||||
__all__.append('ModulatedDeformConv2dPack_MLU')
|
||||
__all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU'])
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.autograd import Function
|
|||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair, _single
|
||||
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning
|
||||
from ..cnn import CONV_LAYERS
|
||||
from ..utils import ext_loader, print_log
|
||||
from .modulated_deform_conv import ModulatedDeformConv2dFunction
|
||||
|
@ -434,3 +434,67 @@ class DeformConv2dPack(DeformConv2d):
|
|||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
|
||||
|
||||
if IS_MLU_AVAILABLE:
|
||||
import torchvision
|
||||
|
||||
from mmcv.utils import digit_version
|
||||
assert digit_version(torchvision.__version__) >= digit_version(
|
||||
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
|
||||
|
||||
from torchvision.ops import deform_conv2d as tv_deform_conv2d
|
||||
|
||||
@CONV_LAYERS.register_module('DCN', force=True)
|
||||
class DeformConv2dPack_MLU(DeformConv2d):
|
||||
"""This class is the DCN implementation of the MLU device. The MLU
|
||||
backend support of the operator has been implemented in torchvision.
|
||||
The mmcv registration mechanism is used for multiplexing here. The
|
||||
torchvision implementation of DCN is called.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||
stride (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
padding (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
dilation (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
bias (bool or str): If specified as `auto`, it will be decided by
|
||||
the norm_cfg. Bias will be set as True if norm_cfg is None,
|
||||
otherwise False.
|
||||
im2col_step (int): Number of samples processed by
|
||||
im2col_cuda_kernel per call. It will work when ``batch_size``
|
||||
> ``im2col_step``, but ``batch_size`` must be divisible by
|
||||
``im2col_step``. Default: 32. `New in version 1.7.2.
|
||||
Currently not supported on MLU devices.`
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deform_groups * 2 * self.kernel_size[0] *
|
||||
self.kernel_size[1],
|
||||
kernel_size=self.kernel_size,
|
||||
stride=_pair(self.stride),
|
||||
padding=_pair(self.padding),
|
||||
dilation=_pair(self.dilation),
|
||||
bias=True)
|
||||
self.init_offset()
|
||||
|
||||
def init_offset(self):
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor: # type: ignore
|
||||
cur_im2col_step = min(self.im2col_step, x.size(0))
|
||||
assert (x.size(0) % cur_im2col_step
|
||||
) == 0, 'batch size must be divisible by im2col_step'
|
||||
offset = self.conv_offset(x)
|
||||
x = x.type_as(offset)
|
||||
weight = self.weight
|
||||
weight = weight.type_as(x)
|
||||
return tv_deform_conv2d(x, offset, weight, None, self.stride,
|
||||
self.padding, self.dilation)
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version
|
||||
|
||||
try:
|
||||
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
|
||||
|
@ -45,7 +45,10 @@ class TestDeformconv:
|
|||
im2col_step=2):
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
pytest.skip('test requires GPU')
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
if device == 'mlu':
|
||||
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
|
||||
else:
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
c_in = 1
|
||||
c_out = 1
|
||||
batch_size = 10
|
||||
|
@ -69,6 +72,8 @@ class TestDeformconv:
|
|||
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
|
||||
if device == 'cuda':
|
||||
model.cuda()
|
||||
elif device == 'mlu':
|
||||
model.mlu()
|
||||
model.type(dtype)
|
||||
|
||||
out = model(x)
|
||||
|
@ -108,6 +113,7 @@ class TestDeformconv:
|
|||
def _test_amp_deformconv(self,
|
||||
input_dtype,
|
||||
threshold=1e-3,
|
||||
device='cuda',
|
||||
batch_size=10,
|
||||
im2col_step=2):
|
||||
"""The function to test amp released on pytorch 1.6.0.
|
||||
|
@ -120,15 +126,18 @@ class TestDeformconv:
|
|||
input_dtype: torch.float or torch.half.
|
||||
threshold: the same as above function.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
return
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
if device == 'mlu':
|
||||
from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack
|
||||
else:
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
c_in = 1
|
||||
c_out = 1
|
||||
repeated_input = np.repeat(input, batch_size, axis=0)
|
||||
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
|
||||
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
|
||||
x = torch.Tensor(repeated_input).cuda().type(input_dtype)
|
||||
x = torch.Tensor(repeated_input).to(device).type(input_dtype)
|
||||
x.requires_grad = True
|
||||
model = DeformConv2dPack(
|
||||
in_channels=c_in,
|
||||
|
@ -143,7 +152,10 @@ class TestDeformconv:
|
|||
torch.Tensor(offset_bias).reshape(8))
|
||||
model.weight.data = torch.nn.Parameter(
|
||||
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
|
||||
model.cuda()
|
||||
if device == 'cuda':
|
||||
model.cuda()
|
||||
elif device == 'mlu':
|
||||
model.mlu()
|
||||
|
||||
out = model(x)
|
||||
out.backward(torch.ones_like(out))
|
||||
|
@ -180,21 +192,25 @@ class TestDeformconv:
|
|||
def test_deformconv(self):
|
||||
self._test_deformconv(torch.double, device='cpu')
|
||||
self._test_deformconv(torch.float, device='cpu', threshold=1e-1)
|
||||
self._test_deformconv(torch.double)
|
||||
self._test_deformconv(torch.float)
|
||||
self._test_deformconv(torch.half, threshold=1e-1)
|
||||
|
||||
device = 'mlu' if IS_MLU_AVAILABLE else 'cuda'
|
||||
self._test_deformconv(torch.double, device=device)
|
||||
self._test_deformconv(torch.float, device=device)
|
||||
self._test_deformconv(torch.half, threshold=1e-1, device=device)
|
||||
# test batch_size < im2col_step
|
||||
self._test_deformconv(torch.float, batch_size=1, im2col_step=2)
|
||||
self._test_deformconv(
|
||||
torch.float, batch_size=1, im2col_step=2, device=device)
|
||||
# test bach_size % im2col_step != 0
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match='batch size must be divisible by im2col_step'):
|
||||
self._test_deformconv(torch.float, batch_size=10, im2col_step=3)
|
||||
self._test_deformconv(
|
||||
torch.float, batch_size=10, im2col_step=3, device=device)
|
||||
|
||||
# test amp when torch version >= '1.6.0', the type of
|
||||
# input data for deformconv might be torch.float or torch.half
|
||||
if (TORCH_VERSION != 'parrots'
|
||||
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
||||
with autocast(enabled=True):
|
||||
self._test_amp_deformconv(torch.float, 1e-1)
|
||||
self._test_amp_deformconv(torch.half, 1e-1)
|
||||
self._test_amp_deformconv(torch.float, 1e-1, device)
|
||||
self._test_amp_deformconv(torch.half, 1e-1, device)
|
||||
|
|
Loading…
Reference in New Issue