mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add interpolate ops from mmdet3d (#1355)
* add ops (interpolate) in mmdet3d * refactor code * fix typo * fix typo * fix typo * refactor codepull/1357/head
parent
97e5bada4c
commit
be5841e45d
|
@ -25,5 +25,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
|||
- SoftmaxFocalLoss
|
||||
- SoftNMS
|
||||
- Synchronized BatchNorm
|
||||
- ThreeInterpolate
|
||||
- ThreeNN
|
||||
- Weight standardization
|
||||
- Correlation
|
||||
|
|
|
@ -25,5 +25,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
|||
- SoftmaxFocalLoss
|
||||
- SoftNMS
|
||||
- Synchronized BatchNorm
|
||||
- ThreeInterpolate
|
||||
- ThreeNN
|
||||
- Weight standardization
|
||||
- Correlation
|
||||
|
|
|
@ -40,6 +40,8 @@ from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
|
|||
from .roi_pool import RoIPool, roi_pool
|
||||
from .saconv import SAConv2d
|
||||
from .sync_bn import SyncBatchNorm
|
||||
from .three_interpolate import three_interpolate
|
||||
from .three_nn import three_nn
|
||||
from .tin_shift import TINShift, tin_shift
|
||||
from .upfirdn2d import upfirdn2d
|
||||
|
||||
|
@ -59,7 +61,8 @@ __all__ = [
|
|||
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
|
||||
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
|
||||
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
|
||||
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
|
||||
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
|
||||
'pixel_group', 'contour_expand', 'three_nn', 'three_interpolate',
|
||||
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
|
||||
'gather_points', 'furthest_point_sample',
|
||||
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef THREE_INTERPOLATE_CUDA_KERNEL_CUH
|
||||
#define THREE_INTERPOLATE_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void three_interpolate_forward_cuda_kernel(
|
||||
int b, int c, int m, int n, const T *points, const int *__restrict__ idx,
|
||||
const T *weight, T *out) {
|
||||
// points: (B, C, M)
|
||||
// idx: (B, N, 3)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// out: (B, C, N)
|
||||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
||||
|
||||
weight += bs_idx * n * 3 + pt_idx * 3;
|
||||
points += bs_idx * c * m + c_idx * m;
|
||||
idx += bs_idx * n * 3 + pt_idx * 3;
|
||||
out += bs_idx * c * n + c_idx * n;
|
||||
|
||||
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
|
||||
weight[2] * points[idx[2]];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void three_interpolate_backward_cuda_kernel(
|
||||
int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx,
|
||||
const T *weight, T *grad_points) {
|
||||
// grad_out: (B, C, N)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// grad_points: (B, C, M)
|
||||
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
||||
|
||||
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
|
||||
weight += bs_idx * n * 3 + pt_idx * 3;
|
||||
grad_points += bs_idx * c * m + c_idx * m;
|
||||
idx += bs_idx * n * 3 + pt_idx * 3;
|
||||
|
||||
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
|
||||
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
|
||||
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
|
||||
}
|
||||
|
||||
#endif // THREE_INTERPOLATE_CUDA_KERNEL_CUH
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef THREE_NN_CUDA_KERNEL_CUH
|
||||
#define THREE_NN_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
|
||||
const T *unknown, const T *known,
|
||||
T *dist2, int *__restrict__ idx) {
|
||||
// unknown: (B, N, 3)
|
||||
// known: (B, M, 3)
|
||||
// output:
|
||||
// dist2: (B, N, 3)
|
||||
// idx: (B, N, 3)
|
||||
|
||||
int bs_idx = blockIdx.y;
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bs_idx >= b || pt_idx >= n) return;
|
||||
|
||||
unknown += bs_idx * n * 3 + pt_idx * 3;
|
||||
known += bs_idx * m * 3;
|
||||
dist2 += bs_idx * n * 3 + pt_idx * 3;
|
||||
idx += bs_idx * n * 3 + pt_idx * 3;
|
||||
|
||||
T ux = unknown[0];
|
||||
T uy = unknown[1];
|
||||
T uz = unknown[2];
|
||||
|
||||
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
|
||||
int besti1 = 0, besti2 = 0, besti3 = 0;
|
||||
for (int k = 0; k < m; ++k) {
|
||||
T x = known[k * 3 + 0];
|
||||
T y = known[k * 3 + 1];
|
||||
T z = known[k * 3 + 2];
|
||||
T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
||||
if (d < best1) {
|
||||
best3 = best2;
|
||||
besti3 = besti2;
|
||||
best2 = best1;
|
||||
besti2 = besti1;
|
||||
best1 = d;
|
||||
besti1 = k;
|
||||
} else if (d < best2) {
|
||||
best3 = best2;
|
||||
besti3 = besti2;
|
||||
best2 = d;
|
||||
besti2 = k;
|
||||
} else if (d < best3) {
|
||||
best3 = d;
|
||||
besti3 = k;
|
||||
}
|
||||
}
|
||||
dist2[0] = best1;
|
||||
dist2[1] = best2;
|
||||
dist2[2] = best3;
|
||||
idx[0] = besti1;
|
||||
idx[1] = besti2;
|
||||
idx[2] = besti3;
|
||||
}
|
||||
|
||||
#endif // THREE_NN_CUDA_KERNEL_CUH
|
|
@ -0,0 +1,66 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#include "three_interpolate_cuda_kernel.cuh"
|
||||
|
||||
void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
|
||||
const Tensor points,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor out) {
|
||||
// points: (B, C, M)
|
||||
// idx: (B, N, 3)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// out: (B, C, N)
|
||||
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
points.scalar_type(), "three_interpolate_forward_cuda_kernel", [&] {
|
||||
three_interpolate_forward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, m, n, points.data_ptr<scalar_t>(), idx.data_ptr<int>(),
|
||||
weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
|
||||
const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor grad_points) {
|
||||
// grad_out: (B, C, N)
|
||||
// weight: (B, N, 3)
|
||||
// output:
|
||||
// grad_points: (B, C, M)
|
||||
|
||||
at::cuda::CUDAGuard device_guard(grad_out.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_out.scalar_type(), "three_interpolate_backward_cuda_kernel", [&] {
|
||||
three_interpolate_backward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, n, m, grad_out.data_ptr<scalar_t>(), idx.data_ptr<int>(),
|
||||
weight.data_ptr<scalar_t>(), grad_points.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#include "three_nn_cuda_kernel.cuh"
|
||||
|
||||
void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2,
|
||||
Tensor idx) {
|
||||
// unknown: (B, N, 3)
|
||||
// known: (B, M, 3)
|
||||
// output:
|
||||
// dist2: (B, N, 3)
|
||||
// idx: (B, N, 3)
|
||||
|
||||
at::cuda::CUDAGuard device_guard(unknown.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
unknown.scalar_type(), "three_nn_forward_cuda_kernel", [&] {
|
||||
three_nn_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
b, n, m, unknown.data_ptr<scalar_t>(), known.data_ptr<scalar_t>(),
|
||||
dist2.data_ptr<scalar_t>(), idx.data_ptr<int>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -74,6 +74,19 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
|
|||
Tensor buff, Tensor grad_input, float gamma,
|
||||
float alpha);
|
||||
|
||||
void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
|
||||
Tensor idx_tensor, Tensor weight_tensor,
|
||||
Tensor out_tensor);
|
||||
|
||||
void three_interpolate_backward(int b, int c, int n, int m,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor,
|
||||
Tensor grad_points_tensor);
|
||||
|
||||
void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
|
||||
Tensor known_tensor, Tensor dist2_tensor,
|
||||
Tensor idx_tensor);
|
||||
|
||||
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset);
|
||||
|
||||
|
@ -343,6 +356,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"softmax_focal_loss_backward", py::arg("input"), py::arg("target"),
|
||||
py::arg("weight"), py::arg("buff"), py::arg("grad_input"),
|
||||
py::arg("gamma"), py::arg("alpha"));
|
||||
m.def("three_interpolate_forward", &three_interpolate_forward,
|
||||
"three_interpolate_forward", py::arg("b"), py::arg("c"), py::arg("m"),
|
||||
py::arg("n"), py::arg("points_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("weight_tensor"), py::arg("out_tensor"));
|
||||
m.def("three_interpolate_backward", &three_interpolate_backward,
|
||||
"three_interpolate_backward", py::arg("b"), py::arg("c"), py::arg("n"),
|
||||
py::arg("m"), py::arg("grad_out_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("weight_tensor"), py::arg("grad_points_tensor"));
|
||||
m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", py::arg("b"),
|
||||
py::arg("n"), py::arg("m"), py::arg("unknown_tensor"),
|
||||
py::arg("known_tensor"), py::arg("dist2_tensor"),
|
||||
py::arg("idx_tensor"));
|
||||
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
|
||||
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
|
||||
py::arg("aligned"), py::arg("offset"));
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
|
||||
const Tensor points,
|
||||
const Tensor idx,
|
||||
const Tensor weight, Tensor out);
|
||||
|
||||
void three_interpolate_forward_cuda(int b, int c, int m, int n,
|
||||
const Tensor points, const Tensor idx,
|
||||
const Tensor weight, Tensor out) {
|
||||
ThreeInterpolateForwardCUDAKernelLauncher(b, c, m, n, points, idx, weight,
|
||||
out);
|
||||
};
|
||||
|
||||
void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
|
||||
const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor grad_points);
|
||||
|
||||
void three_interpolate_backward_cuda(int b, int c, int n, int m,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
const Tensor weight, Tensor grad_points) {
|
||||
ThreeInterpolateBackwardCUDAKernelLauncher(b, c, n, m, grad_out, idx, weight,
|
||||
grad_points);
|
||||
};
|
||||
#endif
|
||||
|
||||
void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
|
||||
Tensor idx_tensor, Tensor weight_tensor,
|
||||
Tensor out_tensor) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor,
|
||||
weight_tensor, out_tensor);
|
||||
#else
|
||||
AT_ERROR("three_interpolate is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("three_interpolate is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void three_interpolate_backward(int b, int c, int n, int m,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor,
|
||||
Tensor grad_points_tensor) {
|
||||
if (grad_out_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor,
|
||||
weight_tensor, grad_points_tensor);
|
||||
#else
|
||||
AT_ERROR("three_interpolate is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("three_interpolate is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2,
|
||||
Tensor idx);
|
||||
|
||||
void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2, Tensor idx) {
|
||||
ThreeNNForwardCUDAKernelLauncher(b, n, m, unknown, known, dist2, idx);
|
||||
};
|
||||
#endif
|
||||
|
||||
void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
|
||||
Tensor known_tensor, Tensor dist2_tensor,
|
||||
Tensor idx_tensor) {
|
||||
if (unknown_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor,
|
||||
idx_tensor);
|
||||
#else
|
||||
AT_ERROR("three_nn is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("three_nn is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
|
||||
|
||||
|
||||
class ThreeInterpolate(Function):
|
||||
"""Performs weighted linear interpolation on 3 features.
|
||||
|
||||
Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
|
||||
for more details.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
|
||||
weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
features (Tensor): (B, C, M) Features descriptors to be
|
||||
interpolated
|
||||
indices (Tensor): (B, n, 3) index three nearest neighbors
|
||||
of the target features in features
|
||||
weight (Tensor): (B, n, 3) weights of interpolation
|
||||
|
||||
Returns:
|
||||
Tensor: (B, C, N) tensor of the interpolated features
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert indices.is_contiguous()
|
||||
assert weight.is_contiguous()
|
||||
|
||||
B, c, m = features.size()
|
||||
n = indices.size(1)
|
||||
ctx.three_interpolate_for_backward = (indices, weight, m)
|
||||
output = torch.cuda.FloatTensor(B, c, n)
|
||||
|
||||
ext_module.three_interpolate_forward(B, c, m, n, features, indices,
|
||||
weight, output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, grad_out: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
grad_out (Tensor): (B, C, N) tensor with gradients of outputs
|
||||
|
||||
Returns:
|
||||
Tensor: (B, C, M) tensor with gradients of features
|
||||
"""
|
||||
idx, weight, m = ctx.three_interpolate_for_backward
|
||||
B, c, n = grad_out.size()
|
||||
|
||||
grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
|
||||
ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx,
|
||||
weight, grad_features.data)
|
||||
return grad_features, None, None
|
||||
|
||||
|
||||
three_interpolate = ThreeInterpolate.apply
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
|
||||
|
||||
|
||||
class ThreeNN(Function):
|
||||
"""Find the top-3 nearest neighbors of the target set from the source set.
|
||||
|
||||
Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
|
||||
for more details.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, target: torch.Tensor,
|
||||
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
target (Tensor): shape (B, N, 3), points set that needs to
|
||||
find the nearest neighbors.
|
||||
source (Tensor): shape (B, M, 3), points set that is used
|
||||
to find the nearest neighbors of points in target set.
|
||||
|
||||
Returns:
|
||||
Tensor: shape (B, N, 3), L2 distance of each point in target
|
||||
set to their corresponding nearest neighbors.
|
||||
"""
|
||||
target = target.contiguous()
|
||||
source = source.contiguous()
|
||||
|
||||
B, N, _ = target.size()
|
||||
m = source.size(1)
|
||||
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
||||
idx = torch.cuda.IntTensor(B, N, 3)
|
||||
|
||||
ext_module.three_nn_forward(B, N, m, target, source, dist2, idx)
|
||||
|
||||
ctx.mark_non_differentiable(idx)
|
||||
|
||||
return torch.sqrt(dist2), idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, a=None, b=None):
|
||||
return None, None
|
||||
|
||||
|
||||
three_nn = ThreeNN.apply
|
|
@ -0,0 +1,74 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import three_interpolate
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_three_interpolate():
|
||||
features = torch.tensor([[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
|
||||
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
|
||||
[2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732],
|
||||
[0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124],
|
||||
[0.3207, 0.0000, 0.3411, 0.3207, 0.3207,
|
||||
0.3207]],
|
||||
[[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000],
|
||||
[0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346],
|
||||
[0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
|
||||
[0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
|
||||
[0.5814, 0.0103, 0.0000, 0.5814, 0.5814,
|
||||
0.5814]]]).cuda()
|
||||
|
||||
idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2],
|
||||
[0, 1, 3]],
|
||||
[[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4],
|
||||
[0, 1, 2]]]).int().cuda()
|
||||
|
||||
weight = torch.tensor([[[3.3333e-01, 3.3333e-01, 3.3333e-01],
|
||||
[1.0000e+00, 5.8155e-08, 2.2373e-08],
|
||||
[1.0000e+00, 1.7737e-08, 1.7356e-08],
|
||||
[3.3333e-01, 3.3333e-01, 3.3333e-01],
|
||||
[3.3333e-01, 3.3333e-01, 3.3333e-01],
|
||||
[3.3333e-01, 3.3333e-01, 3.3333e-01]],
|
||||
[[3.3333e-01, 3.3333e-01, 3.3333e-01],
|
||||
[1.0000e+00, 1.3651e-08, 7.7312e-09],
|
||||
[1.0000e+00, 1.7148e-08, 1.4070e-08],
|
||||
[3.3333e-01, 3.3333e-01, 3.3333e-01],
|
||||
[3.3333e-01, 3.3333e-01, 3.3333e-01],
|
||||
[3.3333e-01, 3.3333e-01, 3.3333e-01]]]).cuda()
|
||||
|
||||
output = three_interpolate(features, idx, weight)
|
||||
expected_output = torch.tensor([[[
|
||||
3.8953e+00, 4.4995e+00, 4.4995e+00, 3.8953e+00, 3.8953e+00, 3.2072e+00
|
||||
], [
|
||||
2.9320e+00, 3.0447e+00, 3.0447e+00, 2.9320e+00, 2.9320e+00, 2.9583e+00
|
||||
], [
|
||||
2.7281e+00, 2.6436e+00, 2.6436e+00, 2.7281e+00, 2.7281e+00, 2.7380e+00
|
||||
], [
|
||||
4.6824e+00, 7.0199e+00, 7.0199e+00, 4.6824e+00, 4.6824e+00, 2.3466e+00
|
||||
], [
|
||||
2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01
|
||||
]],
|
||||
[[
|
||||
8.1773e-01, 9.5440e-01, 2.4532e+00,
|
||||
8.1773e-01, 8.1773e-01, 1.1359e+00
|
||||
],
|
||||
[
|
||||
8.4689e-01, 1.9176e+00, 1.4715e+00,
|
||||
8.4689e-01, 8.4689e-01, 1.3079e+00
|
||||
],
|
||||
[
|
||||
6.9473e-01, 2.7440e-01, 2.0842e+00,
|
||||
6.9473e-01, 6.9473e-01, 7.8619e-01
|
||||
],
|
||||
[
|
||||
7.6789e-01, 1.5063e+00, 1.6209e+00,
|
||||
7.6789e-01, 7.6789e-01, 1.1562e+00
|
||||
],
|
||||
[
|
||||
3.8760e-01, 1.0300e-02, 8.3569e-09,
|
||||
3.8760e-01, 3.8760e-01, 1.9723e-01
|
||||
]]]).cuda()
|
||||
|
||||
assert torch.allclose(output, expected_output, 1e-4)
|
|
@ -0,0 +1,71 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import three_nn
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_three_nn():
|
||||
known = torch.tensor([[[-1.8373, 3.5605,
|
||||
-0.7867], [0.7615, 2.9420, 0.2314],
|
||||
[-0.6503, 3.6637, -1.0622],
|
||||
[-1.8373, 3.5605, -0.7867],
|
||||
[-1.8373, 3.5605, -0.7867]],
|
||||
[[-1.3399, 1.9991, -0.3698],
|
||||
[-0.0799, 0.9698,
|
||||
-0.8457], [0.0858, 2.4721, -0.1928],
|
||||
[-1.3399, 1.9991, -0.3698],
|
||||
[-1.3399, 1.9991, -0.3698]]]).cuda()
|
||||
|
||||
unknown = torch.tensor([[[-1.8373, 3.5605, -0.7867],
|
||||
[0.7615, 2.9420, 0.2314],
|
||||
[-0.6503, 3.6637, -1.0622],
|
||||
[-1.5237, 2.3976, -0.8097],
|
||||
[-0.0722, 3.4017, -0.2880],
|
||||
[0.5198, 3.0661, -0.4605],
|
||||
[-2.0185, 3.5019, -0.3236],
|
||||
[0.5098, 3.1020, 0.5799],
|
||||
[-1.6137, 3.8443, -0.5269],
|
||||
[0.7341, 2.9626, -0.3189]],
|
||||
[[-1.3399, 1.9991, -0.3698],
|
||||
[-0.0799, 0.9698, -0.8457],
|
||||
[0.0858, 2.4721, -0.1928],
|
||||
[-0.9022, 1.6560, -1.3090],
|
||||
[0.1156, 1.6901, -0.4366],
|
||||
[-0.6477, 2.3576, -0.1563],
|
||||
[-0.8482, 1.1466, -1.2704],
|
||||
[-0.8753, 2.0845, -0.3460],
|
||||
[-0.5621, 1.4233, -1.2858],
|
||||
[-0.5883, 1.3114, -1.2899]]]).cuda()
|
||||
|
||||
dist, idx = three_nn(unknown, known)
|
||||
expected_dist = torch.tensor([[[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 2.0463, 2.8588],
|
||||
[0.0000, 1.2229, 1.2229],
|
||||
[1.2047, 1.2047, 1.2047],
|
||||
[1.0011, 1.0845, 1.8411],
|
||||
[0.7433, 1.4451, 2.4304],
|
||||
[0.5007, 0.5007, 0.5007],
|
||||
[0.4587, 2.0875, 2.7544],
|
||||
[0.4450, 0.4450, 0.4450],
|
||||
[0.5514, 1.7206, 2.6811]],
|
||||
[[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 1.6464, 1.6952],
|
||||
[0.0000, 1.5125, 1.5125],
|
||||
[1.0915, 1.0915, 1.0915],
|
||||
[0.8197, 0.8511, 1.4894],
|
||||
[0.7433, 0.8082, 0.8082],
|
||||
[0.8955, 1.3340, 1.3340],
|
||||
[0.4730, 0.4730, 0.4730],
|
||||
[0.7949, 1.3325, 1.3325],
|
||||
[0.7566, 1.3727, 1.3727]]]).cuda()
|
||||
expected_idx = torch.tensor([[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
|
||||
[2, 1, 0], [1, 2, 0], [0, 3, 4], [1, 2, 0],
|
||||
[0, 3, 4], [1, 2, 0]],
|
||||
[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
|
||||
[2, 1, 0], [2, 0, 3], [1, 0, 3], [0, 3, 4],
|
||||
[1, 0, 3], [1, 0, 3]]]).cuda()
|
||||
|
||||
assert torch.allclose(dist, expected_dist, 1e-4)
|
||||
assert torch.all(idx == expected_idx)
|
Loading…
Reference in New Issue