[Feature] Add interpolate ops from mmdet3d (#1355)

* add ops (interpolate) in mmdet3d

* refactor code

* fix typo

* fix typo

* fix typo

* refactor code
pull/1357/head
dingchang 2021-10-14 20:51:02 +08:00 committed by GitHub
parent 97e5bada4c
commit be5841e45d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 618 additions and 2 deletions

View File

@ -25,5 +25,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- ThreeInterpolate
- ThreeNN
- Weight standardization
- Correlation

View File

@ -25,5 +25,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- ThreeInterpolate
- ThreeNN
- Weight standardization
- Correlation

View File

@ -40,6 +40,8 @@ from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
@ -59,7 +61,8 @@ __all__ = [
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'pixel_group', 'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]

View File

@ -0,0 +1,61 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef THREE_INTERPOLATE_CUDA_KERNEL_CUH
#define THREE_INTERPOLATE_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void three_interpolate_forward_cuda_kernel(
int b, int c, int m, int n, const T *points, const int *__restrict__ idx,
const T *weight, T *out) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
weight += bs_idx * n * 3 + pt_idx * 3;
points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
out += bs_idx * c * n + c_idx * n;
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
weight[2] * points[idx[2]];
}
template <typename T>
__global__ void three_interpolate_backward_cuda_kernel(
int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx,
const T *weight, T *grad_points) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
weight += bs_idx * n * 3 + pt_idx * 3;
grad_points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
}
#endif // THREE_INTERPOLATE_CUDA_KERNEL_CUH

View File

@ -0,0 +1,66 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef THREE_NN_CUDA_KERNEL_CUH
#define THREE_NN_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
const T *unknown, const T *known,
T *dist2, int *__restrict__ idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= n) return;
unknown += bs_idx * n * 3 + pt_idx * 3;
known += bs_idx * m * 3;
dist2 += bs_idx * n * 3 + pt_idx * 3;
idx += bs_idx * n * 3 + pt_idx * 3;
T ux = unknown[0];
T uy = unknown[1];
T uz = unknown[2];
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < m; ++k) {
T x = known[k * 3 + 0];
T y = known[k * 3 + 1];
T z = known[k * 3 + 2];
T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best1) {
best3 = best2;
besti3 = besti2;
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
} else if (d < best2) {
best3 = best2;
besti3 = besti2;
best2 = d;
besti2 = k;
} else if (d < best3) {
best3 = d;
besti3 = k;
}
}
dist2[0] = best1;
dist2[1] = best2;
dist2[2] = best3;
idx[0] = besti1;
idx[1] = besti2;
idx[2] = besti3;
}
#endif // THREE_NN_CUDA_KERNEL_CUH

View File

@ -0,0 +1,66 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "three_interpolate_cuda_kernel.cuh"
void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
const Tensor points,
const Tensor idx,
const Tensor weight,
Tensor out) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "three_interpolate_forward_cuda_kernel", [&] {
three_interpolate_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, n, points.data_ptr<scalar_t>(), idx.data_ptr<int>(),
weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
const Tensor grad_out,
const Tensor idx,
const Tensor weight,
Tensor grad_points) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "three_interpolate_backward_cuda_kernel", [&] {
three_interpolate_backward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, m, grad_out.data_ptr<scalar_t>(), idx.data_ptr<int>(),
weight.data_ptr<scalar_t>(), grad_points.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,35 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "three_nn_cuda_kernel.cuh"
void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2,
Tensor idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
at::cuda::CUDAGuard device_guard(unknown.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
unknown.scalar_type(), "three_nn_forward_cuda_kernel", [&] {
three_nn_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, n, m, unknown.data_ptr<scalar_t>(), known.data_ptr<scalar_t>(),
dist2.data_ptr<scalar_t>(), idx.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -74,6 +74,19 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha);
void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
Tensor idx_tensor, Tensor weight_tensor,
Tensor out_tensor);
void three_interpolate_backward(int b, int c, int n, int m,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor weight_tensor,
Tensor grad_points_tensor);
void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
Tensor known_tensor, Tensor dist2_tensor,
Tensor idx_tensor);
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);
@ -343,6 +356,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"softmax_focal_loss_backward", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("buff"), py::arg("grad_input"),
py::arg("gamma"), py::arg("alpha"));
m.def("three_interpolate_forward", &three_interpolate_forward,
"three_interpolate_forward", py::arg("b"), py::arg("c"), py::arg("m"),
py::arg("n"), py::arg("points_tensor"), py::arg("idx_tensor"),
py::arg("weight_tensor"), py::arg("out_tensor"));
m.def("three_interpolate_backward", &three_interpolate_backward,
"three_interpolate_backward", py::arg("b"), py::arg("c"), py::arg("n"),
py::arg("m"), py::arg("grad_out_tensor"), py::arg("idx_tensor"),
py::arg("weight_tensor"), py::arg("grad_points_tensor"));
m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", py::arg("b"),
py::arg("n"), py::arg("m"), py::arg("unknown_tensor"),
py::arg("known_tensor"), py::arg("dist2_tensor"),
py::arg("idx_tensor"));
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
py::arg("aligned"), py::arg("offset"));

View File

@ -0,0 +1,62 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
const Tensor points,
const Tensor idx,
const Tensor weight, Tensor out);
void three_interpolate_forward_cuda(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out) {
ThreeInterpolateForwardCUDAKernelLauncher(b, c, m, n, points, idx, weight,
out);
};
void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
const Tensor grad_out,
const Tensor idx,
const Tensor weight,
Tensor grad_points);
void three_interpolate_backward_cuda(int b, int c, int n, int m,
const Tensor grad_out, const Tensor idx,
const Tensor weight, Tensor grad_points) {
ThreeInterpolateBackwardCUDAKernelLauncher(b, c, n, m, grad_out, idx, weight,
grad_points);
};
#endif
void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
Tensor idx_tensor, Tensor weight_tensor,
Tensor out_tensor) {
if (points_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor,
weight_tensor, out_tensor);
#else
AT_ERROR("three_interpolate is not compiled with GPU support");
#endif
} else {
AT_ERROR("three_interpolate is not implemented on CPU");
}
}
void three_interpolate_backward(int b, int c, int n, int m,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor weight_tensor,
Tensor grad_points_tensor) {
if (grad_out_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor,
weight_tensor, grad_points_tensor);
#else
AT_ERROR("three_interpolate is not compiled with GPU support");
#endif
} else {
AT_ERROR("three_interpolate is not implemented on CPU");
}
}

View File

@ -0,0 +1,30 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2,
Tensor idx);
void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
ThreeNNForwardCUDAKernelLauncher(b, n, m, unknown, known, dist2, idx);
};
#endif
void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
Tensor known_tensor, Tensor dist2_tensor,
Tensor idx_tensor) {
if (unknown_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor,
idx_tensor);
#else
AT_ERROR("three_nn is not compiled with GPU support");
#endif
} else {
AT_ERROR("three_nn is not implemented on CPU");
}
}

View File

@ -0,0 +1,68 @@
from typing import Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
class ThreeInterpolate(Function):
"""Performs weighted linear interpolation on 3 features.
Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
for more details.
"""
@staticmethod
def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
"""
Args:
features (Tensor): (B, C, M) Features descriptors to be
interpolated
indices (Tensor): (B, n, 3) index three nearest neighbors
of the target features in features
weight (Tensor): (B, n, 3) weights of interpolation
Returns:
Tensor: (B, C, N) tensor of the interpolated features
"""
assert features.is_contiguous()
assert indices.is_contiguous()
assert weight.is_contiguous()
B, c, m = features.size()
n = indices.size(1)
ctx.three_interpolate_for_backward = (indices, weight, m)
output = torch.cuda.FloatTensor(B, c, n)
ext_module.three_interpolate_forward(B, c, m, n, features, indices,
weight, output)
return output
@staticmethod
def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
grad_out (Tensor): (B, C, N) tensor with gradients of outputs
Returns:
Tensor: (B, C, M) tensor with gradients of features
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx,
weight, grad_features.data)
return grad_features, None, None
three_interpolate = ThreeInterpolate.apply

View File

@ -0,0 +1,51 @@
from typing import Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
class ThreeNN(Function):
"""Find the top-3 nearest neighbors of the target set from the source set.
Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
for more details.
"""
@staticmethod
def forward(ctx, target: torch.Tensor,
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
target (Tensor): shape (B, N, 3), points set that needs to
find the nearest neighbors.
source (Tensor): shape (B, M, 3), points set that is used
to find the nearest neighbors of points in target set.
Returns:
Tensor: shape (B, N, 3), L2 distance of each point in target
set to their corresponding nearest neighbors.
"""
target = target.contiguous()
source = source.contiguous()
B, N, _ = target.size()
m = source.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
ext_module.three_nn_forward(B, N, m, target, source, dist2, idx)
ctx.mark_non_differentiable(idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply

View File

@ -0,0 +1,74 @@
import pytest
import torch
from mmcv.ops import three_interpolate
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_three_interpolate():
features = torch.tensor([[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
[2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732],
[0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124],
[0.3207, 0.0000, 0.3411, 0.3207, 0.3207,
0.3207]],
[[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000],
[0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346],
[0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
[0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
[0.5814, 0.0103, 0.0000, 0.5814, 0.5814,
0.5814]]]).cuda()
idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2],
[0, 1, 3]],
[[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4],
[0, 1, 2]]]).int().cuda()
weight = torch.tensor([[[3.3333e-01, 3.3333e-01, 3.3333e-01],
[1.0000e+00, 5.8155e-08, 2.2373e-08],
[1.0000e+00, 1.7737e-08, 1.7356e-08],
[3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01]],
[[3.3333e-01, 3.3333e-01, 3.3333e-01],
[1.0000e+00, 1.3651e-08, 7.7312e-09],
[1.0000e+00, 1.7148e-08, 1.4070e-08],
[3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01],
[3.3333e-01, 3.3333e-01, 3.3333e-01]]]).cuda()
output = three_interpolate(features, idx, weight)
expected_output = torch.tensor([[[
3.8953e+00, 4.4995e+00, 4.4995e+00, 3.8953e+00, 3.8953e+00, 3.2072e+00
], [
2.9320e+00, 3.0447e+00, 3.0447e+00, 2.9320e+00, 2.9320e+00, 2.9583e+00
], [
2.7281e+00, 2.6436e+00, 2.6436e+00, 2.7281e+00, 2.7281e+00, 2.7380e+00
], [
4.6824e+00, 7.0199e+00, 7.0199e+00, 4.6824e+00, 4.6824e+00, 2.3466e+00
], [
2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01
]],
[[
8.1773e-01, 9.5440e-01, 2.4532e+00,
8.1773e-01, 8.1773e-01, 1.1359e+00
],
[
8.4689e-01, 1.9176e+00, 1.4715e+00,
8.4689e-01, 8.4689e-01, 1.3079e+00
],
[
6.9473e-01, 2.7440e-01, 2.0842e+00,
6.9473e-01, 6.9473e-01, 7.8619e-01
],
[
7.6789e-01, 1.5063e+00, 1.6209e+00,
7.6789e-01, 7.6789e-01, 1.1562e+00
],
[
3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01
]]]).cuda()
assert torch.allclose(output, expected_output, 1e-4)

View File

@ -0,0 +1,71 @@
import pytest
import torch
from mmcv.ops import three_nn
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_three_nn():
known = torch.tensor([[[-1.8373, 3.5605,
-0.7867], [0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622],
[-1.8373, 3.5605, -0.7867],
[-1.8373, 3.5605, -0.7867]],
[[-1.3399, 1.9991, -0.3698],
[-0.0799, 0.9698,
-0.8457], [0.0858, 2.4721, -0.1928],
[-1.3399, 1.9991, -0.3698],
[-1.3399, 1.9991, -0.3698]]]).cuda()
unknown = torch.tensor([[[-1.8373, 3.5605, -0.7867],
[0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622],
[-1.5237, 2.3976, -0.8097],
[-0.0722, 3.4017, -0.2880],
[0.5198, 3.0661, -0.4605],
[-2.0185, 3.5019, -0.3236],
[0.5098, 3.1020, 0.5799],
[-1.6137, 3.8443, -0.5269],
[0.7341, 2.9626, -0.3189]],
[[-1.3399, 1.9991, -0.3698],
[-0.0799, 0.9698, -0.8457],
[0.0858, 2.4721, -0.1928],
[-0.9022, 1.6560, -1.3090],
[0.1156, 1.6901, -0.4366],
[-0.6477, 2.3576, -0.1563],
[-0.8482, 1.1466, -1.2704],
[-0.8753, 2.0845, -0.3460],
[-0.5621, 1.4233, -1.2858],
[-0.5883, 1.3114, -1.2899]]]).cuda()
dist, idx = three_nn(unknown, known)
expected_dist = torch.tensor([[[0.0000, 0.0000, 0.0000],
[0.0000, 2.0463, 2.8588],
[0.0000, 1.2229, 1.2229],
[1.2047, 1.2047, 1.2047],
[1.0011, 1.0845, 1.8411],
[0.7433, 1.4451, 2.4304],
[0.5007, 0.5007, 0.5007],
[0.4587, 2.0875, 2.7544],
[0.4450, 0.4450, 0.4450],
[0.5514, 1.7206, 2.6811]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 1.6464, 1.6952],
[0.0000, 1.5125, 1.5125],
[1.0915, 1.0915, 1.0915],
[0.8197, 0.8511, 1.4894],
[0.7433, 0.8082, 0.8082],
[0.8955, 1.3340, 1.3340],
[0.4730, 0.4730, 0.4730],
[0.7949, 1.3325, 1.3325],
[0.7566, 1.3727, 1.3727]]]).cuda()
expected_idx = torch.tensor([[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
[2, 1, 0], [1, 2, 0], [0, 3, 4], [1, 2, 0],
[0, 3, 4], [1, 2, 0]],
[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
[2, 1, 0], [2, 0, 3], [1, 0, 3], [0, 3, 4],
[1, 0, 3], [1, 0, 3]]]).cuda()
assert torch.allclose(dist, expected_dist, 1e-4)
assert torch.all(idx == expected_idx)