[Fix] Fix a dilation bug of MLU-DCNv2 and add limitation of torchvision (#2519)

pull/2544/head^2
mengpenghui 2023-01-12 13:06:22 +08:00 committed by GitHub
parent 2810718a99
commit c9d477bb27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 35 deletions

View File

@ -355,10 +355,15 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
if IS_MLU_AVAILABLE: if IS_MLU_AVAILABLE:
import torchvision
from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d from torchvision.ops import deform_conv2d as tv_deform_conv2d
@CONV_LAYERS.register_module('DCNv2', force=True) @CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(nn.modules.Module): class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
"""This class is the DCNv2 implementation of the MLU device. The MLU """This class is the DCNv2 implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision. backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The The mmcv registration mechanism is used for multiplexing here. The
@ -377,31 +382,8 @@ if IS_MLU_AVAILABLE:
otherwise False. otherwise False.
""" """
def __init__(self, def __init__(self, *args, **kwargs):
in_channels: int, super().__init__(*args, **kwargs)
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.deform_groups * 3 * self.kernel_size[0] *
@ -409,17 +391,13 @@ if IS_MLU_AVAILABLE:
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
dilation=self.dilation,
bias=True) bias=True)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
n = self.in_channels super().init_weights()
for k in self.kernel_size: if hasattr(self, 'conv_offset'):
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
self.conv_offset.weight.data.zero_() self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_() self.conv_offset.bias.data.zero_()