mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support modulated_deform_conv with cambricon MLU backend (#2411)
* [Feature] Support modulated_deform_conv with cambricon MLU backend * fix error of torch_mlu * modify with commit suggestion * Update modulated_deform_conv.py * Update mmcv/ops/modulated_deform_conv.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2475/head
parent
c1712ee290
commit
49d891754b
|
@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| MaskedConv | | √ | √ | | √ |
|
||||
| MergeCells | | √ | | | |
|
||||
| MinAreaPolygon | | √ | | | |
|
||||
| ModulatedDeformConv2d | √ | √ | | | √ |
|
||||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | | | |
|
||||
|
|
|
@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| MaskedConv | | √ | √ | | √ |
|
||||
| MergeCells | | √ | | | |
|
||||
| MinAreaPolygon | | √ | | | |
|
||||
| ModulatedDeformConv2d | √ | √ | | | √ |
|
||||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | | | |
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import IS_MLU_AVAILABLE
|
||||
from .active_rotated_filter import active_rotated_filter
|
||||
from .assign_score_withk import assign_score_withk
|
||||
from .ball_query import ball_query
|
||||
|
@ -106,3 +107,8 @@ __all__ = [
|
|||
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
|
||||
'PrRoIPool', 'prroi_pool'
|
||||
]
|
||||
|
||||
if IS_MLU_AVAILABLE:
|
||||
from .modulated_deform_conv import \
|
||||
ModulatedDeformConv2dPack_MLU # noqa:F401
|
||||
__all__.append('ModulatedDeformConv2dPack_MLU')
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.autograd import Function
|
|||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair, _single
|
||||
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning
|
||||
from ..cnn import CONV_LAYERS
|
||||
from ..utils import ext_loader, print_log
|
||||
|
||||
|
@ -352,3 +352,88 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
|||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
|
||||
|
||||
if IS_MLU_AVAILABLE:
|
||||
from torchvision.ops import deform_conv2d as tv_deform_conv2d
|
||||
|
||||
@CONV_LAYERS.register_module('DCNv2', force=True)
|
||||
class ModulatedDeformConv2dPack_MLU(nn.modules.Module):
|
||||
"""This class is the DCNv2 implementation of the MLU device. The MLU
|
||||
backend support of the operator has been implemented in torchvision.
|
||||
The mmcv registration mechanism is used for multiplexing here. The
|
||||
torchvision implementation of DCNv2 is called.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||
stride (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
padding (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
dilation (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
bias (bool or str): If specified as `auto`, it will be decided by
|
||||
the norm_cfg. Bias will be set as True if norm_cfg is None,
|
||||
otherwise False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
deform_groups: int = 1,
|
||||
bias: Union[bool, str] = True):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
self.padding = _pair(padding)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.deform_groups = deform_groups
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels, *self.kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deform_groups * 3 * self.kernel_size[0] *
|
||||
self.kernel_size[1],
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
bias=True)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.zero_()
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_offset(x)
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
return tv_deform_conv2d(
|
||||
x,
|
||||
offset,
|
||||
self.weight,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
mask=mask)
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version
|
||||
|
||||
try:
|
||||
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
|
||||
|
@ -42,10 +42,14 @@ class TestMdconv:
|
|||
def _test_mdconv(self, dtype=torch.float, device='cuda'):
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
pytest.skip('test requires GPU')
|
||||
if device == 'mlu':
|
||||
from mmcv.ops import \
|
||||
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
|
||||
else:
|
||||
from mmcv.ops import ModulatedDeformConv2dPack
|
||||
|
||||
input = torch.tensor(input_t, dtype=dtype, device=device)
|
||||
input.requires_grad = True
|
||||
|
||||
dcn = ModulatedDeformConv2dPack(
|
||||
1,
|
||||
1,
|
||||
|
@ -53,10 +57,7 @@ class TestMdconv:
|
|||
stride=1,
|
||||
padding=1,
|
||||
deform_groups=1,
|
||||
bias=False)
|
||||
|
||||
if device == 'cuda':
|
||||
dcn.cuda()
|
||||
bias=False).to(device)
|
||||
|
||||
dcn.weight.data.fill_(1.)
|
||||
dcn.type(dtype)
|
||||
|
@ -114,9 +115,11 @@ class TestMdconv:
|
|||
def test_mdconv(self):
|
||||
self._test_mdconv(torch.double, device='cpu')
|
||||
self._test_mdconv(torch.float, device='cpu')
|
||||
self._test_mdconv(torch.double)
|
||||
self._test_mdconv(torch.float)
|
||||
self._test_mdconv(torch.half)
|
||||
|
||||
device = 'mlu' if IS_MLU_AVAILABLE else 'cuda'
|
||||
self._test_mdconv(torch.double, device=device)
|
||||
self._test_mdconv(torch.float, device=device)
|
||||
self._test_mdconv(torch.half, device=device)
|
||||
|
||||
# test amp when torch version >= '1.6.0', the type of
|
||||
# input data for mdconv might be torch.float or torch.half
|
||||
|
|
Loading…
Reference in New Issue