From 71ee2a61f294c839688a907699759f59fea53de2 Mon Sep 17 00:00:00 2001
From: mengpenghui <116254103+mengpenghui@users.noreply.github.com>
Date: Fri, 13 Jan 2023 17:51:02 +0800
Subject: [PATCH] [Enhance] Add AMP support for MLU_DCNv2 (#2548)

---
 mmcv/ops/modulated_deform_conv.py            |  5 ++++-
 tests/test_ops/test_modulated_deform_conv.py | 19 ++++++++++++-------
 2 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py
index 5f56a0da3..7b99e4db9 100644
--- a/mmcv/ops/modulated_deform_conv.py
+++ b/mmcv/ops/modulated_deform_conv.py
@@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE:
             o1, o2, mask = torch.chunk(out, 3, dim=1)
             offset = torch.cat((o1, o2), dim=1)
             mask = torch.sigmoid(mask)
+            x = x.type_as(offset)
+            weight = self.weight.type_as(x)
+            mask = mask.type_as(x)
             return tv_deform_conv2d(
                 x,
                 offset,
-                self.weight,
+                weight,
                 bias=self.bias,
                 stride=self.stride,
                 padding=self.padding,
diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py
index 927489df6..703f68a23 100644
--- a/tests/test_ops/test_modulated_deform_conv.py
+++ b/tests/test_ops/test_modulated_deform_conv.py
@@ -74,7 +74,7 @@ class TestMdconv:
         assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
                               dcn_offset_b_grad, 1e-2)
 
-    def _test_amp_mdconv(self, input_dtype=torch.float):
+    def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
         """The function to test amp released on pytorch 1.6.0.
 
         The type of input data might be torch.float or torch.half,
@@ -84,10 +84,15 @@ class TestMdconv:
         Args:
             input_dtype: torch.float or torch.half.
         """
-        if not torch.cuda.is_available():
+        if not torch.cuda.is_available() and device == 'cuda':
             return
-        from mmcv.ops import ModulatedDeformConv2dPack
-        input = torch.tensor(input_t).cuda().type(input_dtype)
+        if device == 'mlu':
+            from mmcv.ops import \
+                ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
+        else:
+            from mmcv.ops import ModulatedDeformConv2dPack
+
+        input = torch.tensor(input_t).to(device).type(input_dtype)
         input.requires_grad = True
 
         dcn = ModulatedDeformConv2dPack(
@@ -97,7 +102,7 @@ class TestMdconv:
             stride=1,
             padding=1,
             deform_groups=1,
-            bias=False).cuda()
+            bias=False).to(device)
         dcn.weight.data.fill_(1.)
         output = dcn(input)
         output.sum().backward()
@@ -126,5 +131,5 @@ class TestMdconv:
         if (TORCH_VERSION != 'parrots'
                 and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
             with autocast(enabled=True):
-                self._test_amp_mdconv(torch.float)
-                self._test_amp_mdconv(torch.half)
+                self._test_amp_mdconv(torch.float, device=device)
+                self._test_amp_mdconv(torch.half, device=device)