From 6254ebef8d1d106e9cca8d4097786c6172cf70ce Mon Sep 17 00:00:00 2001
From: "q.yao" <yaoqian@sensetime.com>
Date: Tue, 8 Nov 2022 15:54:34 +0800
Subject: [PATCH] [Fix] Create Tensor with new_* method to support amp (#2389)

---
 mmcv/ops/diff_iou_rotated.py             |   4 +-
 mmcv/ops/group_points.py                 |   6 +-
 mmcv/ops/three_interpolate.py            |   4 +-
 tests/test_ops/test_group_points.py      | 251 ++++++++---------------
 tests/test_ops/test_three_interpolate.py |  35 ++--
 5 files changed, 116 insertions(+), 184 deletions(-)

diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py
index cdc6c72f8..ddcf4b4fc 100644
--- a/mmcv/ops/diff_iou_rotated.py
+++ b/mmcv/ops/diff_iou_rotated.py
@@ -235,9 +235,9 @@ def box2corners(box: Tensor) -> Tensor:
     """
     B = box.size()[0]
     x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1)
-    x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device)
+    x4 = box.new_tensor([0.5, -0.5, -0.5, 0.5]).to(box.device)
     x4 = x4 * w  # (B, N, 4)
-    y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device)
+    y4 = box.new_tensor([0.5, 0.5, -0.5, -0.5]).to(box.device)
     y4 = y4 * h  # (B, N, 4)
     corners = torch.stack([x4, y4], dim=-1)  # (B, N, 4, 2)
     sin = torch.sin(alpha)
diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py
index 95f359e86..999728c22 100644
--- a/mmcv/ops/group_points.py
+++ b/mmcv/ops/group_points.py
@@ -233,7 +233,7 @@ class GroupingOperation(Function):
         else:
             B, nfeatures, nsample = indices.size()
             _, C, N = features.size()
-            output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+            output = features.new_zeros(B, C, nfeatures, nsample)
 
             ext_module.group_points_forward(
                 features,
@@ -262,7 +262,7 @@ class GroupingOperation(Function):
             idx, N = ctx.for_backwards
 
             B, C, npoint, nsample = grad_out.size()
-            grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+            grad_features = grad_out.new_zeros(B, C, N)
 
             grad_out_data = grad_out.data.contiguous()
             ext_module.group_points_backward(
@@ -279,7 +279,7 @@ class GroupingOperation(Function):
             B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
 
             M, C, nsample = grad_out.size()
-            grad_features = torch.cuda.FloatTensor(N, C).zero_()
+            grad_features = grad_out.new_zeros(N, C)
 
             grad_out_data = grad_out.data.contiguous()
             ext_module.stack_group_points_backward(
diff --git a/mmcv/ops/three_interpolate.py b/mmcv/ops/three_interpolate.py
index 12b2f7611..286bd0472 100644
--- a/mmcv/ops/three_interpolate.py
+++ b/mmcv/ops/three_interpolate.py
@@ -38,7 +38,7 @@ class ThreeInterpolate(Function):
         B, c, m = features.size()
         n = indices.size(1)
         ctx.three_interpolate_for_backward = (indices, weight, m)
-        output = torch.cuda.FloatTensor(B, c, n)
+        output = features.new_empty(B, c, n)
 
         ext_module.three_interpolate_forward(
             features, indices, weight, output, b=B, c=c, m=m, n=n)
@@ -58,7 +58,7 @@ class ThreeInterpolate(Function):
         idx, weight, m = ctx.three_interpolate_for_backward
         B, c, n = grad_out.size()
 
-        grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
+        grad_features = grad_out.new_zeros(B, c, m)
         grad_out_data = grad_out.data.contiguous()
 
         ext_module.three_interpolate_backward(
diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py
index 48c0161ba..8109540ce 100644
--- a/tests/test_ops/test_group_points.py
+++ b/tests/test_ops/test_group_points.py
@@ -7,7 +7,8 @@ from mmcv.ops import grouping_operation
 
 @pytest.mark.skipif(
     not torch.cuda.is_available(), reason='requires CUDA support')
-def test_grouping_points():
+@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
+def test_grouping_points(dtype):
     idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0],
                          [0, 0, 0]],
                         [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
@@ -35,51 +36,37 @@ def test_grouping_points():
                               [
                                   -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                   -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
-                              ]]]).cuda()
+                              ]]],
+                            dtype=dtype).cuda()
 
     output = grouping_operation(features, idx)
-    expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
-                                      [-1.3311, -1.3311, -1.3311],
-                                      [0.9268, 0.9268, 0.9268],
-                                      [0.5798, 0.5798, 0.5798],
-                                      [0.5798, 0.5798, 0.5798],
-                                      [0.5798, 0.5798, 0.5798]],
-                                     [[5.4247, 5.4247, 5.4247],
-                                      [1.4740, 1.4740, 1.4740],
-                                      [2.1581, 2.1581, 2.1581],
-                                      [5.4247, 5.4247, 5.4247],
-                                      [5.4247, 5.4247, 5.4247],
-                                      [5.4247, 5.4247, 5.4247]],
-                                     [[-1.6266, -1.6266, -1.6266],
-                                      [-1.6931, -1.6931, -1.6931],
-                                      [-1.6786, -1.6786, -1.6786],
-                                      [-1.6266, -1.6266, -1.6266],
-                                      [-1.6266, -1.6266, -1.6266],
-                                      [-1.6266, -1.6266, -1.6266]]],
-                                    [[[-0.0380, -0.0380, -0.0380],
-                                      [-0.3693, -0.3693, -0.3693],
-                                      [-1.8527, -1.8527, -1.8527],
-                                      [-0.0380, -0.0380, -0.0380],
-                                      [-0.0380, -0.0380, -0.0380],
-                                      [-0.0380, -0.0380, -0.0380]],
-                                     [[1.1773, 1.1773, 1.1773],
-                                      [6.0865, 6.0865, 6.0865],
-                                      [2.8229, 2.8229, 2.8229],
-                                      [1.1773, 1.1773, 1.1773],
-                                      [1.1773, 1.1773, 1.1773],
-                                      [1.1773, 1.1773, 1.1773]],
-                                     [[-0.6646, -0.6646, -0.6646],
-                                      [0.4990, 0.4990, 0.4990],
-                                      [0.0386, 0.0386, 0.0386],
-                                      [-0.6646, -0.6646, -0.6646],
-                                      [-0.6646, -0.6646, -0.6646],
-                                      [-0.6646, -0.6646, -0.6646]]]]).cuda()
+    expected_output = torch.tensor(
+        [[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311],
+           [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798],
+           [0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798]],
+          [[5.4247, 5.4247, 5.4247], [1.4740, 1.4740, 1.4740],
+           [2.1581, 2.1581, 2.1581], [5.4247, 5.4247, 5.4247],
+           [5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247]],
+          [[-1.6266, -1.6266, -1.6266], [-1.6931, -1.6931, -1.6931],
+           [-1.6786, -1.6786, -1.6786], [-1.6266, -1.6266, -1.6266],
+           [-1.6266, -1.6266, -1.6266], [-1.6266, -1.6266, -1.6266]]],
+         [[[-0.0380, -0.0380, -0.0380], [-0.3693, -0.3693, -0.3693],
+           [-1.8527, -1.8527, -1.8527], [-0.0380, -0.0380, -0.0380],
+           [-0.0380, -0.0380, -0.0380], [-0.0380, -0.0380, -0.0380]],
+          [[1.1773, 1.1773, 1.1773], [6.0865, 6.0865, 6.0865],
+           [2.8229, 2.8229, 2.8229], [1.1773, 1.1773, 1.1773],
+           [1.1773, 1.1773, 1.1773], [1.1773, 1.1773, 1.1773]],
+          [[-0.6646, -0.6646, -0.6646], [0.4990, 0.4990, 0.4990],
+           [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646],
+           [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]],
+        dtype=dtype).cuda()
     assert torch.allclose(output, expected_output)
 
 
 @pytest.mark.skipif(
     not torch.cuda.is_available(), reason='requires CUDA support')
-def test_stack_grouping_points():
+@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
+def test_stack_grouping_points(dtype):
     idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
                         [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
                         [1, 1, 1], [0, 0, 0]]).int().cuda()
@@ -106,130 +93,72 @@ def test_stack_grouping_points():
                              [
                                  -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                  -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
-                             ]]).float().cuda()
+                             ]],
+                            dtype=dtype).cuda()
     features_batch_cnt = torch.tensor([3, 3]).int().cuda()
     indices_batch_cnt = torch.tensor([6, 6]).int().cuda()
     output = grouping_operation(features, idx, features_batch_cnt,
                                 indices_batch_cnt)
-    expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798],
-                                     [-0.7981, -0.7981, -0.7981],
-                                     [-0.9280, -0.9280, -0.9280],
-                                     [-1.3311, -1.3311, -1.3311],
-                                     [1.3687, 1.3687, 1.3687],
-                                     [0.9277, 0.9277, 0.9277],
-                                     [-0.4164, -0.4164, -0.4164],
-                                     [-1.8274, -1.8274, -1.8274],
-                                     [0.9268, 0.9268, 0.9268],
-                                     [0.8414, 0.8414, 0.8414]],
-                                    [[0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000]],
-                                    [[0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000]],
-                                    [[5.4247, 5.4247, 5.4247],
-                                     [1.5113, 1.5113, 1.5113],
-                                     [2.3944, 2.3944, 2.3944],
-                                     [1.4740, 1.4740, 1.4740],
-                                     [5.0300, 5.0300, 5.0300],
-                                     [5.1030, 5.1030, 5.1030],
-                                     [1.9360, 1.9360, 1.9360],
-                                     [2.1939, 2.1939, 2.1939],
-                                     [2.1581, 2.1581, 2.1581],
-                                     [3.4666, 3.4666, 3.4666]],
-                                    [[0.5798, 0.5798, 0.5798],
-                                     [-0.7981, -0.7981, -0.7981],
-                                     [-0.9280, -0.9280, -0.9280],
-                                     [-1.3311, -1.3311, -1.3311],
-                                     [1.3687, 1.3687, 1.3687],
-                                     [0.9277, 0.9277, 0.9277],
-                                     [-0.4164, -0.4164, -0.4164],
-                                     [-1.8274, -1.8274, -1.8274],
-                                     [0.9268, 0.9268, 0.9268],
-                                     [0.8414, 0.8414, 0.8414]],
-                                    [[-1.6266, -1.6266, -1.6266],
-                                     [-1.0281, -1.0281, -1.0281],
-                                     [-1.0393, -1.0393, -1.0393],
-                                     [-1.6931, -1.6931, -1.6931],
-                                     [-1.3982, -1.3982, -1.3982],
-                                     [-0.5732, -0.5732, -0.5732],
-                                     [-1.0830, -1.0830, -1.0830],
-                                     [-1.7561, -1.7561, -1.7561],
-                                     [-1.6786, -1.6786, -1.6786],
-                                     [-1.6967, -1.6967, -1.6967]],
-                                    [[-0.0380, -0.0380, -0.0380],
-                                     [-0.1880, -0.1880, -0.1880],
-                                     [-1.5724, -1.5724, -1.5724],
-                                     [0.6905, 0.6905, 0.6905],
-                                     [-0.3190, -0.3190, -0.3190],
-                                     [0.7798, 0.7798, 0.7798],
-                                     [-0.3693, -0.3693, -0.3693],
-                                     [-0.9457, -0.9457, -0.9457],
-                                     [-0.2942, -0.2942, -0.2942],
-                                     [-1.8527, -1.8527, -1.8527]],
-                                    [[0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000]],
-                                    [[0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000],
-                                     [0.0000, 0.0000, 0.0000]],
-                                    [[-0.0380, -0.0380, -0.0380],
-                                     [-0.1880, -0.1880, -0.1880],
-                                     [-1.5724, -1.5724, -1.5724],
-                                     [0.6905, 0.6905, 0.6905],
-                                     [-0.3190, -0.3190, -0.3190],
-                                     [0.7798, 0.7798, 0.7798],
-                                     [-0.3693, -0.3693, -0.3693],
-                                     [-0.9457, -0.9457, -0.9457],
-                                     [-0.2942, -0.2942, -0.2942],
-                                     [-1.8527, -1.8527, -1.8527]],
-                                    [[1.1773, 1.1773, 1.1773],
-                                     [1.5009, 1.5009, 1.5009],
-                                     [2.6399, 2.6399, 2.6399],
-                                     [5.9242, 5.9242, 5.9242],
-                                     [1.0962, 1.0962, 1.0962],
-                                     [2.7346, 2.7346, 2.7346],
-                                     [6.0865, 6.0865, 6.0865],
-                                     [1.5555, 1.5555, 1.5555],
-                                     [4.3303, 4.3303, 4.3303],
-                                     [2.8229, 2.8229, 2.8229]],
-                                    [[-0.0380, -0.0380, -0.0380],
-                                     [-0.1880, -0.1880, -0.1880],
-                                     [-1.5724, -1.5724, -1.5724],
-                                     [0.6905, 0.6905, 0.6905],
-                                     [-0.3190, -0.3190, -0.3190],
-                                     [0.7798, 0.7798, 0.7798],
-                                     [-0.3693, -0.3693, -0.3693],
-                                     [-0.9457, -0.9457, -0.9457],
-                                     [-0.2942, -0.2942, -0.2942],
-                                     [-1.8527, -1.8527,
-                                      -1.8527]]]).cuda().float()
+    expected_output = torch.tensor(
+        [[[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981],
+          [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311],
+          [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277],
+          [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274],
+          [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]],
+         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
+         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
+         [[5.4247, 5.4247, 5.4247], [1.5113, 1.5113, 1.5113],
+          [2.3944, 2.3944, 2.3944], [1.4740, 1.4740, 1.4740],
+          [5.0300, 5.0300, 5.0300], [5.1030, 5.1030, 5.1030],
+          [1.9360, 1.9360, 1.9360], [2.1939, 2.1939, 2.1939],
+          [2.1581, 2.1581, 2.1581], [3.4666, 3.4666, 3.4666]],
+         [[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981],
+          [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311],
+          [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277],
+          [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274],
+          [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]],
+         [[-1.6266, -1.6266, -1.6266], [-1.0281, -1.0281, -1.0281],
+          [-1.0393, -1.0393, -1.0393], [-1.6931, -1.6931, -1.6931],
+          [-1.3982, -1.3982, -1.3982], [-0.5732, -0.5732, -0.5732],
+          [-1.0830, -1.0830, -1.0830], [-1.7561, -1.7561, -1.7561],
+          [-1.6786, -1.6786, -1.6786], [-1.6967, -1.6967, -1.6967]],
+         [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
+          [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
+          [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
+          [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
+          [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]],
+         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
+         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
+          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
+         [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
+          [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
+          [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
+          [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
+          [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]],
+         [[1.1773, 1.1773, 1.1773], [1.5009, 1.5009, 1.5009],
+          [2.6399, 2.6399, 2.6399], [5.9242, 5.9242, 5.9242],
+          [1.0962, 1.0962, 1.0962], [2.7346, 2.7346, 2.7346],
+          [6.0865, 6.0865, 6.0865], [1.5555, 1.5555, 1.5555],
+          [4.3303, 4.3303, 4.3303], [2.8229, 2.8229, 2.8229]],
+         [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
+          [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
+          [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
+          [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
+          [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]]],
+        dtype=dtype).cuda()
     assert torch.allclose(output, expected_output)
diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py
index 900f451ff..51a6b8732 100644
--- a/tests/test_ops/test_three_interpolate.py
+++ b/tests/test_ops/test_three_interpolate.py
@@ -7,19 +7,20 @@ from mmcv.ops import three_interpolate
 
 @pytest.mark.skipif(
     not torch.cuda.is_available(), reason='requires CUDA support')
-def test_three_interpolate():
-    features = torch.tensor([[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
-                              [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
-                              [2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732],
-                              [0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124],
-                              [0.3207, 0.0000, 0.3411, 0.3207, 0.3207,
-                               0.3207]],
-                             [[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000],
-                              [0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346],
-                              [0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
-                              [0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
-                              [0.5814, 0.0103, 0.0000, 0.5814, 0.5814,
-                               0.5814]]]).cuda()
+@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
+def test_three_interpolate(dtype):
+    features = torch.tensor(
+        [[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
+          [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
+          [2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732],
+          [0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124],
+          [0.3207, 0.0000, 0.3411, 0.3207, 0.3207, 0.3207]],
+         [[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000],
+          [0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346],
+          [0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
+          [0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
+          [0.5814, 0.0103, 0.0000, 0.5814, 0.5814, 0.5814]]],
+        dtype=dtype).cuda()
 
     idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2],
                          [0, 1, 3]],
@@ -37,7 +38,8 @@ def test_three_interpolate():
                             [1.0000e+00, 1.7148e-08, 1.4070e-08],
                             [3.3333e-01, 3.3333e-01, 3.3333e-01],
                             [3.3333e-01, 3.3333e-01, 3.3333e-01],
-                            [3.3333e-01, 3.3333e-01, 3.3333e-01]]]).cuda()
+                            [3.3333e-01, 3.3333e-01, 3.3333e-01]]],
+                          dtype=dtype).cuda()
 
     output = three_interpolate(features, idx, weight)
     expected_output = torch.tensor([[[
@@ -70,6 +72,7 @@ def test_three_interpolate():
                                      [
                                          3.8760e-01, 1.0300e-02, 8.3569e-09,
                                          3.8760e-01, 3.8760e-01, 1.9723e-01
-                                     ]]]).cuda()
+                                     ]]],
+                                   dtype=dtype).cuda()
 
-    assert torch.allclose(output, expected_output, 1e-4)
+    assert torch.allclose(output, expected_output, 1e-3, 1e-4)