[Fix] Support amp (pytorch >= 1.6.0) on DCN and DCNv2/ Add unit tests on DCN/DCNv2 amp (#1029)

* fix fp16 bug on DCNv2

* support fp16 on DCN/DCNv2 when pytorch >= '1.6.0'

* add comment

* Modified the comments

* Unified the usages of '.to()' and '.type_as()'
pull/1045/head
Guangchen Lin 2021-05-23 17:42:59 +08:00 committed by GitHub
parent e9f2a02b47
commit 4bd3b5027a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 2 deletions

View File

@ -70,8 +70,14 @@ class DeformConv2dFunction(Function):
ctx.deform_groups = deform_groups
ctx.im2col_step = im2col_step
# until the code is modified for torch.cuda.amp.autocast,
# we need to cast weight to avoid type mismatch in fp16 training
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
# to float16 by nn.Conv2d automatically, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "offset",
# we cast weight and input to temporarily support fp16 and amp
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
ctx.save_for_backward(input, offset, weight)

View File

@ -57,6 +57,15 @@ class ModulatedDeformConv2dFunction(Function):
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
# to float16 by nn.Conv2d automatically, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "offset",
# we cast weight and input to temporarily support fp16 and amp
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))

View File

@ -2,6 +2,15 @@ import numpy as np
import pytest
import torch
from mmcv.utils import TORCH_VERSION
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
pass
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
[[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]],
@ -71,7 +80,68 @@ class TestDeformconv(object):
with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3)
def _test_amp_deformconv(self, input_dtype, threshold=1e-3):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
so we should test deform_conv in both cases. With amp, the
data type of model will NOT be set manually.
Args:
input_dtype: torch.float or torch.half.
threshold: the same as above function.
"""
if not torch.cuda.is_available():
return
from mmcv.ops import DeformConv2dPack
c_in = 1
c_out = 1
x = torch.Tensor(input).cuda().type(input_dtype)
x.requires_grad = True
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
model.conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
model.conv_offset.bias.data = torch.nn.Parameter(
torch.Tensor(offset_bias).reshape(8))
model.weight.data = torch.nn.Parameter(
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
model.cuda()
out = model(x)
out.backward(torch.ones_like(out))
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold)
assert np.allclose(
model.conv_offset.weight.grad.detach().cpu().numpy(),
gt_offset_weight_grad, threshold)
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(),
gt_offset_bias_grad, threshold)
assert np.allclose(model.weight.grad.detach().cpu().numpy(),
gt_deform_weight_grad, threshold)
from mmcv.ops import DeformConv2d
# test bias
model = DeformConv2d(1, 1, 2, stride=1, padding=0)
assert not hasattr(model, 'bias')
# test bias=True
with pytest.raises(AssertionError):
model = DeformConv2d(1, 1, 2, stride=1, padding=0, bias=True)
# test in_channels % group != 0
with pytest.raises(AssertionError):
model = DeformConv2d(3, 2, 3, groups=2)
# test out_channels % group != 0
with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3)
def test_deformconv(self):
self._test_deformconv(torch.double)
self._test_deformconv(torch.float)
self._test_deformconv(torch.half, 1e-1)
# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)

View File

@ -3,6 +3,15 @@ import os
import numpy
import torch
from mmcv.utils import TORCH_VERSION
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
pass
cur_dir = os.path.dirname(os.path.abspath(__file__))
input_t = [[[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]]]
@ -58,7 +67,52 @@ class TestMdconv(object):
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)
def _test_amp_mdconv(self, input_dtype=torch.float):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
so we should test mdconv in both cases. With amp, the data
type of model will NOT be set manually.
Args:
input_dtype: torch.float or torch.half.
"""
if not torch.cuda.is_available():
return
from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).cuda().type(input_dtype)
input.requires_grad = True
dcn = ModulatedDeformConv2dPack(
1,
1,
kernel_size=(2, 2),
stride=1,
padding=1,
deform_groups=1,
bias=False).cuda()
dcn.weight.data.fill_(1.)
output = dcn(input)
output.sum().backward()
assert numpy.allclose(output.cpu().detach().numpy(), output_t, 1e-2)
assert numpy.allclose(input.grad.cpu().detach().numpy(), input_grad,
1e-2)
assert numpy.allclose(dcn.weight.grad.cpu().detach().numpy(),
dcn_w_grad, 1e-2)
assert numpy.allclose(
dcn.conv_offset.weight.grad.cpu().detach().numpy(),
dcn_offset_w_grad, 1e-2)
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)
def test_mdconv(self):
self._test_mdconv(torch.double)
self._test_mdconv(torch.float)
self._test_mdconv(torch.half)
# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)