mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix deform conv by adding an extra argument im2col_step (#1459)
* [Fix] fix deform conv by add argument * [Fix] replace useless func with np.repeat and add necessary comment * [Fix] reduce batch_size and remove useless lines and modify some var name * [Fix] change position of comments * [Fix] add im2col_step in the docstring and add two test cases * [Fix] fix docstring and add comments for test cases * [Fix] fix docstring * [Fix] add note, fix issue link and other details * [Fix] fix docstring details * [Fix] fix links in docstring * [Fix] fix docstring detailspull/1489/head
parent
0633f91139
commit
e3c63f34bc
|
@ -117,8 +117,8 @@ class DeformConv2dFunction(Function):
|
|||
grad_input = grad_offset = grad_weight = None
|
||||
|
||||
cur_im2col_step = min(ctx.im2col_step, input.size(0))
|
||||
assert (input.size(0) %
|
||||
cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||
assert (input.size(0) % cur_im2col_step
|
||||
) == 0, 'batch size must be divisible by im2col_step'
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
|
@ -197,6 +197,13 @@ class DeformConv2d(nn.Module):
|
|||
`Deformable Convolutional Networks
|
||||
<https://arxiv.org/pdf/1703.06211.pdf>`_
|
||||
|
||||
Note:
|
||||
The argument ``im2col_step`` was added in version 1.3.17, which means
|
||||
number of samples processed by the ``im2col_cuda_kernel`` per call.
|
||||
It enables users to define ``batch_size`` and ``im2col_step`` more
|
||||
flexibly and solved `issue mmcv#1440
|
||||
<https://github.com/open-mmlab/mmcv/issues/1440>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image.
|
||||
out_channels (int): Number of channels produced by the convolution.
|
||||
|
@ -210,7 +217,10 @@ class DeformConv2d(nn.Module):
|
|||
deform_groups (int): Number of deformable group partitions.
|
||||
bias (bool): If True, adds a learnable bias to the output.
|
||||
Default: False.
|
||||
|
||||
im2col_step (int): Number of samples processed by im2col_cuda_kernel
|
||||
per call. It will work when ``batch_size`` > ``im2col_step``, but
|
||||
``batch_size`` must be divisible by ``im2col_step``. Default: 32.
|
||||
`New in version 1.3.17.`
|
||||
"""
|
||||
|
||||
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
|
||||
|
@ -224,7 +234,8 @@ class DeformConv2d(nn.Module):
|
|||
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||
groups: int = 1,
|
||||
deform_groups: int = 1,
|
||||
bias: bool = False) -> None:
|
||||
bias: bool = False,
|
||||
im2col_step: int = 32) -> None:
|
||||
super(DeformConv2d, self).__init__()
|
||||
|
||||
assert not bias, \
|
||||
|
@ -243,6 +254,7 @@ class DeformConv2d(nn.Module):
|
|||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.deform_groups = deform_groups
|
||||
self.im2col_step = im2col_step
|
||||
# enable compatibility with nn.Conv2d
|
||||
self.transposed = False
|
||||
self.output_padding = _single(0)
|
||||
|
@ -293,7 +305,8 @@ class DeformConv2d(nn.Module):
|
|||
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
|
||||
offset = offset.contiguous()
|
||||
out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
|
||||
self.dilation, self.groups, self.deform_groups)
|
||||
self.dilation, self.groups, self.deform_groups,
|
||||
False, self.im2col_step)
|
||||
if input_pad:
|
||||
out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
|
||||
pad_w].contiguous()
|
||||
|
@ -361,7 +374,8 @@ class DeformConv2dPack(DeformConv2d):
|
|||
def forward(self, x):
|
||||
offset = self.conv_offset(x)
|
||||
return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
|
||||
self.dilation, self.groups, self.deform_groups)
|
||||
self.dilation, self.groups, self.deform_groups,
|
||||
False, self.im2col_step)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
|
|
|
@ -39,15 +39,27 @@ class TestDeformconv(object):
|
|||
def _test_deformconv(self,
|
||||
dtype=torch.float,
|
||||
threshold=1e-3,
|
||||
device='cuda'):
|
||||
device='cuda',
|
||||
batch_size=10,
|
||||
im2col_step=2):
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
pytest.skip('test requires GPU')
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
c_in = 1
|
||||
c_out = 1
|
||||
x = torch.tensor(input, device=device, dtype=dtype)
|
||||
batch_size = 10
|
||||
repeated_input = np.repeat(input, batch_size, axis=0)
|
||||
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
|
||||
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
|
||||
x = torch.tensor(repeated_input, device=device, dtype=dtype)
|
||||
x.requires_grad = True
|
||||
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
|
||||
model = DeformConv2dPack(
|
||||
in_channels=c_in,
|
||||
out_channels=c_out,
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
padding=0,
|
||||
im2col_step=im2col_step)
|
||||
model.conv_offset.weight.data = torch.nn.Parameter(
|
||||
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
|
||||
model.conv_offset.bias.data = torch.nn.Parameter(
|
||||
|
@ -61,15 +73,21 @@ class TestDeformconv(object):
|
|||
out = model(x)
|
||||
out.backward(torch.ones_like(out))
|
||||
|
||||
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold)
|
||||
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold)
|
||||
assert np.allclose(out.data.detach().cpu().numpy(), repeated_gt_out,
|
||||
threshold)
|
||||
assert np.allclose(x.grad.detach().cpu().numpy(), repeated_gt_x_grad,
|
||||
threshold)
|
||||
# the batch size of the input is increased which results in
|
||||
# a larger gradient so we need to divide by the batch_size
|
||||
assert np.allclose(
|
||||
model.conv_offset.weight.grad.detach().cpu().numpy(),
|
||||
model.conv_offset.weight.grad.detach().cpu().numpy() / batch_size,
|
||||
gt_offset_weight_grad, threshold)
|
||||
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(),
|
||||
gt_offset_bias_grad, threshold)
|
||||
assert np.allclose(model.weight.grad.detach().cpu().numpy(),
|
||||
gt_deform_weight_grad, threshold)
|
||||
assert np.allclose(
|
||||
model.conv_offset.bias.grad.detach().cpu().numpy() / batch_size,
|
||||
gt_offset_bias_grad, threshold)
|
||||
assert np.allclose(
|
||||
model.weight.grad.detach().cpu().numpy() / batch_size,
|
||||
gt_deform_weight_grad, threshold)
|
||||
|
||||
from mmcv.ops import DeformConv2d
|
||||
|
||||
|
@ -86,7 +104,11 @@ class TestDeformconv(object):
|
|||
with pytest.raises(AssertionError):
|
||||
model = DeformConv2d(3, 4, 3, groups=3)
|
||||
|
||||
def _test_amp_deformconv(self, input_dtype, threshold=1e-3):
|
||||
def _test_amp_deformconv(self,
|
||||
input_dtype,
|
||||
threshold=1e-3,
|
||||
batch_size=10,
|
||||
im2col_step=2):
|
||||
"""The function to test amp released on pytorch 1.6.0.
|
||||
|
||||
The type of input data might be torch.float or torch.half,
|
||||
|
@ -102,9 +124,18 @@ class TestDeformconv(object):
|
|||
from mmcv.ops import DeformConv2dPack
|
||||
c_in = 1
|
||||
c_out = 1
|
||||
x = torch.Tensor(input).cuda().type(input_dtype)
|
||||
repeated_input = np.repeat(input, batch_size, axis=0)
|
||||
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
|
||||
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
|
||||
x = torch.Tensor(repeated_input).cuda().type(input_dtype)
|
||||
x.requires_grad = True
|
||||
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
|
||||
model = DeformConv2dPack(
|
||||
in_channels=c_in,
|
||||
out_channels=c_out,
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
padding=0,
|
||||
im2col_step=im2col_step)
|
||||
model.conv_offset.weight.data = torch.nn.Parameter(
|
||||
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
|
||||
model.conv_offset.bias.data = torch.nn.Parameter(
|
||||
|
@ -116,15 +147,19 @@ class TestDeformconv(object):
|
|||
out = model(x)
|
||||
out.backward(torch.ones_like(out))
|
||||
|
||||
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold)
|
||||
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold)
|
||||
assert np.allclose(out.data.detach().cpu().numpy(), repeated_gt_out,
|
||||
threshold)
|
||||
assert np.allclose(x.grad.detach().cpu().numpy(), repeated_gt_x_grad,
|
||||
threshold)
|
||||
assert np.allclose(
|
||||
model.conv_offset.weight.grad.detach().cpu().numpy(),
|
||||
model.conv_offset.weight.grad.detach().cpu().numpy() / batch_size,
|
||||
gt_offset_weight_grad, threshold)
|
||||
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(),
|
||||
gt_offset_bias_grad, threshold)
|
||||
assert np.allclose(model.weight.grad.detach().cpu().numpy(),
|
||||
gt_deform_weight_grad, threshold)
|
||||
assert np.allclose(
|
||||
model.conv_offset.bias.grad.detach().cpu().numpy() / batch_size,
|
||||
gt_offset_bias_grad, threshold)
|
||||
assert np.allclose(
|
||||
model.weight.grad.detach().cpu().numpy() / batch_size,
|
||||
gt_deform_weight_grad, threshold)
|
||||
|
||||
from mmcv.ops import DeformConv2d
|
||||
|
||||
|
@ -147,6 +182,13 @@ class TestDeformconv(object):
|
|||
self._test_deformconv(torch.double)
|
||||
self._test_deformconv(torch.float)
|
||||
self._test_deformconv(torch.half, threshold=1e-1)
|
||||
# test batch_size < im2col_step
|
||||
self._test_deformconv(torch.float, batch_size=1, im2col_step=2)
|
||||
# test bach_size % im2col_step != 0
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match='batch size must be divisible by im2col_step'):
|
||||
self._test_deformconv(torch.float, batch_size=10, im2col_step=3)
|
||||
|
||||
# test amp when torch version >= '1.6.0', the type of
|
||||
# input data for deformconv might be torch.float or torch.half
|
||||
|
|
Loading…
Reference in New Issue