mirror of https://github.com/open-mmlab/mmcv.git
[Bug] Fix DeformConv2d bias error and add tests (#940)
* [Bug] Fix DeformConv2d bias error and add tests * fix repr * revise tests * lintpull/942/head
parent
3bcc796d38
commit
79f8cbd661
|
@ -200,7 +200,7 @@ class DeformConv2d(nn.Module):
|
|||
channels to output channels. Default: 1.
|
||||
deform_groups (int): Number of deformable group partitions.
|
||||
bias (bool): If True, adds a learnable bias to the output.
|
||||
Default: True.
|
||||
Default: False.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -234,7 +234,6 @@ class DeformConv2d(nn.Module):
|
|||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.deform_groups = deform_groups
|
||||
self.bias = bias
|
||||
# enable compatibility with nn.Conv2d
|
||||
self.transposed = False
|
||||
self.output_padding = _single(0)
|
||||
|
@ -301,7 +300,8 @@ class DeformConv2d(nn.Module):
|
|||
s += f'dilation={self.dilation},\n'
|
||||
s += f'groups={self.groups},\n'
|
||||
s += f'deform_groups={self.deform_groups},\n'
|
||||
s += f'bias={self.bias})'
|
||||
# bias is not supported in DeformConv2d.
|
||||
s += 'deform_groups=False)'
|
||||
return s
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
|
||||
|
@ -56,6 +57,20 @@ class TestDeformconv(object):
|
|||
assert np.allclose(model.weight.grad.detach().cpu().numpy(),
|
||||
gt_deform_weight_grad, threshold)
|
||||
|
||||
from mmcv.ops import DeformConv2d
|
||||
# test bias
|
||||
model = DeformConv2d(1, 1, 2, stride=1, padding=0)
|
||||
assert not hasattr(model, 'bias')
|
||||
# test bias=True
|
||||
with pytest.raises(AssertionError):
|
||||
model = DeformConv2d(1, 1, 2, stride=1, padding=0, bias=True)
|
||||
# test in_channels % group != 0
|
||||
with pytest.raises(AssertionError):
|
||||
model = DeformConv2d(3, 2, 3, groups=2)
|
||||
# test out_channels % group != 0
|
||||
with pytest.raises(AssertionError):
|
||||
model = DeformConv2d(3, 4, 3, groups=3)
|
||||
|
||||
def test_deformconv(self):
|
||||
self._test_deformconv(torch.double)
|
||||
self._test_deformconv(torch.float)
|
||||
|
|
Loading…
Reference in New Issue