diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 5f56a0da3..7b99e4db9 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE: o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) + x = x.type_as(offset) + weight = self.weight.type_as(x) + mask = mask.type_as(x) return tv_deform_conv2d( x, offset, - self.weight, + weight, bias=self.bias, stride=self.stride, padding=self.padding, diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 927489df6..703f68a23 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -74,7 +74,7 @@ class TestMdconv: assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), dcn_offset_b_grad, 1e-2) - def _test_amp_mdconv(self, input_dtype=torch.float): + def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'): """The function to test amp released on pytorch 1.6.0. The type of input data might be torch.float or torch.half, @@ -84,10 +84,15 @@ class TestMdconv: Args: input_dtype: torch.float or torch.half. """ - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and device == 'cuda': return - from mmcv.ops import ModulatedDeformConv2dPack - input = torch.tensor(input_t).cuda().type(input_dtype) + if device == 'mlu': + from mmcv.ops import \ + ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack + else: + from mmcv.ops import ModulatedDeformConv2dPack + + input = torch.tensor(input_t).to(device).type(input_dtype) input.requires_grad = True dcn = ModulatedDeformConv2dPack( @@ -97,7 +102,7 @@ class TestMdconv: stride=1, padding=1, deform_groups=1, - bias=False).cuda() + bias=False).to(device) dcn.weight.data.fill_(1.) output = dcn(input) output.sum().backward() @@ -126,5 +131,5 @@ class TestMdconv: if (TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): with autocast(enabled=True): - self._test_amp_mdconv(torch.float) - self._test_amp_mdconv(torch.half) + self._test_amp_mdconv(torch.float, device=device) + self._test_amp_mdconv(torch.half, device=device)