mirror of https://github.com/open-mmlab/mmcv.git
72 lines
3.6 KiB
Python
72 lines
3.6 KiB
Python
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmcv.ops import three_nn
|
||
|
|
||
|
|
||
|
@pytest.mark.skipif(
|
||
|
not torch.cuda.is_available(), reason='requires CUDA support')
|
||
|
def test_three_nn():
|
||
|
known = torch.tensor([[[-1.8373, 3.5605,
|
||
|
-0.7867], [0.7615, 2.9420, 0.2314],
|
||
|
[-0.6503, 3.6637, -1.0622],
|
||
|
[-1.8373, 3.5605, -0.7867],
|
||
|
[-1.8373, 3.5605, -0.7867]],
|
||
|
[[-1.3399, 1.9991, -0.3698],
|
||
|
[-0.0799, 0.9698,
|
||
|
-0.8457], [0.0858, 2.4721, -0.1928],
|
||
|
[-1.3399, 1.9991, -0.3698],
|
||
|
[-1.3399, 1.9991, -0.3698]]]).cuda()
|
||
|
|
||
|
unknown = torch.tensor([[[-1.8373, 3.5605, -0.7867],
|
||
|
[0.7615, 2.9420, 0.2314],
|
||
|
[-0.6503, 3.6637, -1.0622],
|
||
|
[-1.5237, 2.3976, -0.8097],
|
||
|
[-0.0722, 3.4017, -0.2880],
|
||
|
[0.5198, 3.0661, -0.4605],
|
||
|
[-2.0185, 3.5019, -0.3236],
|
||
|
[0.5098, 3.1020, 0.5799],
|
||
|
[-1.6137, 3.8443, -0.5269],
|
||
|
[0.7341, 2.9626, -0.3189]],
|
||
|
[[-1.3399, 1.9991, -0.3698],
|
||
|
[-0.0799, 0.9698, -0.8457],
|
||
|
[0.0858, 2.4721, -0.1928],
|
||
|
[-0.9022, 1.6560, -1.3090],
|
||
|
[0.1156, 1.6901, -0.4366],
|
||
|
[-0.6477, 2.3576, -0.1563],
|
||
|
[-0.8482, 1.1466, -1.2704],
|
||
|
[-0.8753, 2.0845, -0.3460],
|
||
|
[-0.5621, 1.4233, -1.2858],
|
||
|
[-0.5883, 1.3114, -1.2899]]]).cuda()
|
||
|
|
||
|
dist, idx = three_nn(unknown, known)
|
||
|
expected_dist = torch.tensor([[[0.0000, 0.0000, 0.0000],
|
||
|
[0.0000, 2.0463, 2.8588],
|
||
|
[0.0000, 1.2229, 1.2229],
|
||
|
[1.2047, 1.2047, 1.2047],
|
||
|
[1.0011, 1.0845, 1.8411],
|
||
|
[0.7433, 1.4451, 2.4304],
|
||
|
[0.5007, 0.5007, 0.5007],
|
||
|
[0.4587, 2.0875, 2.7544],
|
||
|
[0.4450, 0.4450, 0.4450],
|
||
|
[0.5514, 1.7206, 2.6811]],
|
||
|
[[0.0000, 0.0000, 0.0000],
|
||
|
[0.0000, 1.6464, 1.6952],
|
||
|
[0.0000, 1.5125, 1.5125],
|
||
|
[1.0915, 1.0915, 1.0915],
|
||
|
[0.8197, 0.8511, 1.4894],
|
||
|
[0.7433, 0.8082, 0.8082],
|
||
|
[0.8955, 1.3340, 1.3340],
|
||
|
[0.4730, 0.4730, 0.4730],
|
||
|
[0.7949, 1.3325, 1.3325],
|
||
|
[0.7566, 1.3727, 1.3727]]]).cuda()
|
||
|
expected_idx = torch.tensor([[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
|
||
|
[2, 1, 0], [1, 2, 0], [0, 3, 4], [1, 2, 0],
|
||
|
[0, 3, 4], [1, 2, 0]],
|
||
|
[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
|
||
|
[2, 1, 0], [2, 0, 3], [1, 0, 3], [0, 3, 4],
|
||
|
[1, 0, 3], [1, 0, 3]]]).cuda()
|
||
|
|
||
|
assert torch.allclose(dist, expected_dist, 1e-4)
|
||
|
assert torch.all(idx == expected_idx)
|