[Fix] Fix deform conv by adding an extra argument im2col_step (#1459)

* [Fix] fix deform conv by add argument

* [Fix] replace useless func with np.repeat and add necessary comment

* [Fix] reduce batch_size and remove useless lines and modify some var name

* [Fix] change position of comments

* [Fix] add im2col_step in the docstring and add two test cases

* [Fix] fix docstring and add comments for test cases

* [Fix] fix docstring

* [Fix] add note, fix issue link and other details

* [Fix] fix docstring details

* [Fix] fix links in docstring

* [Fix] fix docstring details
pull/1489/head
WangJiaZhen 2021-11-10 19:05:20 +08:00 committed by GitHub
parent 0633f91139
commit e3c63f34bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 26 deletions

View File

@ -117,8 +117,8 @@ class DeformConv2dFunction(Function):
grad_input = grad_offset = grad_weight = None
cur_im2col_step = min(ctx.im2col_step, input.size(0))
assert (input.size(0) %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
assert (input.size(0) % cur_im2col_step
) == 0, 'batch size must be divisible by im2col_step'
grad_output = grad_output.contiguous()
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
@ -197,6 +197,13 @@ class DeformConv2d(nn.Module):
`Deformable Convolutional Networks
<https://arxiv.org/pdf/1703.06211.pdf>`_
Note:
The argument ``im2col_step`` was added in version 1.3.17, which means
number of samples processed by the ``im2col_cuda_kernel`` per call.
It enables users to define ``batch_size`` and ``im2col_step`` more
flexibly and solved `issue mmcv#1440
<https://github.com/open-mmlab/mmcv/issues/1440>`_.
Args:
in_channels (int): Number of channels in the input image.
out_channels (int): Number of channels produced by the convolution.
@ -210,7 +217,10 @@ class DeformConv2d(nn.Module):
deform_groups (int): Number of deformable group partitions.
bias (bool): If True, adds a learnable bias to the output.
Default: False.
im2col_step (int): Number of samples processed by im2col_cuda_kernel
per call. It will work when ``batch_size`` > ``im2col_step``, but
``batch_size`` must be divisible by ``im2col_step``. Default: 32.
`New in version 1.3.17.`
"""
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
@ -224,7 +234,8 @@ class DeformConv2d(nn.Module):
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
deform_groups: int = 1,
bias: bool = False) -> None:
bias: bool = False,
im2col_step: int = 32) -> None:
super(DeformConv2d, self).__init__()
assert not bias, \
@ -243,6 +254,7 @@ class DeformConv2d(nn.Module):
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.im2col_step = im2col_step
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
@ -293,7 +305,8 @@ class DeformConv2d(nn.Module):
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
offset = offset.contiguous()
out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups)
self.dilation, self.groups, self.deform_groups,
False, self.im2col_step)
if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
pad_w].contiguous()
@ -361,7 +374,8 @@ class DeformConv2dPack(DeformConv2d):
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups)
self.dilation, self.groups, self.deform_groups,
False, self.im2col_step)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):

View File

@ -39,15 +39,27 @@ class TestDeformconv(object):
def _test_deformconv(self,
dtype=torch.float,
threshold=1e-3,
device='cuda'):
device='cuda',
batch_size=10,
im2col_step=2):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
x = torch.tensor(input, device=device, dtype=dtype)
batch_size = 10
repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.tensor(repeated_input, device=device, dtype=dtype)
x.requires_grad = True
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
model = DeformConv2dPack(
in_channels=c_in,
out_channels=c_out,
kernel_size=2,
stride=1,
padding=0,
im2col_step=im2col_step)
model.conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
model.conv_offset.bias.data = torch.nn.Parameter(
@ -61,15 +73,21 @@ class TestDeformconv(object):
out = model(x)
out.backward(torch.ones_like(out))
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold)
assert np.allclose(out.data.detach().cpu().numpy(), repeated_gt_out,
threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), repeated_gt_x_grad,
threshold)
# the batch size of the input is increased which results in
# a larger gradient so we need to divide by the batch_size
assert np.allclose(
model.conv_offset.weight.grad.detach().cpu().numpy(),
model.conv_offset.weight.grad.detach().cpu().numpy() / batch_size,
gt_offset_weight_grad, threshold)
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(),
gt_offset_bias_grad, threshold)
assert np.allclose(model.weight.grad.detach().cpu().numpy(),
gt_deform_weight_grad, threshold)
assert np.allclose(
model.conv_offset.bias.grad.detach().cpu().numpy() / batch_size,
gt_offset_bias_grad, threshold)
assert np.allclose(
model.weight.grad.detach().cpu().numpy() / batch_size,
gt_deform_weight_grad, threshold)
from mmcv.ops import DeformConv2d
@ -86,7 +104,11 @@ class TestDeformconv(object):
with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3)
def _test_amp_deformconv(self, input_dtype, threshold=1e-3):
def _test_amp_deformconv(self,
input_dtype,
threshold=1e-3,
batch_size=10,
im2col_step=2):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
@ -102,9 +124,18 @@ class TestDeformconv(object):
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
x = torch.Tensor(input).cuda().type(input_dtype)
repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.Tensor(repeated_input).cuda().type(input_dtype)
x.requires_grad = True
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
model = DeformConv2dPack(
in_channels=c_in,
out_channels=c_out,
kernel_size=2,
stride=1,
padding=0,
im2col_step=im2col_step)
model.conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
model.conv_offset.bias.data = torch.nn.Parameter(
@ -116,15 +147,19 @@ class TestDeformconv(object):
out = model(x)
out.backward(torch.ones_like(out))
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold)
assert np.allclose(out.data.detach().cpu().numpy(), repeated_gt_out,
threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), repeated_gt_x_grad,
threshold)
assert np.allclose(
model.conv_offset.weight.grad.detach().cpu().numpy(),
model.conv_offset.weight.grad.detach().cpu().numpy() / batch_size,
gt_offset_weight_grad, threshold)
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(),
gt_offset_bias_grad, threshold)
assert np.allclose(model.weight.grad.detach().cpu().numpy(),
gt_deform_weight_grad, threshold)
assert np.allclose(
model.conv_offset.bias.grad.detach().cpu().numpy() / batch_size,
gt_offset_bias_grad, threshold)
assert np.allclose(
model.weight.grad.detach().cpu().numpy() / batch_size,
gt_deform_weight_grad, threshold)
from mmcv.ops import DeformConv2d
@ -147,6 +182,13 @@ class TestDeformconv(object):
self._test_deformconv(torch.double)
self._test_deformconv(torch.float)
self._test_deformconv(torch.half, threshold=1e-1)
# test batch_size < im2col_step
self._test_deformconv(torch.float, batch_size=1, im2col_step=2)
# test bach_size % im2col_step != 0
with pytest.raises(
AssertionError,
match='batch size must be divisible by im2col_step'):
self._test_deformconv(torch.float, batch_size=10, im2col_step=3)
# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half