From 6254ebef8d1d106e9cca8d4097786c6172cf70ce Mon Sep 17 00:00:00 2001 From: "q.yao" <yaoqian@sensetime.com> Date: Tue, 8 Nov 2022 15:54:34 +0800 Subject: [PATCH] [Fix] Create Tensor with new_* method to support amp (#2389) --- mmcv/ops/diff_iou_rotated.py | 4 +- mmcv/ops/group_points.py | 6 +- mmcv/ops/three_interpolate.py | 4 +- tests/test_ops/test_group_points.py | 251 ++++++++--------------- tests/test_ops/test_three_interpolate.py | 35 ++-- 5 files changed, 116 insertions(+), 184 deletions(-) diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index cdc6c72f8..ddcf4b4fc 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -235,9 +235,9 @@ def box2corners(box: Tensor) -> Tensor: """ B = box.size()[0] x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1) - x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device) + x4 = box.new_tensor([0.5, -0.5, -0.5, 0.5]).to(box.device) x4 = x4 * w # (B, N, 4) - y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device) + y4 = box.new_tensor([0.5, 0.5, -0.5, -0.5]).to(box.device) y4 = y4 * h # (B, N, 4) corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) sin = torch.sin(alpha) diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py index 95f359e86..999728c22 100644 --- a/mmcv/ops/group_points.py +++ b/mmcv/ops/group_points.py @@ -233,7 +233,7 @@ class GroupingOperation(Function): else: B, nfeatures, nsample = indices.size() _, C, N = features.size() - output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + output = features.new_zeros(B, C, nfeatures, nsample) ext_module.group_points_forward( features, @@ -262,7 +262,7 @@ class GroupingOperation(Function): idx, N = ctx.for_backwards B, C, npoint, nsample = grad_out.size() - grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + grad_features = grad_out.new_zeros(B, C, N) grad_out_data = grad_out.data.contiguous() ext_module.group_points_backward( @@ -279,7 +279,7 @@ class GroupingOperation(Function): B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards M, C, nsample = grad_out.size() - grad_features = torch.cuda.FloatTensor(N, C).zero_() + grad_features = grad_out.new_zeros(N, C) grad_out_data = grad_out.data.contiguous() ext_module.stack_group_points_backward( diff --git a/mmcv/ops/three_interpolate.py b/mmcv/ops/three_interpolate.py index 12b2f7611..286bd0472 100644 --- a/mmcv/ops/three_interpolate.py +++ b/mmcv/ops/three_interpolate.py @@ -38,7 +38,7 @@ class ThreeInterpolate(Function): B, c, m = features.size() n = indices.size(1) ctx.three_interpolate_for_backward = (indices, weight, m) - output = torch.cuda.FloatTensor(B, c, n) + output = features.new_empty(B, c, n) ext_module.three_interpolate_forward( features, indices, weight, output, b=B, c=c, m=m, n=n) @@ -58,7 +58,7 @@ class ThreeInterpolate(Function): idx, weight, m = ctx.three_interpolate_for_backward B, c, n = grad_out.size() - grad_features = torch.cuda.FloatTensor(B, c, m).zero_() + grad_features = grad_out.new_zeros(B, c, m) grad_out_data = grad_out.data.contiguous() ext_module.three_interpolate_backward( diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py index 48c0161ba..8109540ce 100644 --- a/tests/test_ops/test_group_points.py +++ b/tests/test_ops/test_group_points.py @@ -7,7 +7,8 @@ from mmcv.ops import grouping_operation @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_grouping_points(): +@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double]) +def test_grouping_points(dtype): idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], @@ -35,51 +36,37 @@ def test_grouping_points(): [ -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 - ]]]).cuda() + ]]], + dtype=dtype).cuda() output = grouping_operation(features, idx) - expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], - [-1.3311, -1.3311, -1.3311], - [0.9268, 0.9268, 0.9268], - [0.5798, 0.5798, 0.5798], - [0.5798, 0.5798, 0.5798], - [0.5798, 0.5798, 0.5798]], - [[5.4247, 5.4247, 5.4247], - [1.4740, 1.4740, 1.4740], - [2.1581, 2.1581, 2.1581], - [5.4247, 5.4247, 5.4247], - [5.4247, 5.4247, 5.4247], - [5.4247, 5.4247, 5.4247]], - [[-1.6266, -1.6266, -1.6266], - [-1.6931, -1.6931, -1.6931], - [-1.6786, -1.6786, -1.6786], - [-1.6266, -1.6266, -1.6266], - [-1.6266, -1.6266, -1.6266], - [-1.6266, -1.6266, -1.6266]]], - [[[-0.0380, -0.0380, -0.0380], - [-0.3693, -0.3693, -0.3693], - [-1.8527, -1.8527, -1.8527], - [-0.0380, -0.0380, -0.0380], - [-0.0380, -0.0380, -0.0380], - [-0.0380, -0.0380, -0.0380]], - [[1.1773, 1.1773, 1.1773], - [6.0865, 6.0865, 6.0865], - [2.8229, 2.8229, 2.8229], - [1.1773, 1.1773, 1.1773], - [1.1773, 1.1773, 1.1773], - [1.1773, 1.1773, 1.1773]], - [[-0.6646, -0.6646, -0.6646], - [0.4990, 0.4990, 0.4990], - [0.0386, 0.0386, 0.0386], - [-0.6646, -0.6646, -0.6646], - [-0.6646, -0.6646, -0.6646], - [-0.6646, -0.6646, -0.6646]]]]).cuda() + expected_output = torch.tensor( + [[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311], + [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798]], + [[5.4247, 5.4247, 5.4247], [1.4740, 1.4740, 1.4740], + [2.1581, 2.1581, 2.1581], [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247]], + [[-1.6266, -1.6266, -1.6266], [-1.6931, -1.6931, -1.6931], + [-1.6786, -1.6786, -1.6786], [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266], [-1.6266, -1.6266, -1.6266]]], + [[[-0.0380, -0.0380, -0.0380], [-0.3693, -0.3693, -0.3693], + [-1.8527, -1.8527, -1.8527], [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380], [-0.0380, -0.0380, -0.0380]], + [[1.1773, 1.1773, 1.1773], [6.0865, 6.0865, 6.0865], + [2.8229, 2.8229, 2.8229], [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773], [1.1773, 1.1773, 1.1773]], + [[-0.6646, -0.6646, -0.6646], [0.4990, 0.4990, 0.4990], + [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]], + dtype=dtype).cuda() assert torch.allclose(output, expected_output) @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_stack_grouping_points(): +@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double]) +def test_stack_grouping_points(dtype): idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0], [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [1, 1, 1], [0, 0, 0]]).int().cuda() @@ -106,130 +93,72 @@ def test_stack_grouping_points(): [ -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 - ]]).float().cuda() + ]], + dtype=dtype).cuda() features_batch_cnt = torch.tensor([3, 3]).int().cuda() indices_batch_cnt = torch.tensor([6, 6]).int().cuda() output = grouping_operation(features, idx, features_batch_cnt, indices_batch_cnt) - expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798], - [-0.7981, -0.7981, -0.7981], - [-0.9280, -0.9280, -0.9280], - [-1.3311, -1.3311, -1.3311], - [1.3687, 1.3687, 1.3687], - [0.9277, 0.9277, 0.9277], - [-0.4164, -0.4164, -0.4164], - [-1.8274, -1.8274, -1.8274], - [0.9268, 0.9268, 0.9268], - [0.8414, 0.8414, 0.8414]], - [[0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000]], - [[0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000]], - [[5.4247, 5.4247, 5.4247], - [1.5113, 1.5113, 1.5113], - [2.3944, 2.3944, 2.3944], - [1.4740, 1.4740, 1.4740], - [5.0300, 5.0300, 5.0300], - [5.1030, 5.1030, 5.1030], - [1.9360, 1.9360, 1.9360], - [2.1939, 2.1939, 2.1939], - [2.1581, 2.1581, 2.1581], - [3.4666, 3.4666, 3.4666]], - [[0.5798, 0.5798, 0.5798], - [-0.7981, -0.7981, -0.7981], - [-0.9280, -0.9280, -0.9280], - [-1.3311, -1.3311, -1.3311], - [1.3687, 1.3687, 1.3687], - [0.9277, 0.9277, 0.9277], - [-0.4164, -0.4164, -0.4164], - [-1.8274, -1.8274, -1.8274], - [0.9268, 0.9268, 0.9268], - [0.8414, 0.8414, 0.8414]], - [[-1.6266, -1.6266, -1.6266], - [-1.0281, -1.0281, -1.0281], - [-1.0393, -1.0393, -1.0393], - [-1.6931, -1.6931, -1.6931], - [-1.3982, -1.3982, -1.3982], - [-0.5732, -0.5732, -0.5732], - [-1.0830, -1.0830, -1.0830], - [-1.7561, -1.7561, -1.7561], - [-1.6786, -1.6786, -1.6786], - [-1.6967, -1.6967, -1.6967]], - [[-0.0380, -0.0380, -0.0380], - [-0.1880, -0.1880, -0.1880], - [-1.5724, -1.5724, -1.5724], - [0.6905, 0.6905, 0.6905], - [-0.3190, -0.3190, -0.3190], - [0.7798, 0.7798, 0.7798], - [-0.3693, -0.3693, -0.3693], - [-0.9457, -0.9457, -0.9457], - [-0.2942, -0.2942, -0.2942], - [-1.8527, -1.8527, -1.8527]], - [[0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000]], - [[0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.0000]], - [[-0.0380, -0.0380, -0.0380], - [-0.1880, -0.1880, -0.1880], - [-1.5724, -1.5724, -1.5724], - [0.6905, 0.6905, 0.6905], - [-0.3190, -0.3190, -0.3190], - [0.7798, 0.7798, 0.7798], - [-0.3693, -0.3693, -0.3693], - [-0.9457, -0.9457, -0.9457], - [-0.2942, -0.2942, -0.2942], - [-1.8527, -1.8527, -1.8527]], - [[1.1773, 1.1773, 1.1773], - [1.5009, 1.5009, 1.5009], - [2.6399, 2.6399, 2.6399], - [5.9242, 5.9242, 5.9242], - [1.0962, 1.0962, 1.0962], - [2.7346, 2.7346, 2.7346], - [6.0865, 6.0865, 6.0865], - [1.5555, 1.5555, 1.5555], - [4.3303, 4.3303, 4.3303], - [2.8229, 2.8229, 2.8229]], - [[-0.0380, -0.0380, -0.0380], - [-0.1880, -0.1880, -0.1880], - [-1.5724, -1.5724, -1.5724], - [0.6905, 0.6905, 0.6905], - [-0.3190, -0.3190, -0.3190], - [0.7798, 0.7798, 0.7798], - [-0.3693, -0.3693, -0.3693], - [-0.9457, -0.9457, -0.9457], - [-0.2942, -0.2942, -0.2942], - [-1.8527, -1.8527, - -1.8527]]]).cuda().float() + expected_output = torch.tensor( + [[[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981], + [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311], + [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277], + [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274], + [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]], + [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]], + [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]], + [[5.4247, 5.4247, 5.4247], [1.5113, 1.5113, 1.5113], + [2.3944, 2.3944, 2.3944], [1.4740, 1.4740, 1.4740], + [5.0300, 5.0300, 5.0300], [5.1030, 5.1030, 5.1030], + [1.9360, 1.9360, 1.9360], [2.1939, 2.1939, 2.1939], + [2.1581, 2.1581, 2.1581], [3.4666, 3.4666, 3.4666]], + [[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981], + [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311], + [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277], + [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274], + [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]], + [[-1.6266, -1.6266, -1.6266], [-1.0281, -1.0281, -1.0281], + [-1.0393, -1.0393, -1.0393], [-1.6931, -1.6931, -1.6931], + [-1.3982, -1.3982, -1.3982], [-0.5732, -0.5732, -0.5732], + [-1.0830, -1.0830, -1.0830], [-1.7561, -1.7561, -1.7561], + [-1.6786, -1.6786, -1.6786], [-1.6967, -1.6967, -1.6967]], + [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880], + [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905], + [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798], + [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457], + [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]], + [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]], + [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]], + [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880], + [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905], + [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798], + [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457], + [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]], + [[1.1773, 1.1773, 1.1773], [1.5009, 1.5009, 1.5009], + [2.6399, 2.6399, 2.6399], [5.9242, 5.9242, 5.9242], + [1.0962, 1.0962, 1.0962], [2.7346, 2.7346, 2.7346], + [6.0865, 6.0865, 6.0865], [1.5555, 1.5555, 1.5555], + [4.3303, 4.3303, 4.3303], [2.8229, 2.8229, 2.8229]], + [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880], + [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905], + [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798], + [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457], + [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]]], + dtype=dtype).cuda() assert torch.allclose(output, expected_output) diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index 900f451ff..51a6b8732 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -7,19 +7,20 @@ from mmcv.ops import three_interpolate @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') -def test_three_interpolate(): - features = torch.tensor([[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350], - [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236], - [2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732], - [0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124], - [0.3207, 0.0000, 0.3411, 0.3207, 0.3207, - 0.3207]], - [[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000], - [0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346], - [0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000], - [0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414], - [0.5814, 0.0103, 0.0000, 0.5814, 0.5814, - 0.5814]]]).cuda() +@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double]) +def test_three_interpolate(dtype): + features = torch.tensor( + [[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350], + [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236], + [2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732], + [0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124], + [0.3207, 0.0000, 0.3411, 0.3207, 0.3207, 0.3207]], + [[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000], + [0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346], + [0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000], + [0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414], + [0.5814, 0.0103, 0.0000, 0.5814, 0.5814, 0.5814]]], + dtype=dtype).cuda() idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2], [0, 1, 3]], @@ -37,7 +38,8 @@ def test_three_interpolate(): [1.0000e+00, 1.7148e-08, 1.4070e-08], [3.3333e-01, 3.3333e-01, 3.3333e-01], [3.3333e-01, 3.3333e-01, 3.3333e-01], - [3.3333e-01, 3.3333e-01, 3.3333e-01]]]).cuda() + [3.3333e-01, 3.3333e-01, 3.3333e-01]]], + dtype=dtype).cuda() output = three_interpolate(features, idx, weight) expected_output = torch.tensor([[[ @@ -70,6 +72,7 @@ def test_three_interpolate(): [ 3.8760e-01, 1.0300e-02, 8.3569e-09, 3.8760e-01, 3.8760e-01, 1.9723e-01 - ]]]).cuda() + ]]], + dtype=dtype).cuda() - assert torch.allclose(output, expected_output, 1e-4) + assert torch.allclose(output, expected_output, 1e-3, 1e-4)