[Fix] Create Tensor with new_* method to support amp ()

pull/2397/head
q.yao 2022-11-08 15:54:34 +08:00 committed by GitHub
parent b622fb2e29
commit 6254ebef8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 184 deletions

View File

@ -235,9 +235,9 @@ def box2corners(box: Tensor) -> Tensor:
""" """
B = box.size()[0] B = box.size()[0]
x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1) x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1)
x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device) x4 = box.new_tensor([0.5, -0.5, -0.5, 0.5]).to(box.device)
x4 = x4 * w # (B, N, 4) x4 = x4 * w # (B, N, 4)
y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device) y4 = box.new_tensor([0.5, 0.5, -0.5, -0.5]).to(box.device)
y4 = y4 * h # (B, N, 4) y4 = y4 * h # (B, N, 4)
corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2)
sin = torch.sin(alpha) sin = torch.sin(alpha)

View File

@ -233,7 +233,7 @@ class GroupingOperation(Function):
else: else:
B, nfeatures, nsample = indices.size() B, nfeatures, nsample = indices.size()
_, C, N = features.size() _, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) output = features.new_zeros(B, C, nfeatures, nsample)
ext_module.group_points_forward( ext_module.group_points_forward(
features, features,
@ -262,7 +262,7 @@ class GroupingOperation(Function):
idx, N = ctx.for_backwards idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size() B, C, npoint, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_() grad_features = grad_out.new_zeros(B, C, N)
grad_out_data = grad_out.data.contiguous() grad_out_data = grad_out.data.contiguous()
ext_module.group_points_backward( ext_module.group_points_backward(
@ -279,7 +279,7 @@ class GroupingOperation(Function):
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
M, C, nsample = grad_out.size() M, C, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(N, C).zero_() grad_features = grad_out.new_zeros(N, C)
grad_out_data = grad_out.data.contiguous() grad_out_data = grad_out.data.contiguous()
ext_module.stack_group_points_backward( ext_module.stack_group_points_backward(

View File

@ -38,7 +38,7 @@ class ThreeInterpolate(Function):
B, c, m = features.size() B, c, m = features.size()
n = indices.size(1) n = indices.size(1)
ctx.three_interpolate_for_backward = (indices, weight, m) ctx.three_interpolate_for_backward = (indices, weight, m)
output = torch.cuda.FloatTensor(B, c, n) output = features.new_empty(B, c, n)
ext_module.three_interpolate_forward( ext_module.three_interpolate_forward(
features, indices, weight, output, b=B, c=c, m=m, n=n) features, indices, weight, output, b=B, c=c, m=m, n=n)
@ -58,7 +58,7 @@ class ThreeInterpolate(Function):
idx, weight, m = ctx.three_interpolate_for_backward idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size() B, c, n = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, c, m).zero_() grad_features = grad_out.new_zeros(B, c, m)
grad_out_data = grad_out.data.contiguous() grad_out_data = grad_out.data.contiguous()
ext_module.three_interpolate_backward( ext_module.three_interpolate_backward(

View File

@ -7,7 +7,8 @@ from mmcv.ops import grouping_operation
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support') not torch.cuda.is_available(), reason='requires CUDA support')
def test_grouping_points(): @pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
def test_grouping_points(dtype):
idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0],
[0, 0, 0]], [0, 0, 0]],
[[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
@ -35,51 +36,37 @@ def test_grouping_points():
[ [
-0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386 -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
]]]).cuda() ]]],
dtype=dtype).cuda()
output = grouping_operation(features, idx) output = grouping_operation(features, idx)
expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], expected_output = torch.tensor(
[-1.3311, -1.3311, -1.3311], [[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311],
[0.9268, 0.9268, 0.9268], [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798],
[0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798]],
[0.5798, 0.5798, 0.5798], [[5.4247, 5.4247, 5.4247], [1.4740, 1.4740, 1.4740],
[0.5798, 0.5798, 0.5798]], [2.1581, 2.1581, 2.1581], [5.4247, 5.4247, 5.4247],
[[5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247]],
[1.4740, 1.4740, 1.4740], [[-1.6266, -1.6266, -1.6266], [-1.6931, -1.6931, -1.6931],
[2.1581, 2.1581, 2.1581], [-1.6786, -1.6786, -1.6786], [-1.6266, -1.6266, -1.6266],
[5.4247, 5.4247, 5.4247], [-1.6266, -1.6266, -1.6266], [-1.6266, -1.6266, -1.6266]]],
[5.4247, 5.4247, 5.4247], [[[-0.0380, -0.0380, -0.0380], [-0.3693, -0.3693, -0.3693],
[5.4247, 5.4247, 5.4247]], [-1.8527, -1.8527, -1.8527], [-0.0380, -0.0380, -0.0380],
[[-1.6266, -1.6266, -1.6266], [-0.0380, -0.0380, -0.0380], [-0.0380, -0.0380, -0.0380]],
[-1.6931, -1.6931, -1.6931], [[1.1773, 1.1773, 1.1773], [6.0865, 6.0865, 6.0865],
[-1.6786, -1.6786, -1.6786], [2.8229, 2.8229, 2.8229], [1.1773, 1.1773, 1.1773],
[-1.6266, -1.6266, -1.6266], [1.1773, 1.1773, 1.1773], [1.1773, 1.1773, 1.1773]],
[-1.6266, -1.6266, -1.6266], [[-0.6646, -0.6646, -0.6646], [0.4990, 0.4990, 0.4990],
[-1.6266, -1.6266, -1.6266]]], [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646],
[[[-0.0380, -0.0380, -0.0380], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]],
[-0.3693, -0.3693, -0.3693], dtype=dtype).cuda()
[-1.8527, -1.8527, -1.8527],
[-0.0380, -0.0380, -0.0380],
[-0.0380, -0.0380, -0.0380],
[-0.0380, -0.0380, -0.0380]],
[[1.1773, 1.1773, 1.1773],
[6.0865, 6.0865, 6.0865],
[2.8229, 2.8229, 2.8229],
[1.1773, 1.1773, 1.1773],
[1.1773, 1.1773, 1.1773],
[1.1773, 1.1773, 1.1773]],
[[-0.6646, -0.6646, -0.6646],
[0.4990, 0.4990, 0.4990],
[0.0386, 0.0386, 0.0386],
[-0.6646, -0.6646, -0.6646],
[-0.6646, -0.6646, -0.6646],
[-0.6646, -0.6646, -0.6646]]]]).cuda()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support') not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_grouping_points(): @pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
def test_stack_grouping_points(dtype):
idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0], idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
[2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
[1, 1, 1], [0, 0, 0]]).int().cuda() [1, 1, 1], [0, 0, 0]]).int().cuda()
@ -106,130 +93,72 @@ def test_stack_grouping_points():
[ [
-0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386 -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
]]).float().cuda() ]],
dtype=dtype).cuda()
features_batch_cnt = torch.tensor([3, 3]).int().cuda() features_batch_cnt = torch.tensor([3, 3]).int().cuda()
indices_batch_cnt = torch.tensor([6, 6]).int().cuda() indices_batch_cnt = torch.tensor([6, 6]).int().cuda()
output = grouping_operation(features, idx, features_batch_cnt, output = grouping_operation(features, idx, features_batch_cnt,
indices_batch_cnt) indices_batch_cnt)
expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798], expected_output = torch.tensor(
[-0.7981, -0.7981, -0.7981], [[[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981],
[-0.9280, -0.9280, -0.9280], [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311],
[-1.3311, -1.3311, -1.3311], [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277],
[1.3687, 1.3687, 1.3687], [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274],
[0.9277, 0.9277, 0.9277], [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]],
[-0.4164, -0.4164, -0.4164], [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[-1.8274, -1.8274, -1.8274], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.9268, 0.9268, 0.9268], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.8414, 0.8414, 0.8414]], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
[0.0000, 0.0000, 0.0000], [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
[0.0000, 0.0000, 0.0000], [[5.4247, 5.4247, 5.4247], [1.5113, 1.5113, 1.5113],
[0.0000, 0.0000, 0.0000], [2.3944, 2.3944, 2.3944], [1.4740, 1.4740, 1.4740],
[0.0000, 0.0000, 0.0000], [5.0300, 5.0300, 5.0300], [5.1030, 5.1030, 5.1030],
[0.0000, 0.0000, 0.0000]], [1.9360, 1.9360, 1.9360], [2.1939, 2.1939, 2.1939],
[[0.0000, 0.0000, 0.0000], [2.1581, 2.1581, 2.1581], [3.4666, 3.4666, 3.4666]],
[0.0000, 0.0000, 0.0000], [[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981],
[0.0000, 0.0000, 0.0000], [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311],
[0.0000, 0.0000, 0.0000], [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277],
[0.0000, 0.0000, 0.0000], [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274],
[0.0000, 0.0000, 0.0000], [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]],
[0.0000, 0.0000, 0.0000], [[-1.6266, -1.6266, -1.6266], [-1.0281, -1.0281, -1.0281],
[0.0000, 0.0000, 0.0000], [-1.0393, -1.0393, -1.0393], [-1.6931, -1.6931, -1.6931],
[0.0000, 0.0000, 0.0000], [-1.3982, -1.3982, -1.3982], [-0.5732, -0.5732, -0.5732],
[0.0000, 0.0000, 0.0000]], [-1.0830, -1.0830, -1.0830], [-1.7561, -1.7561, -1.7561],
[[5.4247, 5.4247, 5.4247], [-1.6786, -1.6786, -1.6786], [-1.6967, -1.6967, -1.6967]],
[1.5113, 1.5113, 1.5113], [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
[2.3944, 2.3944, 2.3944], [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
[1.4740, 1.4740, 1.4740], [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
[5.0300, 5.0300, 5.0300], [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
[5.1030, 5.1030, 5.1030], [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]],
[1.9360, 1.9360, 1.9360], [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[2.1939, 2.1939, 2.1939], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[2.1581, 2.1581, 2.1581], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[3.4666, 3.4666, 3.4666]], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[[0.5798, 0.5798, 0.5798], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
[-0.7981, -0.7981, -0.7981], [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[-0.9280, -0.9280, -0.9280], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[-1.3311, -1.3311, -1.3311], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[1.3687, 1.3687, 1.3687], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.9277, 0.9277, 0.9277], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
[-0.4164, -0.4164, -0.4164], [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
[-1.8274, -1.8274, -1.8274], [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
[0.9268, 0.9268, 0.9268], [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
[0.8414, 0.8414, 0.8414]], [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
[[-1.6266, -1.6266, -1.6266], [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]],
[-1.0281, -1.0281, -1.0281], [[1.1773, 1.1773, 1.1773], [1.5009, 1.5009, 1.5009],
[-1.0393, -1.0393, -1.0393], [2.6399, 2.6399, 2.6399], [5.9242, 5.9242, 5.9242],
[-1.6931, -1.6931, -1.6931], [1.0962, 1.0962, 1.0962], [2.7346, 2.7346, 2.7346],
[-1.3982, -1.3982, -1.3982], [6.0865, 6.0865, 6.0865], [1.5555, 1.5555, 1.5555],
[-0.5732, -0.5732, -0.5732], [4.3303, 4.3303, 4.3303], [2.8229, 2.8229, 2.8229]],
[-1.0830, -1.0830, -1.0830], [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
[-1.7561, -1.7561, -1.7561], [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
[-1.6786, -1.6786, -1.6786], [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
[-1.6967, -1.6967, -1.6967]], [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
[[-0.0380, -0.0380, -0.0380], [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]]],
[-0.1880, -0.1880, -0.1880], dtype=dtype).cuda()
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527, -1.8527]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527, -1.8527]],
[[1.1773, 1.1773, 1.1773],
[1.5009, 1.5009, 1.5009],
[2.6399, 2.6399, 2.6399],
[5.9242, 5.9242, 5.9242],
[1.0962, 1.0962, 1.0962],
[2.7346, 2.7346, 2.7346],
[6.0865, 6.0865, 6.0865],
[1.5555, 1.5555, 1.5555],
[4.3303, 4.3303, 4.3303],
[2.8229, 2.8229, 2.8229]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527,
-1.8527]]]).cuda().float()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)

View File

@ -7,19 +7,20 @@ from mmcv.ops import three_interpolate
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support') not torch.cuda.is_available(), reason='requires CUDA support')
def test_three_interpolate(): @pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
features = torch.tensor([[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350], def test_three_interpolate(dtype):
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236], features = torch.tensor(
[2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732], [[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
[0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124], [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
[0.3207, 0.0000, 0.3411, 0.3207, 0.3207, [2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732],
0.3207]], [0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124],
[[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000], [0.3207, 0.0000, 0.3411, 0.3207, 0.3207, 0.3207]],
[0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346], [[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000],
[0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000], [0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346],
[0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414], [0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
[0.5814, 0.0103, 0.0000, 0.5814, 0.5814, [0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
0.5814]]]).cuda() [0.5814, 0.0103, 0.0000, 0.5814, 0.5814, 0.5814]]],
dtype=dtype).cuda()
idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2], idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2],
[0, 1, 3]], [0, 1, 3]],
@ -37,7 +38,8 @@ def test_three_interpolate():
[1.0000e+00, 1.7148e-08, 1.4070e-08], [1.0000e+00, 1.7148e-08, 1.4070e-08],
[3.3333e-01, 3.3333e-01, 3.3333e-01], [3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01], [3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01]]]).cuda() [3.3333e-01, 3.3333e-01, 3.3333e-01]]],
dtype=dtype).cuda()
output = three_interpolate(features, idx, weight) output = three_interpolate(features, idx, weight)
expected_output = torch.tensor([[[ expected_output = torch.tensor([[[
@ -70,6 +72,7 @@ def test_three_interpolate():
[ [
3.8760e-01, 1.0300e-02, 8.3569e-09, 3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01 3.8760e-01, 3.8760e-01, 1.9723e-01
]]]).cuda() ]]],
dtype=dtype).cuda()
assert torch.allclose(output, expected_output, 1e-4) assert torch.allclose(output, expected_output, 1e-3, 1e-4)