[Feature] Add knn op from mmdet3d (#1354)

* add ops (knn) in mmdet3d

* refactor code

* refactor code

* fix typo

* fix typo

* fix typo

* refactor code

* refactor code

* spell typo

* fix spell typo
pull/1338/head^2
dingchang 2021-10-14 13:33:13 +08:00 committed by GitHub
parent 8016d88067
commit 42c9e71120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 306 additions and 9 deletions

View File

@ -13,6 +13,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- KNN
- MaskedConv
- NMS
- PSAMask

View File

@ -13,6 +13,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- KNN
- MaskedConv
- NMS
- PSAMask

View File

@ -22,6 +22,7 @@ from .furthest_point_sample import (furthest_point_sample,
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
@ -55,9 +56,9 @@ __all__ = [
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'furthest_point_sample', 'furthest_point_sample_with_dist',
'PointsSampler', 'Correlation'
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]

View File

@ -0,0 +1,91 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#ifndef KNN_CUDA_KERNEL_CUH
#define KNN_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
inline __device__ void swap_float(float *x, float *y) {
float tmp = *x;
*x = *y;
*y = tmp;
}
inline __device__ void swap_int(int *x, int *y) {
int tmp = *x;
*x = *y;
*y = tmp;
}
__device__ void reheap(float *dist, int *idx, int k) {
int root = 0;
int child = root * 2 + 1;
while (child < k) {
if (child + 1 < k && dist[child + 1] > dist[child]) child++;
if (dist[root] > dist[child]) return;
swap_float(&dist[root], &dist[child]);
swap_int(&idx[root], &idx[child]);
root = child;
child = root * 2 + 1;
}
}
__device__ void heap_sort(float *dist, int *idx, int k) {
int i;
for (i = k - 1; i > 0; i--) {
swap_float(&dist[0], &dist[i]);
swap_int(&idx[0], &idx[i]);
reheap(dist, idx, i);
}
}
// input: xyz (b, n, 3) new_xyz (b, m, 3)
// output: idx (b, m, nsample) dist2 (b, m, nsample)
template <typename T>
__global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
const T *xyz, const T *new_xyz,
int *__restrict__ idx, T *dist2) {
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
dist2 += bs_idx * m * nsample + pt_idx * nsample;
T new_x = new_xyz[0];
T new_y = new_xyz[1];
T new_z = new_xyz[2];
float best_dist[100];
int best_idx[100];
for (int i = 0; i < nsample; i++) {
best_dist[i] = 1e10;
best_idx[i] = 0;
}
for (int i = 0; i < n; i++) {
T x = xyz[i * 3 + 0];
T y = xyz[i * 3 + 1];
T z = xyz[i * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < best_dist[0]) {
best_dist[0] = d2;
best_idx[0] = i;
reheap(best_dist, best_idx, nsample);
}
}
heap_sort(best_dist, best_idx, nsample);
for (int i = 0; i < nsample; i++) {
idx[i] = best_idx[i];
dist2[i] = best_dist[i];
}
}
#endif // KNN_CUDA_KERNEL_CUH

View File

@ -99,10 +99,10 @@ void modulated_deform_conv_forward(
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
@ -220,10 +220,10 @@ void modulated_deform_conv_backward(
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =

View File

@ -0,0 +1,34 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#include <cmath>
#include <cstdio>
#include "knn_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
const Tensor xyz, const Tensor new_xyz,
Tensor idx, Tensor dist2) {
// param new_xyz: (B, m, 3)
// param xyz: (B, n, 3)
// param idx: (B, m, nsample)
at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "knn_forward_cuda_kernel", [&] {
knn_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, n, m, nsample, xyz.data_ptr<scalar_t>(),
new_xyz.data_ptr<scalar_t>(), idx.data_ptr<int>(),
dist2.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,33 @@
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
const Tensor xyz, const Tensor new_xyz,
Tensor idx, Tensor dist2);
void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2);
}
#endif
void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor,
Tensor dist2_tensor) {
if (new_xyz_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(new_xyz_tensor);
CHECK_CUDA_INPUT(xyz_tensor);
knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor,
dist2_tensor);
#else
AT_ERROR("knn is not compiled with GPU support");
#endif
} else {
AT_ERROR("knn is not implemented on CPU");
}
}

View File

@ -69,6 +69,9 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);
void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor);
void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor);
@ -256,6 +259,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
py::arg("dist2_tensor"));
m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward",
py::arg("features"), py::arg("masks"), py::arg("output"),
py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor"));

75
mmcv/ops/knn.py 100644
View File

@ -0,0 +1,75 @@
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
class KNN(Function):
r"""KNN (CUDA) based on heap data structure.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/pointops/src/knnquery_heap>`_.
Find k-nearest points.
"""
@staticmethod
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
k (int): number of nearest neighbors.
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
xyz coordinates of the features.
center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
False, else (B, 3, npoint). centers of the knn query.
Default: None.
transposed (bool, optional): whether the input tensors are
transposed. Should not explicitly use this keyword when
calling knn (=KNN.apply), just add the fourth param.
Default: False.
Returns:
Tensor: (B, k, npoint) tensor with the indices of
the features that form k-nearest neighbours.
"""
assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
if center_xyz is None:
center_xyz = xyz
if transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()
assert xyz.is_contiguous() # [B, N, 3]
assert center_xyz.is_contiguous() # [B, npoint, 3]
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None
knn = KNN.apply

View File

@ -0,0 +1,54 @@
import pytest
import torch
from mmcv.ops import knn
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_knn():
new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625]],
[[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]]).cuda()
xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666,
-0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418,
-1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334],
[-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432],
[0.4917, 1.1529, -1.3496]],
[[-2.0289, 2.4952,
-0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda()
idx = knn(5, xyz, new_xyz)
new_xyz_ = new_xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz_ = xyz.unsqueeze(1).repeat(1, new_xyz.shape[1], 1, 1)
dist = ((new_xyz_ - xyz_) * (new_xyz_ - xyz_)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)
idx = knn(5,
xyz.transpose(1, 2).contiguous(),
new_xyz.transpose(1, 2).contiguous(), True)
assert torch.all(idx == expected_idx)
idx = knn(5, xyz, xyz)
xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1)
dist = ((xyz_ - xyz__) * (xyz_ - xyz__)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)