From 2d73eafec2fc405bd29a790c4cc6d0ef11b230ff Mon Sep 17 00:00:00 2001 From: pc Date: Sat, 23 Oct 2021 14:18:56 +0800 Subject: [PATCH] add mmdet3d op (#1425) Co-authored-by: zhouzaida --- mmcv/ops/assign_score_withk.py | 44 +++-- mmcv/ops/ball_query.py | 15 +- mmcv/ops/correlation.py | 43 ++++- mmcv/ops/csrc/parrots/assign_score_withk.cpp | 85 +++++++++ .../parrots/assign_score_withk_parrots.cpp | 89 +++++++++ .../csrc/parrots/assign_score_withk_pytorch.h | 19 ++ mmcv/ops/csrc/parrots/ball_query._parrots.cpp | 43 +++++ mmcv/ops/csrc/parrots/ball_query.cpp | 37 ++++ mmcv/ops/csrc/parrots/ball_query_pytorch.h | 11 ++ mmcv/ops/csrc/parrots/correlation.cpp | 87 +++++++++ mmcv/ops/csrc/parrots/correlation_parrots.cpp | 176 ++++++++++++++++++ mmcv/ops/csrc/parrots/correlation_pytorch.h | 18 ++ .../csrc/parrots/furthest_point_sample.cpp | 62 ++++++ .../parrots/furthest_point_sample_parrots.cpp | 57 ++++++ .../parrots/furthest_point_sample_pytorch.h | 14 ++ mmcv/ops/csrc/parrots/gather_points.cpp | 55 ++++++ .../csrc/parrots/gather_points_parrots.cpp | 71 +++++++ mmcv/ops/csrc/parrots/gather_points_pytorch.h | 13 ++ mmcv/ops/csrc/parrots/knn.cpp | 32 ++++ mmcv/ops/csrc/parrots/knn_parrots.cpp | 41 ++++ mmcv/ops/csrc/parrots/knn_pytorch.h | 9 + mmcv/ops/csrc/parrots/roipoint_pool3d.cpp | 60 ++++++ .../csrc/parrots/roipoint_pool3d_parrots.cpp | 31 +++ .../csrc/parrots/roipoint_pool3d_pytorch.h | 10 + mmcv/ops/csrc/parrots/three_interpolate.cpp | 61 ++++++ .../parrots/three_interpolate_parrots.cpp | 74 ++++++++ .../csrc/parrots/three_interpolate_pytorch.h | 14 ++ mmcv/ops/csrc/parrots/three_nn.cpp | 30 +++ mmcv/ops/csrc/parrots/three_nn_parrots.cpp | 35 ++++ mmcv/ops/csrc/parrots/three_nn_pytorch.h | 10 + mmcv/ops/csrc/pytorch/assign_score_withk.cpp | 20 +- mmcv/ops/csrc/pytorch/ball_query.cpp | 6 +- .../csrc/pytorch/furthest_point_sample.cpp | 10 +- mmcv/ops/csrc/pytorch/gather_points.cpp | 12 +- mmcv/ops/csrc/pytorch/knn.cpp | 5 +- mmcv/ops/csrc/pytorch/pybind.cpp | 158 ++++++++-------- mmcv/ops/csrc/pytorch/three_interpolate.cpp | 13 +- mmcv/ops/csrc/pytorch/three_nn.cpp | 6 +- mmcv/ops/furthest_point_sample.py | 18 +- mmcv/ops/gather_points.py | 17 +- mmcv/ops/knn.py | 6 +- mmcv/ops/three_interpolate.py | 8 +- mmcv/ops/three_nn.py | 6 +- 43 files changed, 1473 insertions(+), 158 deletions(-) create mode 100644 mmcv/ops/csrc/parrots/assign_score_withk.cpp create mode 100644 mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/ball_query._parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/ball_query.cpp create mode 100644 mmcv/ops/csrc/parrots/ball_query_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/correlation.cpp create mode 100644 mmcv/ops/csrc/parrots/correlation_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/correlation_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/furthest_point_sample.cpp create mode 100644 mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/gather_points.cpp create mode 100644 mmcv/ops/csrc/parrots/gather_points_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/gather_points_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/knn.cpp create mode 100644 mmcv/ops/csrc/parrots/knn_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/knn_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/roipoint_pool3d.cpp create mode 100644 mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/three_interpolate.cpp create mode 100644 mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/three_interpolate_pytorch.h create mode 100644 mmcv/ops/csrc/parrots/three_nn.cpp create mode 100644 mmcv/ops/csrc/parrots/three_nn_parrots.cpp create mode 100644 mmcv/ops/csrc/parrots/three_nn_pytorch.h diff --git a/mmcv/ops/assign_score_withk.py b/mmcv/ops/assign_score_withk.py index 6cca1cb36..4906adaa2 100644 --- a/mmcv/ops/assign_score_withk.py +++ b/mmcv/ops/assign_score_withk.py @@ -57,12 +57,19 @@ class AssignScoreWithK(Function): _, npoint, K, _ = scores.size() output = point_features.new_zeros((B, out_dim, npoint, K)) - ext_module.assign_score_withk_forward(B, N, npoint, M, K, out_dim, - agg[aggregate], - point_features.contiguous(), - center_features.contiguous(), - scores.contiguous(), - knn_idx.contiguous(), output) + ext_module.assign_score_withk_forward( + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + output, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg[aggregate]) ctx.save_for_backward(output, point_features, center_features, scores, knn_idx) @@ -92,15 +99,22 @@ class AssignScoreWithK(Function): grad_center_features = center_features.new_zeros(center_features.shape) grad_scores = scores.new_zeros(scores.shape) - ext_module.assign_score_withk_backward(B, N, npoint, M, K, out_dim, - agg, grad_out.contiguous(), - point_features.contiguous(), - center_features.contiguous(), - scores.contiguous(), - knn_idx.contiguous(), - grad_point_features, - grad_center_features, - grad_scores) + ext_module.assign_score_withk_backward( + grad_out.contiguous(), + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + grad_point_features, + grad_center_features, + grad_scores, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg) return grad_scores, grad_point_features, \ grad_center_features, None, None diff --git a/mmcv/ops/ball_query.py b/mmcv/ops/ball_query.py index f77bdc8bf..d0466847c 100644 --- a/mmcv/ops/ball_query.py +++ b/mmcv/ops/ball_query.py @@ -33,9 +33,18 @@ class BallQuery(Function): npoint = center_xyz.size(1) idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int) - ext_module.ball_query_forward(B, N, npoint, min_radius, max_radius, - sample_num, center_xyz, xyz, idx) - ctx.mark_non_differentiable(idx) + ext_module.ball_query_forward( + center_xyz, + xyz, + idx, + b=B, + n=N, + m=npoint, + min_radius=min_radius, + max_radius=max_radius, + nsample=sample_num) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) return idx @staticmethod diff --git a/mmcv/ops/correlation.py b/mmcv/ops/correlation.py index f6ddedacf..86dab432c 100644 --- a/mmcv/ops/correlation.py +++ b/mmcv/ops/correlation.py @@ -39,10 +39,22 @@ class CorrelationFunction(Function): output = input1.new_zeros(output_size) - ext_module.correlation_forward(input1, input2, output, kH, kW, - patch_size, patch_size, padH, padW, - dilationH, dilationW, dilation_patchH, - dilation_patchW, dH, dW) + ext_module.correlation_forward( + input1, + input2, + output, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) return output @@ -60,11 +72,24 @@ class CorrelationFunction(Function): grad_input1 = torch.zeros_like(input1) grad_input2 = torch.zeros_like(input2) - ext_module.correlation_backward(grad_output, input1, input2, - grad_input1, grad_input2, kH, kW, - patch_size, patch_size, padH, padW, - dilationH, dilationW, dilation_patchH, - dilation_patchW, dH, dW) + ext_module.correlation_backward( + grad_output, + input1, + input2, + grad_input1, + grad_input2, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) return grad_input1, grad_input2, None, None, None, None, None, None @staticmethod diff --git a/mmcv/ops/csrc/parrots/assign_score_withk.cpp b/mmcv/ops/csrc/parrots/assign_score_withk.cpp new file mode 100644 index 000000000..d35fd2479 --- /dev/null +++ b/mmcv/ops/csrc/parrots/assign_score_withk.cpp @@ -0,0 +1,85 @@ +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void AssignScoreWithKForwardCUDAKernelLauncher( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& points, const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& output); + +void assign_score_withk_forward_cuda(int B, int N0, int N1, int M, int K, int O, + int aggregate, const Tensor& points, + const Tensor& centers, + const Tensor& scores, + const Tensor& knn_idx, Tensor& output) { + AssignScoreWithKForwardCUDAKernelLauncher( + B, N0, N1, M, K, O, aggregate, points, centers, scores, knn_idx, output); +}; + +void AssignScoreWithKBackwardCUDAKernelLauncher( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& grad_out, const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores); + +void assign_score_withk_backward_cuda( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& grad_out, const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores) { + AssignScoreWithKBackwardCUDAKernelLauncher( + B, N0, N1, M, K, O, aggregate, grad_out, points, centers, scores, knn_idx, + grad_points, grad_centers, grad_scores); +}; +#endif + +void assign_score_withk_forward(const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, + Tensor& output, int B, int N0, int N1, int M, + int K, int O, int aggregate) { + if (points.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(centers); + CHECK_CONTIGUOUS(scores); + CHECK_CONTIGUOUS(knn_idx); + CHECK_CONTIGUOUS(output); + + assign_score_withk_forward_cuda(B, N0, N1, M, K, O, aggregate, points, + centers, scores, knn_idx, output); +#else + AT_ERROR("assign_score_withk is not compiled with GPU support"); +#endif + } else { + AT_ERROR("assign_score_withk is not implemented on CPU"); + } +} + +void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points, + const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate) { + if (grad_points.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(scores); + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(centers); + CHECK_CONTIGUOUS(knn_idx); + CHECK_CONTIGUOUS(grad_scores); + CHECK_CONTIGUOUS(grad_points); + CHECK_CONTIGUOUS(grad_centers); + + assign_score_withk_backward_cuda(B, N0, N1, M, K, O, aggregate, grad_out, + points, centers, scores, knn_idx, + grad_points, grad_centers, grad_scores); +#else + AT_ERROR("assign_score_withk is not compiled with GPU support"); +#endif + } else { + AT_ERROR("assign_score_withk is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp b/mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp new file mode 100644 index 000000000..5729c7163 --- /dev/null +++ b/mmcv/ops/csrc/parrots/assign_score_withk_parrots.cpp @@ -0,0 +1,89 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "assign_score_withk_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void assign_score_withk_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int B, N0, N1, M, K, O, aggregate; + SSAttrs(attr) + .get("B", B) + .get("N0", N0) + .get("N1", N1) + .get("M", M) + .get("K", K) + .get("O", O) + .get("aggregate", aggregate) + .done(); + + const auto& points = buildATensor(ctx, ins[0]); + const auto& centers = buildATensor(ctx, ins[1]); + const auto& scores = buildATensor(ctx, ins[2]); + const auto& knn_idx = buildATensor(ctx, ins[3]); + + auto output = buildATensor(ctx, outs[0]); + assign_score_withk_forward(points, centers, scores, knn_idx, output, B, N0, + N1, M, K, O, aggregate); +} + +void assign_score_withk_backward_cuda_parrots( + CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int B, N0, N1, M, K, O, aggregate; + SSAttrs(attr) + .get("B", B) + .get("N0", N0) + .get("N1", N1) + .get("M", M) + .get("K", K) + .get("O", O) + .get("aggregate", aggregate) + .done(); + + const auto& grad_out = buildATensor(ctx, ins[0]); + const auto& points = buildATensor(ctx, ins[1]); + const auto& centers = buildATensor(ctx, ins[2]); + const auto& scores = buildATensor(ctx, ins[3]); + const auto& knn_idx = buildATensor(ctx, ins[4]); + + auto grad_points = buildATensor(ctx, outs[0]); + auto grad_centers = buildATensor(ctx, outs[1]); + auto grad_scores = buildATensor(ctx, outs[2]); + assign_score_withk_backward(grad_out, points, centers, scores, knn_idx, + grad_points, grad_centers, grad_scores, B, N0, N1, + M, K, O, aggregate); +} + +PARROTS_EXTENSION_REGISTER(assign_score_withk_forward) + .attr("B") + .attr("N0") + .attr("N1") + .attr("M") + .attr("K") + .attr("O") + .attr("aggregate") + .input(4) + .output(1) + .apply(assign_score_withk_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(assign_score_withk_backward) + .attr("B") + .attr("N0") + .attr("N1") + .attr("M") + .attr("K") + .attr("O") + .attr("aggregate") + .input(5) + .output(3) + .apply(assign_score_withk_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h b/mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h new file mode 100644 index 000000000..660594fee --- /dev/null +++ b/mmcv/ops/csrc/parrots/assign_score_withk_pytorch.h @@ -0,0 +1,19 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ASSIGN_SCORE_WITHK_PYTORCH_H +#define ASSIGN_SCORE_WITHK_PYTORCH_H +#include +using namespace at; + +void assign_score_withk_forward(const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, + Tensor& output, int B, int N0, int N1, int M, + int K, int O, int aggregate); + +void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points, + const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate); + +#endif // ASSIGN_SCORE_WITHK_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/ball_query._parrots.cpp b/mmcv/ops/csrc/parrots/ball_query._parrots.cpp new file mode 100644 index 000000000..01ab9739b --- /dev/null +++ b/mmcv/ops/csrc/parrots/ball_query._parrots.cpp @@ -0,0 +1,43 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "ball_query_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void ball_query_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m, nsample; + float min_radius, max_radius; + SSAttrs(attr) + .get("b", b) + .get("n", n) + .get("m", m) + .get("nsample", nsample) + .get("min_radius", min_radius) + .get("max_radius", max_radius) + .done(); + + const auto& center_xyz = buildATensor(ctx, ins[0]); + const auto& xyz = buildATensor(ctx, ins[1]); + auto idx = buildATensor(ctx, outs[0]); + ball_query_forward(center_xyz, xyz, idx, b, n, m, min_radius, max_radius, + nsample); +} + +PARROTS_EXTENSION_REGISTER(ball_query_forward) + .attr("b") + .attr("n") + .attr("m") + .attr("nsample") + .attr("min_radius") + .attr("max_radius") + .input(2) + .output(1) + .apply(ball_query_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/ball_query.cpp b/mmcv/ops/csrc/parrots/ball_query.cpp new file mode 100644 index 000000000..fc2709f0d --- /dev/null +++ b/mmcv/ops/csrc/parrots/ball_query.cpp @@ -0,0 +1,37 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + Tensor idx); + +void ball_query_forward_cuda(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + Tensor idx) { + BallQueryForwardCUDAKernelLauncher(b, n, m, min_radius, max_radius, nsample, + new_xyz, xyz, idx); +}; +#endif + +void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor, int b, int n, int m, + float min_radius, float max_radius, int nsample) { + if (new_xyz_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(new_xyz_tensor); + CHECK_CUDA_INPUT(xyz_tensor); + + ball_query_forward_cuda(b, n, m, min_radius, max_radius, nsample, + new_xyz_tensor, xyz_tensor, idx_tensor); +#else + AT_ERROR("ball_query is not compiled with GPU support"); +#endif + } else { + AT_ERROR("ball_query is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/ball_query_pytorch.h b/mmcv/ops/csrc/parrots/ball_query_pytorch.h new file mode 100644 index 000000000..70026f315 --- /dev/null +++ b/mmcv/ops/csrc/parrots/ball_query_pytorch.h @@ -0,0 +1,11 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef BALL_QUERY_PYTORCH_H +#define BALL_QUERY_PYTORCH_H +#include +using namespace at; + +void ball_query_forward(const Tensor new_xyz, const Tensor xyz, Tensor idx, + int b, int n, int m, float min_radius, float max_radius, + int nsample); + +#endif // BALL_QUERY_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/correlation.cpp b/mmcv/ops/csrc/parrots/correlation.cpp new file mode 100644 index 000000000..c3614a500 --- /dev/null +++ b/mmcv/ops/csrc/parrots/correlation.cpp @@ -0,0 +1,87 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA + +void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, + Tensor output, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, + Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationForwardCUDAKernelLauncher( + input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH, + dilationW, dilation_patchH, dilation_patchW, dH, dW); +} + +void correlation_cuda_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationBackwardCUDAKernelLauncher( + grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +#endif + +void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + if (input1.device().is_cuda() && input2.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_forward(input1, input2, output, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +#else + AT_ERROR("Correlation is not compiled with GPU support"); +#endif + } else { + AT_ERROR("Correlation is not implemented on CPU"); + } +} + +void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + if (input1.device().is_cuda() && input2.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(grad_output); + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_backward(grad_output, input1, input2, grad_input1, + grad_input2, kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); + +#else + AT_ERROR("Correlation is not compiled with GPU support"); +#endif + } else { + AT_ERROR("Correlation is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/correlation_parrots.cpp b/mmcv/ops/csrc/parrots/correlation_parrots.cpp new file mode 100644 index 000000000..b1e287d06 --- /dev/null +++ b/mmcv/ops/csrc/parrots/correlation_parrots.cpp @@ -0,0 +1,176 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "correlation_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void correlation_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto input1 = buildATensor(ctx, ins[0]); + auto input2 = buildATensor(ctx, ins[1]); + + auto output = buildATensor(ctx, outs[0]); + + correlation_forward(input1, input2, output, kH, kW, patchH, patchW, padH, + padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +void correlation_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto grad_output = buildATensor(ctx, ins[0]); + auto input1 = buildATensor(ctx, ins[1]); + auto input2 = buildATensor(ctx, ins[2]); + + auto grad_input1 = buildATensor(ctx, outs[0]); + auto grad_input2 = buildATensor(ctx, outs[1]); + + correlation_backward(grad_output, input1, input2, grad_input1, grad_input2, + kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, dH, dW); +} +#endif + +void correlation_forward_cpu_parrots(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto input1 = buildATensor(ctx, ins[0]); + auto input2 = buildATensor(ctx, ins[1]); + + auto output = buildATensor(ctx, outs[0]); + + correlation_forward(input1, input2, output, kH, kW, patchH, patchW, padH, + padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +void correlation_backward_cpu_parrots(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW; + SSAttrs(attr) + .get("kH", kH) + .get("kW", kW) + .get("patchH", patchH) + .get("patchW", patchW) + .get("padH", padH) + .get("padW", padW) + .get("dilationH", dilationH) + .get("dilationW", dilationW) + .get("dilation_patchH", dilation_patchH) + .get("dilation_patchW", dilation_patchW) + .get("dH", dH) + .get("dW", dW) + .done(); + + auto grad_output = buildATensor(ctx, ins[0]); + auto input1 = buildATensor(ctx, ins[1]); + auto input2 = buildATensor(ctx, ins[2]); + + auto grad_input1 = buildATensor(ctx, outs[0]); + auto grad_input2 = buildATensor(ctx, outs[1]); + + correlation_backward(grad_output, input1, input2, grad_input1, grad_input2, + kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, dH, dW); +} + +PARROTS_EXTENSION_REGISTER(correlation_forward) + .attr("kH") + .attr("kW") + .attr("patchH") + .attr("patchW") + .attr("padH") + .attr("padW") + .attr("dilationH") + .attr("dilationW") + .attr("dilation_patchH") + .attr("dilation_patchW") + .attr("dH") + .attr("dW") + .input(2) + .output(1) + .apply(correlation_forward_cpu_parrots) +#ifdef MMCV_WITH_CUDA + .apply(correlation_forward_cuda_parrots) +#endif + .done(); + +PARROTS_EXTENSION_REGISTER(correlation_backward) + .attr("kH") + .attr("kW") + .attr("patchH") + .attr("patchW") + .attr("padH") + .attr("padW") + .attr("dilationH") + .attr("dilationW") + .attr("dilation_patchH") + .attr("dilation_patchW") + .attr("dH") + .attr("dW") + .input(3) + .output(2) + .apply(correlation_backward_cpu_parrots) +#ifdef MMCV_WITH_CUDA + .apply(correlation_backward_cuda_parrots) +#endif + .done(); diff --git a/mmcv/ops/csrc/parrots/correlation_pytorch.h b/mmcv/ops/csrc/parrots/correlation_pytorch.h new file mode 100644 index 000000000..806fcaa71 --- /dev/null +++ b/mmcv/ops/csrc/parrots/correlation_pytorch.h @@ -0,0 +1,18 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef CORRELATION_PYTORCH_H +#define CORRELATION_PYTORCH_H +#include +using namespace at; + +void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); + +#endif // CORRELATION_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/furthest_point_sample.cpp b/mmcv/ops/csrc/parrots/furthest_point_sample.cpp new file mode 100644 index 000000000..e3ec99a82 --- /dev/null +++ b/mmcv/ops/csrc/parrots/furthest_point_sample.cpp @@ -0,0 +1,62 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m, + const float *dataset, + float *temp, int *idxs); + +void furthest_point_sampling_forward_cuda(int b, int n, int m, + const float *dataset, float *temp, + int *idxs) { + FurthestPointSamplingForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs); +} + +void FurthestPointSamplingWithDistForwardCUDAKernelLauncher( + int b, int n, int m, const float *dataset, float *temp, int *idxs); + +void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m, + const float *dataset, + float *temp, int *idxs) { + FurthestPointSamplingWithDistForwardCUDAKernelLauncher(b, n, m, dataset, temp, + idxs); +} +#endif + +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + const float *points = points_tensor.data_ptr(); + float *temp = temp_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + furthest_point_sampling_forward_cuda(b, n, m, points, temp, idx); +#else + AT_ERROR("furthest_point_sampling is not compiled with GPU support"); +#endif + } else { + AT_ERROR("furthest_point_sampling is not implemented on CPU"); + } +} + +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + const float *points = points_tensor.data(); + float *temp = temp_tensor.data(); + int *idx = idx_tensor.data(); + + furthest_point_sampling_with_dist_forward_cuda(b, n, m, points, temp, idx); +#else + AT_ERROR( + "furthest_point_sampling_with_dist is not compiled with GPU support"); +#endif + } else { + AT_ERROR("furthest_point_sampling_with_dist is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp b/mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp new file mode 100644 index 000000000..483bfb243 --- /dev/null +++ b/mmcv/ops/csrc/parrots/furthest_point_sample_parrots.cpp @@ -0,0 +1,57 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "furthest_point_sample_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void furthest_point_sample_forward_cuda_parrots( + CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m; + SSAttrs(attr).get("b", b).get("n", n).get("m", m).done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto temp_tensor = buildATensor(ctx, ins[1]); + + auto idx_tensor = buildATensor(ctx, outs[0]); + + furthest_point_sampling_forward(points_tensor, temp_tensor, idx_tensor, b, n, + m); +} + +void furthest_point_sampling_with_dist_forward_cuda_parrots( + CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m; + SSAttrs(attr).get("b", b).get("n", n).get("m", m).done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto temp_tensor = buildATensor(ctx, ins[1]); + + auto idx_tensor = buildATensor(ctx, outs[0]); + + furthest_point_sampling_with_dist_forward(points_tensor, temp_tensor, + idx_tensor, b, n, m); +} +PARROTS_EXTENSION_REGISTER(furthest_point_sampling_forward) + .attr("b") + .attr("n") + .attr("m") + .input(2) + .output(1) + .apply(furthest_point_sample_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(furthest_point_sampling_with_dist_forward) + .attr("b") + .attr("n") + .attr("m") + .input(2) + .output(1) + .apply(furthest_point_sampling_with_dist_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h b/mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h new file mode 100644 index 000000000..0325cd66e --- /dev/null +++ b/mmcv/ops/csrc/parrots/furthest_point_sample_pytorch.h @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef FURTHEST_POINT_SAMPLE_PYTORCH_H +#define FURTHEST_POINT_SAMPLE_PYTORCH_H +#include +using namespace at; + +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m); + +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m); +#endif // FURTHEST_POINT_SAMPLE_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/gather_points.cpp b/mmcv/ops/csrc/parrots/gather_points.cpp new file mode 100644 index 000000000..3ab93b600 --- /dev/null +++ b/mmcv/ops/csrc/parrots/gather_points.cpp @@ -0,0 +1,55 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + const Tensor points, + const Tensor idx, Tensor out); + +void gather_points_forward_cuda(int b, int c, int n, int npoints, + const Tensor points, const Tensor idx, + Tensor out) { + GatherPointsForwardCUDAKernelLauncher(b, c, n, npoints, points, idx, out); +}; + +void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + const Tensor grad_out, + const Tensor idx, + Tensor grad_points); + +void gather_points_backward_cuda(int b, int c, int n, int npoints, + const Tensor grad_out, const Tensor idx, + Tensor grad_points) { + GatherPointsBackwardCUDAKernelLauncher(b, c, n, npoints, grad_out, idx, + grad_points); +}; +#endif + +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, + int npoints) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor, + out_tensor); +#else + AT_ERROR("gather_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("gather_points is not implemented on CPU"); + } +} + +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor, + grad_points_tensor); +#else + AT_ERROR("gather_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("gather_points is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/gather_points_parrots.cpp b/mmcv/ops/csrc/parrots/gather_points_parrots.cpp new file mode 100644 index 000000000..1d2d9e129 --- /dev/null +++ b/mmcv/ops/csrc/parrots/gather_points_parrots.cpp @@ -0,0 +1,71 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "gather_points_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void gather_points_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, n, npoints; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("n", n) + .get("npoints", npoints) + .done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + + auto out_tensor = buildATensor(ctx, outs[0]); + + gather_points_forward(points_tensor, idx_tensor, out_tensor, b, c, n, + npoints); +} + +void gather_points_backward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, n, npoints; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("n", n) + .get("npoints", npoints) + .done(); + + auto grad_out_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + + auto grad_points_tensor = buildATensor(ctx, outs[0]); + + gather_points_backward(grad_out_tensor, idx_tensor, grad_points_tensor, b, c, + n, npoints); +} + +PARROTS_EXTENSION_REGISTER(gather_points_forward) + .attr("b") + .attr("c") + .attr("n") + .attr("npoints") + .input(2) + .output(1) + .apply(gather_points_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(gather_points_backward) + .attr("b") + .attr("c") + .attr("n") + .attr("npoints") + .input(2) + .output(1) + .apply(gather_points_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/gather_points_pytorch.h b/mmcv/ops/csrc/parrots/gather_points_pytorch.h new file mode 100644 index 000000000..1689ae6ad --- /dev/null +++ b/mmcv/ops/csrc/parrots/gather_points_pytorch.h @@ -0,0 +1,13 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef GATHER_POINTS_PYTORCH_H +#define GATHER_POINTS_PYTORCH_H +#include +using namespace at; + +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, int npoints); + +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints); +#endif // GATHER_POINTS_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/knn.cpp b/mmcv/ops/csrc/parrots/knn.cpp new file mode 100644 index 000000000..55105eb01 --- /dev/null +++ b/mmcv/ops/csrc/parrots/knn.cpp @@ -0,0 +1,32 @@ +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample, + const Tensor xyz, const Tensor new_xyz, + Tensor idx, Tensor dist2); + +void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2); +} +#endif + +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample) { + if (new_xyz_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(new_xyz_tensor); + CHECK_CUDA_INPUT(xyz_tensor); + + knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor, + dist2_tensor); +#else + AT_ERROR("knn is not compiled with GPU support"); +#endif + } else { + AT_ERROR("knn is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/knn_parrots.cpp b/mmcv/ops/csrc/parrots/knn_parrots.cpp new file mode 100644 index 000000000..585b84644 --- /dev/null +++ b/mmcv/ops/csrc/parrots/knn_parrots.cpp @@ -0,0 +1,41 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "knn_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void knn_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m, nsample; + SSAttrs(attr) + .get("b", b) + .get("n", n) + .get("m", m) + .get("nsample", nsample) + .done(); + + auto xyz_tensor = buildATensor(ctx, ins[0]); + auto new_xyz_tensor = buildATensor(ctx, ins[1]); + + auto idx_tensor = buildATensor(ctx, outs[0]); + auto dist2_tensor = buildATensor(ctx, outs[1]); + + knn_forward(xyz_tensor, new_xyz_tensor, idx_tensor, dist2_tensor, b, n, m, + nsample); +} + +PARROTS_EXTENSION_REGISTER(knn_forward) + .attr("b") + .attr("n") + .attr("m") + .attr("nsample") + .input(2) + .output(2) + .apply(knn_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/knn_pytorch.h b/mmcv/ops/csrc/parrots/knn_pytorch.h new file mode 100644 index 000000000..b0875f838 --- /dev/null +++ b/mmcv/ops/csrc/parrots/knn_pytorch.h @@ -0,0 +1,9 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef KNN_PYTORCH_H +#define KNN_PYTORCH_H +#include +using namespace at; + +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample); +#endif // KNN_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/roipoint_pool3d.cpp b/mmcv/ops/csrc/parrots/roipoint_pool3d.cpp new file mode 100644 index 000000000..e9b5054e7 --- /dev/null +++ b/mmcv/ops/csrc/parrots/roipoint_pool3d.cpp @@ -0,0 +1,60 @@ +/* +Modified from +https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d.cpp +Point cloud feature pooling +Written by Shaoshuai Shi +All Rights Reserved 2018. +*/ + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void RoIPointPool3dForwardCUDAKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); + +void roipoint_pool3d_forward_cuda(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag) { + RoIPointPool3dForwardCUDAKernelLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz, + boxes3d, pts_feature, pooled_features, pooled_empty_flag); +}; +#endif + +void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, + Tensor pooled_features, Tensor pooled_empty_flag) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + + if (xyz.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(xyz); + CHECK_CUDA_INPUT(boxes3d); + CHECK_CUDA_INPUT(pts_feature); + CHECK_CUDA_INPUT(pooled_features); + CHECK_CUDA_INPUT(pooled_empty_flag); + + int batch_size = xyz.size(0); + int pts_num = xyz.size(1); + int boxes_num = boxes3d.size(1); + int feature_in_len = pts_feature.size(2); + int sampled_pts_num = pooled_features.size(2); + + roipoint_pool3d_forward_cuda(batch_size, pts_num, boxes_num, feature_in_len, + sampled_pts_num, xyz, boxes3d, pts_feature, + pooled_features, pooled_empty_flag); +#else + AT_ERROR("roipoint_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roipoint_pool3d is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp b/mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp new file mode 100644 index 000000000..17f549849 --- /dev/null +++ b/mmcv/ops/csrc/parrots/roipoint_pool3d_parrots.cpp @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "roipoint_pool3d_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void roipoint_pool3d_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + auto xyz = buildATensor(ctx, ins[0]); + auto boxes3d = buildATensor(ctx, ins[1]); + auto pts_feature = buildATensor(ctx, ins[2]); + + auto pooled_features = buildATensor(ctx, outs[0]); + auto pooled_empty_flag = buildATensor(ctx, outs[1]); + + roipoint_pool3d_forward(xyz, boxes3d, pts_feature, pooled_features, + pooled_empty_flag); +} + +PARROTS_EXTENSION_REGISTER(roipoint_pool3d_forward) + .input(3) + .output(2) + .apply(roipoint_pool3d_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h b/mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h new file mode 100644 index 000000000..e5b61b0d9 --- /dev/null +++ b/mmcv/ops/csrc/parrots/roipoint_pool3d_pytorch.h @@ -0,0 +1,10 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIPOINT_POOL3D_PYTORCH_H +#define ROIPOINT_POOL3D_PYTORCH_H +#include +using namespace at; + +void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, + Tensor pooled_features, Tensor pooled_empty_flag); + +#endif // ROIPOINT_POOL3D_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/three_interpolate.cpp b/mmcv/ops/csrc/parrots/three_interpolate.cpp new file mode 100644 index 000000000..dbbcd995d --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_interpolate.cpp @@ -0,0 +1,61 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n, + const Tensor points, + const Tensor idx, + const Tensor weight, Tensor out); + +void three_interpolate_forward_cuda(int b, int c, int m, int n, + const Tensor points, const Tensor idx, + const Tensor weight, Tensor out) { + ThreeInterpolateForwardCUDAKernelLauncher(b, c, m, n, points, idx, weight, + out); +}; + +void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m, + const Tensor grad_out, + const Tensor idx, + const Tensor weight, + Tensor grad_points); + +void three_interpolate_backward_cuda(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points) { + ThreeInterpolateBackwardCUDAKernelLauncher(b, c, n, m, grad_out, idx, weight, + grad_points); +}; +#endif + +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor, + weight_tensor, out_tensor); +#else + AT_ERROR("three_interpolate is not compiled with GPU support"); +#endif + } else { + AT_ERROR("three_interpolate is not implemented on CPU"); + } +} + +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor, + weight_tensor, grad_points_tensor); +#else + AT_ERROR("three_interpolate is not compiled with GPU support"); +#endif + } else { + AT_ERROR("three_interpolate is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp b/mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp new file mode 100644 index 000000000..a71a90fd1 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_interpolate_parrots.cpp @@ -0,0 +1,74 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "three_interpolate_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void three_interpolate_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, m, n; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("m", m) + .get("n", n) + .done(); + + auto points_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + auto weight_tensor = buildATensor(ctx, ins[2]); + + auto out_tensor = buildATensor(ctx, outs[0]); + + three_interpolate_forward(points_tensor, idx_tensor, weight_tensor, + out_tensor, b, c, m, n); +} + +void three_interpolate_backward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, c, n, m; + SSAttrs(attr) + .get("b", b) + .get("c", c) + .get("n", n) + .get("m", m) + .done(); + + auto grad_out_tensor = buildATensor(ctx, ins[0]); + auto idx_tensor = buildATensor(ctx, ins[1]); + auto weight_tensor = buildATensor(ctx, ins[2]); + + auto grad_points_tensor = buildATensor(ctx, outs[0]); + + three_interpolate_backward(grad_out_tensor, idx_tensor, weight_tensor, + grad_points_tensor, b, c, n, m); +} + +PARROTS_EXTENSION_REGISTER(three_interpolate_forward) + .attr("b") + .attr("c") + .attr("m") + .attr("n") + .input(3) + .output(1) + .apply(three_interpolate_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(three_interpolate_backward) + .attr("b") + .attr("c") + .attr("n") + .attr("m") + .input(3) + .output(1) + .apply(three_interpolate_backward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/three_interpolate_pytorch.h b/mmcv/ops/csrc/parrots/three_interpolate_pytorch.h new file mode 100644 index 000000000..464c6d900 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_interpolate_pytorch.h @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef THREE_INTERPOLATE_PYTORCH_H +#define THREE_INTERPOLATE_PYTORCH_H +#include +using namespace at; + +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n); + +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m); +#endif // THREE_INTERPOLATE_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/three_nn.cpp b/mmcv/ops/csrc/parrots/three_nn.cpp new file mode 100644 index 000000000..158ac0023 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_nn.cpp @@ -0,0 +1,30 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, + Tensor idx); + +void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + ThreeNNForwardCUDAKernelLauncher(b, n, m, unknown, known, dist2, idx); +}; +#endif + +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m) { + if (unknown_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor, + idx_tensor); +#else + AT_ERROR("three_nn is not compiled with GPU support"); +#endif + } else { + AT_ERROR("three_nn is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/three_nn_parrots.cpp b/mmcv/ops/csrc/parrots/three_nn_parrots.cpp new file mode 100644 index 000000000..c28c7d216 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_nn_parrots.cpp @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "three_nn_pytorch.h" + +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void three_nn_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int b, n, m; + SSAttrs(attr).get("b", b).get("n", n).get("m", m).done(); + + auto unknown_tensor = buildATensor(ctx, ins[0]); + auto known_tensor = buildATensor(ctx, ins[1]); + + auto dist2_tensor = buildATensor(ctx, outs[0]); + auto idx_tensor = buildATensor(ctx, outs[1]); + + three_nn_forward(unknown_tensor, known_tensor, dist2_tensor, idx_tensor, b, n, + m); +} + +PARROTS_EXTENSION_REGISTER(three_nn_forward) + .attr("b") + .attr("n") + .attr("m") + .input(2) + .output(2) + .apply(three_nn_forward_cuda_parrots) + .done(); +#endif diff --git a/mmcv/ops/csrc/parrots/three_nn_pytorch.h b/mmcv/ops/csrc/parrots/three_nn_pytorch.h new file mode 100644 index 000000000..6574fba09 --- /dev/null +++ b/mmcv/ops/csrc/parrots/three_nn_pytorch.h @@ -0,0 +1,10 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef THREE_NN_PYTORCH_H +#define THREE_NN_PYTORCH_H +#include +using namespace at; + +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m); +#endif // THREE_NN_PYTORCH_H diff --git a/mmcv/ops/csrc/pytorch/assign_score_withk.cpp b/mmcv/ops/csrc/pytorch/assign_score_withk.cpp index 36bd16432..d35fd2479 100644 --- a/mmcv/ops/csrc/pytorch/assign_score_withk.cpp +++ b/mmcv/ops/csrc/pytorch/assign_score_withk.cpp @@ -34,10 +34,10 @@ void assign_score_withk_backward_cuda( }; #endif -void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor& points, - const Tensor& centers, const Tensor& scores, - const Tensor& knn_idx, Tensor& output) { +void assign_score_withk_forward(const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, + Tensor& output, int B, int N0, int N1, int M, + int K, int O, int aggregate) { if (points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CONTIGUOUS(points); @@ -56,12 +56,12 @@ void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O, } } -void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor& grad_out, - const Tensor& points, const Tensor& centers, - const Tensor& scores, const Tensor& knn_idx, - Tensor& grad_points, Tensor& grad_centers, - Tensor& grad_scores) { +void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points, + const Tensor& centers, const Tensor& scores, + const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate) { if (grad_points.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CONTIGUOUS(grad_out); diff --git a/mmcv/ops/csrc/pytorch/ball_query.cpp b/mmcv/ops/csrc/pytorch/ball_query.cpp index 0a0892ba1..fc2709f0d 100644 --- a/mmcv/ops/csrc/pytorch/ball_query.cpp +++ b/mmcv/ops/csrc/pytorch/ball_query.cpp @@ -18,9 +18,9 @@ void ball_query_forward_cuda(int b, int n, int m, float min_radius, }; #endif -void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, - int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor, - Tensor idx_tensor) { +void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor, int b, int n, int m, + float min_radius, float max_radius, int nsample) { if (new_xyz_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(new_xyz_tensor); diff --git a/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp b/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp index a7bc060a8..e3ec99a82 100644 --- a/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp +++ b/mmcv/ops/csrc/pytorch/furthest_point_sample.cpp @@ -25,8 +25,8 @@ void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m, } #endif -void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor, - Tensor temp_tensor, Tensor idx_tensor) { +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA const float *points = points_tensor.data_ptr(); @@ -41,10 +41,10 @@ void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor, } } -void furthest_point_sampling_with_dist_forward(int b, int n, int m, - Tensor points_tensor, +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, Tensor temp_tensor, - Tensor idx_tensor) { + Tensor idx_tensor, int b, int n, + int m) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA const float *points = points_tensor.data(); diff --git a/mmcv/ops/csrc/pytorch/gather_points.cpp b/mmcv/ops/csrc/pytorch/gather_points.cpp index a56e933c8..3ab93b600 100644 --- a/mmcv/ops/csrc/pytorch/gather_points.cpp +++ b/mmcv/ops/csrc/pytorch/gather_points.cpp @@ -24,9 +24,9 @@ void gather_points_backward_cuda(int b, int c, int n, int npoints, }; #endif -void gather_points_forward(int b, int c, int n, int npoints, - Tensor points_tensor, Tensor idx_tensor, - Tensor out_tensor) { +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, + int npoints) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor, @@ -39,9 +39,9 @@ void gather_points_forward(int b, int c, int n, int npoints, } } -void gather_points_backward(int b, int c, int n, int npoints, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor grad_points_tensor) { +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints) { if (grad_out_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor, diff --git a/mmcv/ops/csrc/pytorch/knn.cpp b/mmcv/ops/csrc/pytorch/knn.cpp index fbbbfc8f2..55105eb01 100644 --- a/mmcv/ops/csrc/pytorch/knn.cpp +++ b/mmcv/ops/csrc/pytorch/knn.cpp @@ -14,9 +14,8 @@ void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz, } #endif -void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, - Tensor new_xyz_tensor, Tensor idx_tensor, - Tensor dist2_tensor) { +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample) { if (new_xyz_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(new_xyz_tensor); diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 8f52e26e8..1845737f3 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -4,17 +4,17 @@ std::string get_compiler_version(); std::string get_compiling_cuda_version(); -void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor &points, - const Tensor ¢ers, const Tensor &scores, - const Tensor &knn_idx, Tensor &output); +void assign_score_withk_forward(const Tensor &points, const Tensor ¢ers, + const Tensor &scores, const Tensor &knn_idx, + Tensor &output, int B, int N0, int N1, int M, + int K, int O, int aggregate); -void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor &grad_out, - const Tensor &points, const Tensor ¢ers, - const Tensor &scores, const Tensor &knn_idx, - Tensor &grad_points, Tensor &grad_centers, - Tensor &grad_scores); +void assign_score_withk_backward(const Tensor &grad_out, const Tensor &points, + const Tensor ¢ers, const Tensor &scores, + const Tensor &knn_idx, Tensor &grad_points, + Tensor &grad_centers, Tensor &grad_scores, + int B, int N0, int N1, int M, int K, int O, + int aggregate); void carafe_naive_forward(Tensor features, Tensor masks, Tensor output, int kernel_size, int group_size, int scale_factor); @@ -76,13 +76,12 @@ void group_points_backward(int b, int c, int n, int npoints, int nsample, void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); -void gather_points_forward(int b, int c, int n, int npoints, - Tensor points_tensor, Tensor idx_tensor, - Tensor out_tensor); +void gather_points_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor, int b, int c, int n, int npoints); -void gather_points_backward(int b, int c, int n, int npoints, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor grad_points_tensor); +void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor, int b, int c, int n, + int npoints); void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); @@ -97,22 +96,23 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha); -void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor, - Tensor idx_tensor, Tensor weight_tensor, - Tensor out_tensor); +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n); -void three_interpolate_backward(int b, int c, int n, int m, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor weight_tensor, - Tensor grad_points_tensor); +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m); -void three_nn_forward(int b, int n, int m, Tensor unknown_tensor, - Tensor known_tensor, Tensor dist2_tensor, - Tensor idx_tensor); +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m); void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); +void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor, int b, int n, int m, int nsample); void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_overlap); @@ -124,16 +124,13 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh); int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh); -void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, - Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor); +void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor, + Tensor idx_tensor, int b, int n, int m); -void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor, - Tensor temp_tensor, Tensor idx_tensor); - -void furthest_point_sampling_with_dist_forward(int b, int n, int m, - Tensor points_tensor, +void furthest_point_sampling_with_dist_forward(Tensor points_tensor, Tensor temp_tensor, - Tensor idx_tensor); + Tensor idx_tensor, int b, int n, + int m); void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx, const Tensor mask_w_idx, Tensor col, @@ -238,9 +235,9 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output); void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input); -void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, - int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor, - Tensor idx_tensor); +void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor, int b, int n, int m, + float min_radius, float max_radius, int nsample); Tensor bottom_pool_forward(Tensor input); @@ -352,32 +349,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), py::arg("scale")); m.def("gather_points_forward", &gather_points_forward, - "gather_points_forward", py::arg("b"), py::arg("c"), py::arg("n"), - py::arg("npoints"), py::arg("points_tensor"), py::arg("idx_tensor"), - py::arg("out_tensor")); + "gather_points_forward", py::arg("points_tensor"), + py::arg("idx_tensor"), py::arg("out_tensor"), py::arg("b"), + py::arg("c"), py::arg("n"), py::arg("npoints")); m.def("gather_points_backward", &gather_points_backward, - "gather_points_backward", py::arg("b"), py::arg("c"), py::arg("n"), - py::arg("npoints"), py::arg("grad_out_tensor"), py::arg("idx_tensor"), - py::arg("grad_points_tensor")); + "gather_points_backward", py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"), + py::arg("c"), py::arg("n"), py::arg("npoints")); m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); m.def("get_compiling_cuda_version", &get_compiling_cuda_version, "get_compiling_cuda_version"); m.def("assign_score_withk_forward", &assign_score_withk_forward, - "assign_score_withk_forward", py::arg("B"), py::arg("N0"), - py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"), - py::arg("aggregate"), py::arg("points"), py::arg("centers"), - py::arg("scores"), py::arg("knn_idx"), py::arg("output")); + "assign_score_withk_forward", py::arg("points"), py::arg("centers"), + py::arg("scores"), py::arg("knn_idx"), py::arg("output"), py::arg("B"), + py::arg("N0"), py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"), + py::arg("aggregate")); m.def("assign_score_withk_backward", &assign_score_withk_backward, - "assign_score_withk_backward", py::arg("B"), py::arg("N0"), - py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"), - py::arg("aggregate"), py::arg("grad_out"), py::arg("points"), + "assign_score_withk_backward", py::arg("grad_out"), py::arg("points"), py::arg("centers"), py::arg("scores"), py::arg("knn_idx"), - py::arg("grad_points"), py::arg("grad_centers"), - py::arg("grad_scores")); - m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), - py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), + py::arg("grad_points"), py::arg("grad_centers"), py::arg("grad_scores"), + py::arg("B"), py::arg("N0"), py::arg("N1"), py::arg("M"), py::arg("K"), + py::arg("O"), py::arg("aggregate")); + m.def("knn_forward", &knn_forward, "knn_forward", py::arg("xyz_tensor"), py::arg("new_xyz_tensor"), py::arg("idx_tensor"), - py::arg("dist2_tensor")); + py::arg("dist2_tensor"), py::arg("b"), py::arg("n"), py::arg("m"), + py::arg("nsample")); m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward", py::arg("features"), py::arg("masks"), py::arg("output"), py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); @@ -447,17 +443,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("buff"), py::arg("grad_input"), py::arg("gamma"), py::arg("alpha")); m.def("three_interpolate_forward", &three_interpolate_forward, - "three_interpolate_forward", py::arg("b"), py::arg("c"), py::arg("m"), - py::arg("n"), py::arg("points_tensor"), py::arg("idx_tensor"), - py::arg("weight_tensor"), py::arg("out_tensor")); + "three_interpolate_forward", py::arg("points_tensor"), + py::arg("idx_tensor"), py::arg("weight_tensor"), py::arg("out_tensor"), + py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n")); m.def("three_interpolate_backward", &three_interpolate_backward, - "three_interpolate_backward", py::arg("b"), py::arg("c"), py::arg("n"), - py::arg("m"), py::arg("grad_out_tensor"), py::arg("idx_tensor"), - py::arg("weight_tensor"), py::arg("grad_points_tensor")); - m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", py::arg("b"), - py::arg("n"), py::arg("m"), py::arg("unknown_tensor"), - py::arg("known_tensor"), py::arg("dist2_tensor"), - py::arg("idx_tensor")); + "three_interpolate_backward", py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("weight_tensor"), + py::arg("grad_points_tensor"), py::arg("b"), py::arg("c"), py::arg("n"), + py::arg("m")); + m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", + py::arg("unknown_tensor"), py::arg("known_tensor"), + py::arg("dist2_tensor"), py::arg("idx_tensor"), py::arg("b"), + py::arg("n"), py::arg("m")); m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); @@ -485,14 +482,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"), py::arg("nms_overlap_thresh")); m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward, - "furthest_point_sampling_forward", py::arg("b"), py::arg("n"), - py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"), - py::arg("idx_tensor")); + "furthest_point_sampling_forward", py::arg("points_tensor"), + py::arg("temp_tensor"), py::arg("idx_tensor"), py::arg("b"), + py::arg("n"), py::arg("m")); m.def("furthest_point_sampling_with_dist_forward", &furthest_point_sampling_with_dist_forward, - "furthest_point_sampling_with_dist_forward", py::arg("b"), py::arg("n"), - py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"), - py::arg("idx_tensor")); + "furthest_point_sampling_with_dist_forward", py::arg("points_tensor"), + py::arg("temp_tensor"), py::arg("idx_tensor"), py::arg("b"), + py::arg("n"), py::arg("m")); m.def("masked_im2col_forward", &masked_im2col_forward, "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), @@ -609,9 +606,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("iou_threshold"), py::arg("multi_label")); m.def("ball_query_forward", &ball_query_forward, "ball_query_forward", + py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"), py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"), - py::arg("max_radius"), py::arg("nsample"), py::arg("new_xyz_tensor"), - py::arg("xyz_tensor"), py::arg("idx_tensor")); + py::arg("max_radius"), py::arg("nsample")); m.def("roi_align_rotated_forward", &roi_align_rotated_forward, "roi_align_rotated forward", py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), @@ -657,6 +654,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "backward function of border_align", py::arg("grad_output"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("pool_size")); + m.def("correlation_forward", &correlation_forward, "Correlation forward", + py::arg("input1"), py::arg("input2"), py::arg("output"), py::arg("kH"), + py::arg("kW"), py::arg("patchH"), py::arg("patchW"), py::arg("padH"), + py::arg("padW"), py::arg("dilationH"), py::arg("dilationW"), + py::arg("dilation_patchH"), py::arg("dilation_patchW"), py::arg("dH"), + py::arg("dW")); + m.def("correlation_backward", &correlation_backward, "Correlation backward", + py::arg("grad_output"), py::arg("input1"), py::arg("input2"), + py::arg("grad_input1"), py::arg("grad_input2"), py::arg("kH"), + py::arg("kW"), py::arg("patchH"), py::arg("patchW"), py::arg("padH"), + py::arg("padW"), py::arg("dilationH"), py::arg("dilationW"), + py::arg("dilation_patchH"), py::arg("dilation_patchW"), py::arg("dH"), + py::arg("dW")); m.def("points_in_boxes_cpu_forward", &points_in_boxes_cpu_forward, "points_in_boxes_cpu_forward", py::arg("boxes_tensor"), py::arg("pts_tensor"), py::arg("pts_indices_tensor")); @@ -674,6 +684,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"), py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"), py::arg("pool_method")); - m.def("correlation_forward", &correlation_forward, "Correlation forward"); - m.def("correlation_backward", &correlation_backward, "Correlation backward"); } diff --git a/mmcv/ops/csrc/pytorch/three_interpolate.cpp b/mmcv/ops/csrc/pytorch/three_interpolate.cpp index 71a7e09ce..dbbcd995d 100644 --- a/mmcv/ops/csrc/pytorch/three_interpolate.cpp +++ b/mmcv/ops/csrc/pytorch/three_interpolate.cpp @@ -30,9 +30,9 @@ void three_interpolate_backward_cuda(int b, int c, int n, int m, }; #endif -void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor, - Tensor idx_tensor, Tensor weight_tensor, - Tensor out_tensor) { +void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor out_tensor, int b, + int c, int m, int n) { if (points_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor, @@ -45,10 +45,9 @@ void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor, } } -void three_interpolate_backward(int b, int c, int n, int m, - Tensor grad_out_tensor, Tensor idx_tensor, - Tensor weight_tensor, - Tensor grad_points_tensor) { +void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor weight_tensor, Tensor grad_points_tensor, + int b, int c, int n, int m) { if (grad_out_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor, diff --git a/mmcv/ops/csrc/pytorch/three_nn.cpp b/mmcv/ops/csrc/pytorch/three_nn.cpp index cba70746c..158ac0023 100644 --- a/mmcv/ops/csrc/pytorch/three_nn.cpp +++ b/mmcv/ops/csrc/pytorch/three_nn.cpp @@ -14,9 +14,9 @@ void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown, }; #endif -void three_nn_forward(int b, int n, int m, Tensor unknown_tensor, - Tensor known_tensor, Tensor dist2_tensor, - Tensor idx_tensor) { +void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor, + Tensor dist2_tensor, Tensor idx_tensor, int b, int n, + int m) { if (unknown_tensor.device().is_cuda()) { #ifdef MMCV_WITH_CUDA three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor, diff --git a/mmcv/ops/furthest_point_sample.py b/mmcv/ops/furthest_point_sample.py index 11cf5fbf5..374b7a878 100644 --- a/mmcv/ops/furthest_point_sample.py +++ b/mmcv/ops/furthest_point_sample.py @@ -30,9 +30,16 @@ class FurthestPointSampling(Function): output = torch.cuda.IntTensor(B, num_points) temp = torch.cuda.FloatTensor(B, N).fill_(1e10) - ext_module.furthest_point_sampling_forward(B, N, num_points, - points_xyz, temp, output) - ctx.mark_non_differentiable(output) + ext_module.furthest_point_sampling_forward( + points_xyz, + temp, + output, + b=B, + n=N, + m=num_points, + ) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) return output @staticmethod @@ -62,8 +69,9 @@ class FurthestPointSamplingWithDist(Function): temp = points_dist.new_zeros([B, N]).fill_(1e10) ext_module.furthest_point_sampling_with_dist_forward( - B, N, num_points, points_dist, temp, output) - ctx.mark_non_differentiable(output) + points_dist, temp, output, b=B, n=N, m=num_points) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) return output @staticmethod diff --git a/mmcv/ops/gather_points.py b/mmcv/ops/gather_points.py index 67c6c3f8e..f52f1677d 100644 --- a/mmcv/ops/gather_points.py +++ b/mmcv/ops/gather_points.py @@ -28,11 +28,12 @@ class GatherPoints(Function): _, C, N = features.size() output = torch.cuda.FloatTensor(B, C, npoint) - ext_module.gather_points_forward(B, C, N, npoint, features, indices, - output) + ext_module.gather_points_forward( + features, indices, output, b=B, c=C, n=N, npoints=npoint) ctx.for_backwards = (indices, C, N) - ctx.mark_non_differentiable(indices) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(indices) return output @staticmethod @@ -42,8 +43,14 @@ class GatherPoints(Function): grad_features = torch.cuda.FloatTensor(B, C, N).zero_() grad_out_data = grad_out.data.contiguous() - ext_module.gather_points_backward(B, C, N, npoint, grad_out_data, idx, - grad_features.data) + ext_module.gather_points_backward( + grad_out_data, + idx, + grad_features.data, + b=B, + c=C, + n=N, + npoints=npoint) return grad_features, None diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index e3428a91c..f33578503 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -61,10 +61,12 @@ class KNN(Function): idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() - ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2) + ext_module.knn_forward( + xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k) # idx shape to [B, k, npoint] idx = idx.transpose(2, 1).contiguous() - ctx.mark_non_differentiable(idx) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) return idx @staticmethod diff --git a/mmcv/ops/three_interpolate.py b/mmcv/ops/three_interpolate.py index 48aa8c5a3..203f47f05 100644 --- a/mmcv/ops/three_interpolate.py +++ b/mmcv/ops/three_interpolate.py @@ -39,8 +39,8 @@ class ThreeInterpolate(Function): ctx.three_interpolate_for_backward = (indices, weight, m) output = torch.cuda.FloatTensor(B, c, n) - ext_module.three_interpolate_forward(B, c, m, n, features, indices, - weight, output) + ext_module.three_interpolate_forward( + features, indices, weight, output, b=B, c=c, m=m, n=n) return output @staticmethod @@ -60,8 +60,8 @@ class ThreeInterpolate(Function): grad_features = torch.cuda.FloatTensor(B, c, m).zero_() grad_out_data = grad_out.data.contiguous() - ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx, - weight, grad_features.data) + ext_module.three_interpolate_backward( + grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m) return grad_features, None, None diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index 459f611a4..2b01047a1 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -37,9 +37,9 @@ class ThreeNN(Function): dist2 = torch.cuda.FloatTensor(B, N, 3) idx = torch.cuda.IntTensor(B, N, 3) - ext_module.three_nn_forward(B, N, m, target, source, dist2, idx) - - ctx.mark_non_differentiable(idx) + ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) return torch.sqrt(dist2), idx