[Features] Add stack ball query and stack group points ops ()

* add stack sa model ops

* fix lint

* fix lint

* fix comments

* fix bug

* fix lint

* fix comments

* fix lint

* fix lint

* fix
pull/2371/head
VVsssssk 2022-10-28 00:42:10 +08:00 committed by Zaida Zhou
parent a4f304a5f5
commit 93fe4829f7
14 changed files with 787 additions and 60 deletions

View File

@ -1,28 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
ext_module = ext_loader.load_ext(
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])
class BallQuery(Function):
"""Find nearby points in spherical space."""
@staticmethod
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
def forward(
ctx,
min_radius: float,
max_radius: float,
sample_num: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
xyz_batch_cnt: Optional[torch.Tensor] = None,
center_xyz_batch_cnt: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
or staked input (N1 + N2 ..., 3).
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
query.
query, or staked input (M1 + M2 ..., 3).
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
nums in each batch, just line (M1, M2, ...). Defaults to None.
New in version 1.7.0.
Returns:
torch.Tensor: (B, npoint, nsample) tensor with the indices of the
@ -31,21 +47,34 @@ class BallQuery(Function):
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
assert xyz_batch_cnt.dtype == torch.int
assert center_xyz_batch_cnt.dtype == torch.int
idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
dtype=torch.int32)
ext_module.stack_ball_query_forward(
center_xyz,
center_xyz_batch_cnt,
xyz,
xyz_batch_cnt,
idx,
max_radius=max_radius,
nsample=sample_num,
)
else:
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx

View File

@ -0,0 +1,68 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH
#define STACK_BALL_QUERY_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void stack_ball_query_forward_cuda_kernel(
int B, int M, float radius, int nsample, const T *new_xyz,
const int *new_xyz_batch_cnt, const T *xyz, const int *xyz_batch_cnt,
int *idx) {
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// output:
// idx: (M, nsample)
const T *cur_xyz = xyz;
int *cur_idx = idx;
CUDA_1D_KERNEL_LOOP(pt_idx, M) {
int bs_idx = 0;
for (int pt_cnt = 0; bs_idx < B; bs_idx++) {
pt_cnt += new_xyz_batch_cnt[bs_idx];
if (pt_idx < pt_cnt) break;
}
int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
const T *new_xyz_p = new_xyz + pt_idx * 3;
cur_xyz += xyz_batch_start_idx * 3;
cur_idx += pt_idx * nsample;
float radius2 = radius * radius;
T new_x = new_xyz_p[0];
T new_y = new_xyz_p[1];
T new_z = new_xyz_p[2];
int n = xyz_batch_cnt[bs_idx];
int cnt = 0;
for (int k = 0; k < n; ++k) {
T x = cur_xyz[k * 3 + 0];
T y = cur_xyz[k * 3 + 1];
T z = cur_xyz[k * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
cur_idx[l] = k;
}
}
cur_idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
if (cnt == 0) cur_idx[0] = -1;
}
}
#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH

View File

@ -0,0 +1,97 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include <stdio.h>
template <typename T>
__global__ void stack_group_points_forward_cuda_kernel(
int b, int c, int m, int nsample, const T *features,
const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt,
T *out) {
// :param features: (N1 + N2 ..., C) tensor of features to group
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
// indices of features to group with :param idx: (M1 + M2 ..., nsample) tensor
// containing the indices of features to group with :param idx_batch_cnt:
// (batch_size) [M1 + M2 ...] tensor containing the indices of features to
// group with :return:
// output: (M1 + M2, C, nsample) tensor
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
const T *cur_features = features;
const int *cur_idx = idx;
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);
if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++) {
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}
int features_batch_start_idx = 0;
int features_batch_end_idx = features_batch_cnt[0];
for (int k = 0; k < bs_idx; k++) {
features_batch_start_idx += features_batch_cnt[k];
features_batch_end_idx =
features_batch_start_idx + features_batch_cnt[k + 1];
}
cur_features += features_batch_start_idx * c;
cur_idx += pt_idx * nsample + sample_idx;
int in_idx = cur_idx[0] * c + c_idx;
int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx;
if (in_idx < features_batch_end_idx * c) {
out[out_idx] = cur_features[in_idx];
}
}
}
template <typename T>
__global__ void stack_group_points_backward_cuda_kernel(
int b, int c, int m, int n, int nsample, const T *grad_out, const int *idx,
const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) {
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the
// output from forward :param idx: (M1 + M2 ..., nsample) tensor containing
// the indices of features to group with :param idx_batch_cnt: (batch_size)
// [M1 + M2 ...] tensor containing the indices of features to group with
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
// indices of features to group with :return:
// grad_features: (N1 + N2 ..., C) gradient of the features
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
const T *cur_grad_out = grad_out;
const int *cur_idx = idx;
T *cur_grad_features = grad_features;
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);
if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++) {
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}
int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++)
features_batch_start_idx += features_batch_cnt[k];
cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx;
cur_idx += pt_idx * nsample + sample_idx;
cur_grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx;
atomicAdd(cur_grad_features, cur_grad_out[0]);
}
}
#endif // GROUP_POINTS_CUDA_KERNEL_CUH

View File

@ -15,5 +15,6 @@ using at::Tensor;
using phalf = at::Half;
#define __PHALF(x) (x)
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#endif // PYTORCH_CUDA_HELPER

View File

@ -18,3 +18,21 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample,
new_xyz_tensor, xyz_tensor, idx_tensor);
}
void stack_ball_query_forward_impl(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx) {
DISPATCH_DEVICE_IMPL(stack_ball_query_forward_impl, max_radius, nsample,
new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
}
void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
Tensor xyz_tensor, Tensor xyz_batch_cnt,
Tensor idx_tensor, float max_radius,
int nsample) {
stack_ball_query_forward_impl(max_radius, nsample, new_xyz_tensor,
new_xyz_batch_cnt, xyz_tensor, xyz_batch_cnt,
idx_tensor);
}

View File

@ -67,6 +67,30 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius,
Tensor idx);
REGISTER_DEVICE_IMPL(ball_query_forward_impl, CUDA, ball_query_forward_cuda);
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz,
const Tensor xyz_batch_cnt,
Tensor idx);
void stack_ball_query_forward_cuda(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx) {
StackBallQueryForwardCUDAKernelLauncher(
max_radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
};
void stack_ball_query_forward_impl(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx);
REGISTER_DEVICE_IMPL(stack_ball_query_forward_impl, CUDA,
stack_ball_query_forward_cuda);
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
const bool aligned, const int offset);
@ -571,6 +595,56 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA,
group_points_backward_cuda);
void StackGroupPointsForwardCUDAKernelLauncher(
int b, int c, int m, int nsample, const Tensor features_tensor,
const Tensor features_batch_cnt_tensor, const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor, Tensor out_tensor);
void StackGroupPointsBackwardCUDAKernelLauncher(
int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor);
void stack_group_points_forward_cuda(int b, int c, int m, int nsample,
const Tensor features_tensor,
const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
Tensor out_tensor) {
StackGroupPointsForwardCUDAKernelLauncher(
b, c, m, nsample, features_tensor, features_batch_cnt_tensor, idx_tensor,
idx_batch_cnt_tensor, out_tensor);
};
void stack_group_points_backward_cuda(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor) {
StackGroupPointsBackwardCUDAKernelLauncher(
b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
features_batch_cnt_tensor, grad_features_tensor);
};
void stack_group_points_forward_impl(int b, int c, int m, int nsample,
const Tensor features_tensor,
const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
Tensor out_tensor);
void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor);
REGISTER_DEVICE_IMPL(stack_group_points_forward_impl, CUDA,
stack_group_points_forward_cuda);
REGISTER_DEVICE_IMPL(stack_group_points_backward_impl, CUDA,
stack_group_points_backward_cuda);
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,

View File

@ -0,0 +1,45 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "stack_ball_query_cuda_kernel.cuh"
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz,
const Tensor xyz_batch_cnt,
Tensor idx) {
at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const float *new_xyz_ptr = new_xyz.data_ptr<float>();
// const float *xyz_ptr = xyz.data_ptr<float>();
// const int *new_xyz_batch_cnt_ptr = new_xyz_batch_cnt.data_ptr<int>();
// const int *xyz_batch_cnt_ptr = xyz_batch_cnt.data_ptr<int>();
// int *idx_ptr = idx.data_ptr<int>();
int B = xyz_batch_cnt.size(0);
int M = new_xyz.size(0);
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "stack_ball_query_forward_cuda_kernel", [&] {
stack_ball_query_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
B, M, max_radius, nsample, new_xyz.data_ptr<scalar_t>(),
new_xyz_batch_cnt.data_ptr<int>(), xyz.data_ptr<scalar_t>(),
xyz_batch_cnt.data_ptr<int>(), idx.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,62 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "stack_group_points_cuda_kernel.cuh"
void StackGroupPointsForwardCUDAKernelLauncher(
int b, int c, int m, int nsample, const Tensor features_tensor,
const Tensor features_batch_cnt_tensor, const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor, Tensor out_tensor) {
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
at::cuda::CUDAGuard device_guard(features_tensor.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features_tensor.scalar_type(), "stack_group_points_forward_cuda_kernel",
[&] {
stack_group_points_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, nsample, features_tensor.data_ptr<scalar_t>(),
features_batch_cnt_tensor.data_ptr<int>(),
idx_tensor.data_ptr<int>(),
idx_batch_cnt_tensor.data_ptr<int>(),
out_tensor.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void StackGroupPointsBackwardCUDAKernelLauncher(
int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor) {
at::cuda::CUDAGuard device_guard(grad_features_tensor.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_features_tensor.scalar_type(),
"stack_group_points_backward_cuda_kernel", [&] {
stack_group_points_backward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, n, nsample, grad_out_tensor.data_ptr<scalar_t>(),
idx_tensor.data_ptr<int>(),
idx_batch_cnt_tensor.data_ptr<int>(),
features_batch_cnt_tensor.data_ptr<int>(),
grad_features_tensor.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -32,3 +32,45 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
group_points_backward_impl(b, c, n, npoints, nsample, grad_out_tensor,
idx_tensor, grad_points_tensor);
}
void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor) {
DISPATCH_DEVICE_IMPL(stack_group_points_backward_impl, b, c, m, n, nsample,
grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
features_batch_cnt_tensor, grad_features_tensor);
}
void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
Tensor idx_batch_cnt_tensor,
Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor, int b, int c,
int m, int n, int nsample) {
stack_group_points_backward_impl(
b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
features_batch_cnt_tensor, grad_features_tensor);
}
void stack_group_points_forward_impl(int b, int c, int m, int nsample,
const Tensor features_tensor,
const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
Tensor out_tensor) {
DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample,
features_tensor, features_batch_cnt_tensor, idx_tensor,
idx_batch_cnt_tensor, out_tensor);
}
void stack_group_points_forward(Tensor features_tensor,
Tensor features_batch_cnt_tensor,
Tensor idx_tensor, Tensor idx_batch_cnt_tensor,
Tensor out_tensor, int b, int c, int m,
int nsample) {
DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample,
features_tensor, features_batch_cnt_tensor, idx_tensor,
idx_batch_cnt_tensor, out_tensor);
}

View File

@ -75,6 +75,18 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
Tensor grad_points_tensor, int b, int c, int n,
int npoints, int nsample);
void stack_group_points_forward(Tensor features_tensor,
Tensor features_batch_cnt_tensor,
Tensor idx_tensor, Tensor idx_batch_cnt_tensor,
Tensor out_tensor, int b, int c, int m,
int nsample);
void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
Tensor idx_batch_cnt_tensor,
Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor, int b, int c,
int m, int n, int nsample);
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag);
@ -240,6 +252,10 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
Tensor idx_tensor, int b, int n, int m,
float min_radius, float max_radius, int nsample);
void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
Tensor xyz_tensor, Tensor xyz_batch_cnt,
Tensor idx_tensor, float max_radius, int nsample);
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale);
@ -557,6 +573,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"group_points_backward", py::arg("grad_out_tensor"),
py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"),
py::arg("c"), py::arg("n"), py::arg("npoints"), py::arg("nsample"));
m.def("stack_group_points_forward", &stack_group_points_forward,
"stack_group_points_forward", py::arg("features_tensor"),
py::arg("features_batch_cnt_tensor"), py::arg("idx_tensor"),
py::arg("idx_batch_cnt_tensor"), py::arg("out_tensor"), py::arg("b"),
py::arg("c"), py::arg("m"), py::arg("nsample"));
m.def("stack_group_points_backward", &stack_group_points_backward,
"stack_group_points_backward", py::arg("grad_out_tensor"),
py::arg("idx_tensor"), py::arg("idx_batch_cnt_tensor"),
py::arg("features_batch_cnt_tensor"), py::arg("grad_features_tensor"),
py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n"),
py::arg("nsample"));
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
@ -726,6 +753,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"),
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
py::arg("max_radius"), py::arg("nsample"));
m.def("stack_ball_query_forward", &stack_ball_query_forward,
"stack_ball_query_forward", py::arg("new_xyz_tensor"),
py::arg("new_xyz_batch_cnt"), py::arg("xyz_tensor"),
py::arg("xyz_batch_cnt"), py::arg("idx_tensor"), py::arg("max_radius"),
py::arg("nsample"));
m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
"roi_align_rotated forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),

View File

@ -9,8 +9,10 @@ from ..utils import ext_loader
from .ball_query import ball_query
from .knn import knn
ext_module = ext_loader.load_ext(
'_ext', ['group_points_forward', 'group_points_backward'])
ext_module = ext_loader.load_ext('_ext', [
'group_points_forward', 'group_points_backward',
'stack_group_points_forward', 'stack_group_points_backward'
])
class QueryAndGroup(nn.Module):
@ -183,39 +185,71 @@ class GroupingOperation(Function):
"""Group feature with given index."""
@staticmethod
def forward(ctx, features: torch.Tensor,
indices: torch.Tensor) -> torch.Tensor:
def forward(
ctx,
features: torch.Tensor,
indices: torch.Tensor,
features_batch_cnt: Optional[torch.Tensor] = None,
indices_batch_cnt: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
features (Tensor): (B, C, N) tensor of features to group.
indices (Tensor): (B, npoint, nsample) the indices of
features to group with.
features (Tensor): Tensor of features to group, input shape is
(B, C, N) or stacked inputs (N1 + N2 ..., C).
indices (Tensor): The indices of features to group with, input
shape is (B, npoint, nsample) or stacked inputs
(M1 + M2 ..., nsample).
features_batch_cnt (Tensor, optional): Input features nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
indices_batch_cnt (Tensor, optional): Input indices nums in
each batch, just like (M1, M2, ...). Defaults to None.
New in version 1.7.0.
Returns:
Tensor: (B, C, npoint, nsample) Grouped features.
Tensor: Grouped features, the shape is (B, C, npoint, nsample)
or (M1 + M2 ..., C, nsample).
"""
features = features.contiguous()
indices = indices.contiguous()
if features_batch_cnt is not None and indices_batch_cnt is not None:
assert features_batch_cnt.dtype == torch.int
assert indices_batch_cnt.dtype == torch.int
M, nsample = indices.size()
N, C = features.size()
B = indices_batch_cnt.shape[0]
output = features.new_zeros((M, C, nsample))
ext_module.stack_group_points_forward(
features,
features_batch_cnt,
indices,
indices_batch_cnt,
output,
b=B,
m=M,
c=C,
nsample=nsample)
ctx.for_backwards = (B, N, indices, features_batch_cnt,
indices_batch_cnt)
else:
B, nfeatures, nsample = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
B, nfeatures, nsample = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
ext_module.group_points_forward(
features,
indices,
output,
b=B,
c=C,
n=N,
npoints=nfeatures,
nsample=nsample)
ext_module.group_points_forward(
features,
indices,
output,
b=B,
c=C,
n=N,
npoints=nfeatures,
nsample=nsample)
ctx.for_backwards = (indices, N)
ctx.for_backwards = (indices, N)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
def backward(ctx, grad_out: torch.Tensor) -> Tuple:
"""
Args:
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
@ -224,22 +258,42 @@ class GroupingOperation(Function):
Returns:
Tensor: (B, C, N) gradient of the features.
"""
idx, N = ctx.for_backwards
if len(ctx.for_backwards) != 5:
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
B, C, npoint, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.group_points_backward(
grad_out_data,
idx,
grad_features.data,
b=B,
c=C,
n=N,
npoints=npoint,
nsample=nsample)
return grad_features, None
grad_out_data = grad_out.data.contiguous()
ext_module.group_points_backward(
grad_out_data,
idx,
grad_features.data,
b=B,
c=C,
n=N,
npoints=npoint,
nsample=nsample)
return grad_features, None
else:
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
M, C, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(N, C).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.stack_group_points_backward(
grad_out_data,
idx,
idx_batch_cnt,
features_batch_cnt,
grad_features.data,
b=B,
c=C,
m=M,
n=N,
nsample=nsample)
return grad_features, None, None, None
grouping_operation = GroupingOperation.apply

Binary file not shown.

View File

@ -53,3 +53,50 @@ def test_ball_query():
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_ball_query():
new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625],
[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]).cuda()
new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda()
xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334], [-0.2822, 1.3192, -1.3645],
[0.1533, 1.5024, -1.0432], [0.4917, 1.1529, -1.3496],
[-2.0289, 2.4952, -0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]).cuda()
xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda()
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[2, 2, 2, 2, 2], [7, 7, 7, 7, 7],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda()
assert torch.all(idx == expected_idx)
xyz = xyz.double()
new_xyz = new_xyz.double()
expected_idx = expected_idx.double()
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
assert torch.all(idx == expected_idx)
xyz = xyz.half()
new_xyz = new_xyz.half()
expected_idx = expected_idx.half()
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
assert torch.all(idx == expected_idx)

View File

@ -12,7 +12,7 @@ def test_grouping_points():
[0, 0, 0]],
[[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
[0, 0, 0]]]).int().cuda()
festures = torch.tensor([[[
features = torch.tensor([[[
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
0.9268, 0.8414
],
@ -37,7 +37,7 @@ def test_grouping_points():
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386
]]]).cuda()
output = grouping_operation(festures, idx)
output = grouping_operation(features, idx)
expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
[-1.3311, -1.3311, -1.3311],
[0.9268, 0.9268, 0.9268],
@ -75,3 +75,161 @@ def test_grouping_points():
[-0.6646, -0.6646, -0.6646],
[-0.6646, -0.6646, -0.6646]]]]).cuda()
assert torch.allclose(output, expected_output)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_grouping_points():
idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
[2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
[1, 1, 1], [0, 0, 0]]).int().cuda()
features = torch.tensor([[
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
0.9268, 0.8414
],
[
5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
5.1030, 1.9360, 2.1939, 2.1581, 3.4666
],
[
-1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
-0.5732, -1.0830, -1.7561, -1.6786, -1.6967
],
[
-0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
0.7798, -0.3693, -0.9457, -0.2942, -1.8527
],
[
1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
2.7346, 6.0865, 1.5555, 4.3303, 2.8229
],
[
-0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386
]]).float().cuda()
features_batch_cnt = torch.tensor([3, 3]).int().cuda()
indices_batch_cnt = torch.tensor([6, 6]).int().cuda()
output = grouping_operation(features, idx, features_batch_cnt,
indices_batch_cnt)
expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798],
[-0.7981, -0.7981, -0.7981],
[-0.9280, -0.9280, -0.9280],
[-1.3311, -1.3311, -1.3311],
[1.3687, 1.3687, 1.3687],
[0.9277, 0.9277, 0.9277],
[-0.4164, -0.4164, -0.4164],
[-1.8274, -1.8274, -1.8274],
[0.9268, 0.9268, 0.9268],
[0.8414, 0.8414, 0.8414]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[5.4247, 5.4247, 5.4247],
[1.5113, 1.5113, 1.5113],
[2.3944, 2.3944, 2.3944],
[1.4740, 1.4740, 1.4740],
[5.0300, 5.0300, 5.0300],
[5.1030, 5.1030, 5.1030],
[1.9360, 1.9360, 1.9360],
[2.1939, 2.1939, 2.1939],
[2.1581, 2.1581, 2.1581],
[3.4666, 3.4666, 3.4666]],
[[0.5798, 0.5798, 0.5798],
[-0.7981, -0.7981, -0.7981],
[-0.9280, -0.9280, -0.9280],
[-1.3311, -1.3311, -1.3311],
[1.3687, 1.3687, 1.3687],
[0.9277, 0.9277, 0.9277],
[-0.4164, -0.4164, -0.4164],
[-1.8274, -1.8274, -1.8274],
[0.9268, 0.9268, 0.9268],
[0.8414, 0.8414, 0.8414]],
[[-1.6266, -1.6266, -1.6266],
[-1.0281, -1.0281, -1.0281],
[-1.0393, -1.0393, -1.0393],
[-1.6931, -1.6931, -1.6931],
[-1.3982, -1.3982, -1.3982],
[-0.5732, -0.5732, -0.5732],
[-1.0830, -1.0830, -1.0830],
[-1.7561, -1.7561, -1.7561],
[-1.6786, -1.6786, -1.6786],
[-1.6967, -1.6967, -1.6967]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527, -1.8527]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527, -1.8527]],
[[1.1773, 1.1773, 1.1773],
[1.5009, 1.5009, 1.5009],
[2.6399, 2.6399, 2.6399],
[5.9242, 5.9242, 5.9242],
[1.0962, 1.0962, 1.0962],
[2.7346, 2.7346, 2.7346],
[6.0865, 6.0865, 6.0865],
[1.5555, 1.5555, 1.5555],
[4.3303, 4.3303, 4.3303],
[2.8229, 2.8229, 2.8229]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527,
-1.8527]]]).cuda().float()
assert torch.allclose(output, expected_output)