mirror of https://github.com/open-mmlab/mmcv.git
[Enhance] Add AMP support for MLU_DCNv2 (#2548)
parent
c310d28c8f
commit
71ee2a61f2
|
@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE:
|
|||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
x = x.type_as(offset)
|
||||
weight = self.weight.type_as(x)
|
||||
mask = mask.type_as(x)
|
||||
return tv_deform_conv2d(
|
||||
x,
|
||||
offset,
|
||||
self.weight,
|
||||
weight,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
|
|
|
@ -74,7 +74,7 @@ class TestMdconv:
|
|||
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
|
||||
dcn_offset_b_grad, 1e-2)
|
||||
|
||||
def _test_amp_mdconv(self, input_dtype=torch.float):
|
||||
def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
|
||||
"""The function to test amp released on pytorch 1.6.0.
|
||||
|
||||
The type of input data might be torch.float or torch.half,
|
||||
|
@ -84,10 +84,15 @@ class TestMdconv:
|
|||
Args:
|
||||
input_dtype: torch.float or torch.half.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
return
|
||||
if device == 'mlu':
|
||||
from mmcv.ops import \
|
||||
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
|
||||
else:
|
||||
from mmcv.ops import ModulatedDeformConv2dPack
|
||||
input = torch.tensor(input_t).cuda().type(input_dtype)
|
||||
|
||||
input = torch.tensor(input_t).to(device).type(input_dtype)
|
||||
input.requires_grad = True
|
||||
|
||||
dcn = ModulatedDeformConv2dPack(
|
||||
|
@ -97,7 +102,7 @@ class TestMdconv:
|
|||
stride=1,
|
||||
padding=1,
|
||||
deform_groups=1,
|
||||
bias=False).cuda()
|
||||
bias=False).to(device)
|
||||
dcn.weight.data.fill_(1.)
|
||||
output = dcn(input)
|
||||
output.sum().backward()
|
||||
|
@ -126,5 +131,5 @@ class TestMdconv:
|
|||
if (TORCH_VERSION != 'parrots'
|
||||
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
||||
with autocast(enabled=True):
|
||||
self._test_amp_mdconv(torch.float)
|
||||
self._test_amp_mdconv(torch.half)
|
||||
self._test_amp_mdconv(torch.float, device=device)
|
||||
self._test_amp_mdconv(torch.half, device=device)
|
||||
|
|
Loading…
Reference in New Issue