[Enhance] Add AMP support for MLU_DCNv2 (#2548)

pull/2560/head
mengpenghui 2023-01-13 17:51:02 +08:00 committed by GitHub
parent c310d28c8f
commit 71ee2a61f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 8 deletions

View File

@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE:
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d(
x,
offset,
self.weight,
weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,

View File

@ -74,7 +74,7 @@ class TestMdconv:
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)
def _test_amp_mdconv(self, input_dtype=torch.float):
def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
@ -84,10 +84,15 @@ class TestMdconv:
Args:
input_dtype: torch.float or torch.half.
"""
if not torch.cuda.is_available():
if not torch.cuda.is_available() and device == 'cuda':
return
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).cuda().type(input_dtype)
input = torch.tensor(input_t).to(device).type(input_dtype)
input.requires_grad = True
dcn = ModulatedDeformConv2dPack(
@ -97,7 +102,7 @@ class TestMdconv:
stride=1,
padding=1,
deform_groups=1,
bias=False).cuda()
bias=False).to(device)
dcn.weight.data.fill_(1.)
output = dcn(input)
output.sum().backward()
@ -126,5 +131,5 @@ class TestMdconv:
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)
self._test_amp_mdconv(torch.float, device=device)
self._test_amp_mdconv(torch.half, device=device)