mirror of https://github.com/open-mmlab/mmcv.git
parent
75cae78c55
commit
2d73eafec2
|
@ -57,12 +57,19 @@ class AssignScoreWithK(Function):
|
|||
_, npoint, K, _ = scores.size()
|
||||
|
||||
output = point_features.new_zeros((B, out_dim, npoint, K))
|
||||
ext_module.assign_score_withk_forward(B, N, npoint, M, K, out_dim,
|
||||
agg[aggregate],
|
||||
point_features.contiguous(),
|
||||
center_features.contiguous(),
|
||||
scores.contiguous(),
|
||||
knn_idx.contiguous(), output)
|
||||
ext_module.assign_score_withk_forward(
|
||||
point_features.contiguous(),
|
||||
center_features.contiguous(),
|
||||
scores.contiguous(),
|
||||
knn_idx.contiguous(),
|
||||
output,
|
||||
B=B,
|
||||
N0=N,
|
||||
N1=npoint,
|
||||
M=M,
|
||||
K=K,
|
||||
O=out_dim,
|
||||
aggregate=agg[aggregate])
|
||||
|
||||
ctx.save_for_backward(output, point_features, center_features, scores,
|
||||
knn_idx)
|
||||
|
@ -92,15 +99,22 @@ class AssignScoreWithK(Function):
|
|||
grad_center_features = center_features.new_zeros(center_features.shape)
|
||||
grad_scores = scores.new_zeros(scores.shape)
|
||||
|
||||
ext_module.assign_score_withk_backward(B, N, npoint, M, K, out_dim,
|
||||
agg, grad_out.contiguous(),
|
||||
point_features.contiguous(),
|
||||
center_features.contiguous(),
|
||||
scores.contiguous(),
|
||||
knn_idx.contiguous(),
|
||||
grad_point_features,
|
||||
grad_center_features,
|
||||
grad_scores)
|
||||
ext_module.assign_score_withk_backward(
|
||||
grad_out.contiguous(),
|
||||
point_features.contiguous(),
|
||||
center_features.contiguous(),
|
||||
scores.contiguous(),
|
||||
knn_idx.contiguous(),
|
||||
grad_point_features,
|
||||
grad_center_features,
|
||||
grad_scores,
|
||||
B=B,
|
||||
N0=N,
|
||||
N1=npoint,
|
||||
M=M,
|
||||
K=K,
|
||||
O=out_dim,
|
||||
aggregate=agg)
|
||||
|
||||
return grad_scores, grad_point_features, \
|
||||
grad_center_features, None, None
|
||||
|
|
|
@ -33,9 +33,18 @@ class BallQuery(Function):
|
|||
npoint = center_xyz.size(1)
|
||||
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
|
||||
|
||||
ext_module.ball_query_forward(B, N, npoint, min_radius, max_radius,
|
||||
sample_num, center_xyz, xyz, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
ext_module.ball_query_forward(
|
||||
center_xyz,
|
||||
xyz,
|
||||
idx,
|
||||
b=B,
|
||||
n=N,
|
||||
m=npoint,
|
||||
min_radius=min_radius,
|
||||
max_radius=max_radius,
|
||||
nsample=sample_num)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return idx
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -39,10 +39,22 @@ class CorrelationFunction(Function):
|
|||
|
||||
output = input1.new_zeros(output_size)
|
||||
|
||||
ext_module.correlation_forward(input1, input2, output, kH, kW,
|
||||
patch_size, patch_size, padH, padW,
|
||||
dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW)
|
||||
ext_module.correlation_forward(
|
||||
input1,
|
||||
input2,
|
||||
output,
|
||||
kH=kH,
|
||||
kW=kW,
|
||||
patchH=patch_size,
|
||||
patchW=patch_size,
|
||||
padH=padH,
|
||||
padW=padW,
|
||||
dilationH=dilationH,
|
||||
dilationW=dilationW,
|
||||
dilation_patchH=dilation_patchH,
|
||||
dilation_patchW=dilation_patchW,
|
||||
dH=dH,
|
||||
dW=dW)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -60,11 +72,24 @@ class CorrelationFunction(Function):
|
|||
grad_input1 = torch.zeros_like(input1)
|
||||
grad_input2 = torch.zeros_like(input2)
|
||||
|
||||
ext_module.correlation_backward(grad_output, input1, input2,
|
||||
grad_input1, grad_input2, kH, kW,
|
||||
patch_size, patch_size, padH, padW,
|
||||
dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW)
|
||||
ext_module.correlation_backward(
|
||||
grad_output,
|
||||
input1,
|
||||
input2,
|
||||
grad_input1,
|
||||
grad_input2,
|
||||
kH=kH,
|
||||
kW=kW,
|
||||
patchH=patch_size,
|
||||
patchW=patch_size,
|
||||
padH=padH,
|
||||
padW=padW,
|
||||
dilationH=dilationH,
|
||||
dilationW=dilationW,
|
||||
dilation_patchH=dilation_patchH,
|
||||
dilation_patchW=dilation_patchW,
|
||||
dH=dH,
|
||||
dW=dW)
|
||||
return grad_input1, grad_input2, None, None, None, None, None, None
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
// Modified from
|
||||
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void AssignScoreWithKForwardCUDAKernelLauncher(
|
||||
int B, int N0, int N1, int M, int K, int O, int aggregate,
|
||||
const Tensor& points, const Tensor& centers, const Tensor& scores,
|
||||
const Tensor& knn_idx, Tensor& output);
|
||||
|
||||
void assign_score_withk_forward_cuda(int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate, const Tensor& points,
|
||||
const Tensor& centers,
|
||||
const Tensor& scores,
|
||||
const Tensor& knn_idx, Tensor& output) {
|
||||
AssignScoreWithKForwardCUDAKernelLauncher(
|
||||
B, N0, N1, M, K, O, aggregate, points, centers, scores, knn_idx, output);
|
||||
};
|
||||
|
||||
void AssignScoreWithKBackwardCUDAKernelLauncher(
|
||||
int B, int N0, int N1, int M, int K, int O, int aggregate,
|
||||
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
|
||||
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
|
||||
Tensor& grad_centers, Tensor& grad_scores);
|
||||
|
||||
void assign_score_withk_backward_cuda(
|
||||
int B, int N0, int N1, int M, int K, int O, int aggregate,
|
||||
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
|
||||
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
|
||||
Tensor& grad_centers, Tensor& grad_scores) {
|
||||
AssignScoreWithKBackwardCUDAKernelLauncher(
|
||||
B, N0, N1, M, K, O, aggregate, grad_out, points, centers, scores, knn_idx,
|
||||
grad_points, grad_centers, grad_scores);
|
||||
};
|
||||
#endif
|
||||
|
||||
void assign_score_withk_forward(const Tensor& points, const Tensor& centers,
|
||||
const Tensor& scores, const Tensor& knn_idx,
|
||||
Tensor& output, int B, int N0, int N1, int M,
|
||||
int K, int O, int aggregate) {
|
||||
if (points.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CONTIGUOUS(points);
|
||||
CHECK_CONTIGUOUS(centers);
|
||||
CHECK_CONTIGUOUS(scores);
|
||||
CHECK_CONTIGUOUS(knn_idx);
|
||||
CHECK_CONTIGUOUS(output);
|
||||
|
||||
assign_score_withk_forward_cuda(B, N0, N1, M, K, O, aggregate, points,
|
||||
centers, scores, knn_idx, output);
|
||||
#else
|
||||
AT_ERROR("assign_score_withk is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("assign_score_withk is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points,
|
||||
const Tensor& centers, const Tensor& scores,
|
||||
const Tensor& knn_idx, Tensor& grad_points,
|
||||
Tensor& grad_centers, Tensor& grad_scores,
|
||||
int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate) {
|
||||
if (grad_points.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CONTIGUOUS(grad_out);
|
||||
CHECK_CONTIGUOUS(scores);
|
||||
CHECK_CONTIGUOUS(points);
|
||||
CHECK_CONTIGUOUS(centers);
|
||||
CHECK_CONTIGUOUS(knn_idx);
|
||||
CHECK_CONTIGUOUS(grad_scores);
|
||||
CHECK_CONTIGUOUS(grad_points);
|
||||
CHECK_CONTIGUOUS(grad_centers);
|
||||
|
||||
assign_score_withk_backward_cuda(B, N0, N1, M, K, O, aggregate, grad_out,
|
||||
points, centers, scores, knn_idx,
|
||||
grad_points, grad_centers, grad_scores);
|
||||
#else
|
||||
AT_ERROR("assign_score_withk is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("assign_score_withk is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "assign_score_withk_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void assign_score_withk_forward_cuda_parrots(CudaContext& ctx,
|
||||
const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int B, N0, N1, M, K, O, aggregate;
|
||||
SSAttrs(attr)
|
||||
.get<int>("B", B)
|
||||
.get<int>("N0", N0)
|
||||
.get<int>("N1", N1)
|
||||
.get<int>("M", M)
|
||||
.get<int>("K", K)
|
||||
.get<int>("O", O)
|
||||
.get<int>("aggregate", aggregate)
|
||||
.done();
|
||||
|
||||
const auto& points = buildATensor(ctx, ins[0]);
|
||||
const auto& centers = buildATensor(ctx, ins[1]);
|
||||
const auto& scores = buildATensor(ctx, ins[2]);
|
||||
const auto& knn_idx = buildATensor(ctx, ins[3]);
|
||||
|
||||
auto output = buildATensor(ctx, outs[0]);
|
||||
assign_score_withk_forward(points, centers, scores, knn_idx, output, B, N0,
|
||||
N1, M, K, O, aggregate);
|
||||
}
|
||||
|
||||
void assign_score_withk_backward_cuda_parrots(
|
||||
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int B, N0, N1, M, K, O, aggregate;
|
||||
SSAttrs(attr)
|
||||
.get<int>("B", B)
|
||||
.get<int>("N0", N0)
|
||||
.get<int>("N1", N1)
|
||||
.get<int>("M", M)
|
||||
.get<int>("K", K)
|
||||
.get<int>("O", O)
|
||||
.get<int>("aggregate", aggregate)
|
||||
.done();
|
||||
|
||||
const auto& grad_out = buildATensor(ctx, ins[0]);
|
||||
const auto& points = buildATensor(ctx, ins[1]);
|
||||
const auto& centers = buildATensor(ctx, ins[2]);
|
||||
const auto& scores = buildATensor(ctx, ins[3]);
|
||||
const auto& knn_idx = buildATensor(ctx, ins[4]);
|
||||
|
||||
auto grad_points = buildATensor(ctx, outs[0]);
|
||||
auto grad_centers = buildATensor(ctx, outs[1]);
|
||||
auto grad_scores = buildATensor(ctx, outs[2]);
|
||||
assign_score_withk_backward(grad_out, points, centers, scores, knn_idx,
|
||||
grad_points, grad_centers, grad_scores, B, N0, N1,
|
||||
M, K, O, aggregate);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(assign_score_withk_forward)
|
||||
.attr("B")
|
||||
.attr("N0")
|
||||
.attr("N1")
|
||||
.attr("M")
|
||||
.attr("K")
|
||||
.attr("O")
|
||||
.attr("aggregate")
|
||||
.input(4)
|
||||
.output(1)
|
||||
.apply(assign_score_withk_forward_cuda_parrots)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(assign_score_withk_backward)
|
||||
.attr("B")
|
||||
.attr("N0")
|
||||
.attr("N1")
|
||||
.attr("M")
|
||||
.attr("K")
|
||||
.attr("O")
|
||||
.attr("aggregate")
|
||||
.input(5)
|
||||
.output(3)
|
||||
.apply(assign_score_withk_backward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ASSIGN_SCORE_WITHK_PYTORCH_H
|
||||
#define ASSIGN_SCORE_WITHK_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void assign_score_withk_forward(const Tensor& points, const Tensor& centers,
|
||||
const Tensor& scores, const Tensor& knn_idx,
|
||||
Tensor& output, int B, int N0, int N1, int M,
|
||||
int K, int O, int aggregate);
|
||||
|
||||
void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points,
|
||||
const Tensor& centers, const Tensor& scores,
|
||||
const Tensor& knn_idx, Tensor& grad_points,
|
||||
Tensor& grad_centers, Tensor& grad_scores,
|
||||
int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate);
|
||||
|
||||
#endif // ASSIGN_SCORE_WITHK_PYTORCH_H
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "ball_query_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void ball_query_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, n, m, nsample;
|
||||
float min_radius, max_radius;
|
||||
SSAttrs(attr)
|
||||
.get<int>("b", b)
|
||||
.get<int>("n", n)
|
||||
.get<int>("m", m)
|
||||
.get<int>("nsample", nsample)
|
||||
.get<float>("min_radius", min_radius)
|
||||
.get<float>("max_radius", max_radius)
|
||||
.done();
|
||||
|
||||
const auto& center_xyz = buildATensor(ctx, ins[0]);
|
||||
const auto& xyz = buildATensor(ctx, ins[1]);
|
||||
auto idx = buildATensor(ctx, outs[0]);
|
||||
ball_query_forward(center_xyz, xyz, idx, b, n, m, min_radius, max_radius,
|
||||
nsample);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(ball_query_forward)
|
||||
.attr("b")
|
||||
.attr("n")
|
||||
.attr("m")
|
||||
.attr("nsample")
|
||||
.attr("min_radius")
|
||||
.attr("max_radius")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(ball_query_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,37 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius,
|
||||
float max_radius, int nsample,
|
||||
const Tensor new_xyz, const Tensor xyz,
|
||||
Tensor idx);
|
||||
|
||||
void ball_query_forward_cuda(int b, int n, int m, float min_radius,
|
||||
float max_radius, int nsample,
|
||||
const Tensor new_xyz, const Tensor xyz,
|
||||
Tensor idx) {
|
||||
BallQueryForwardCUDAKernelLauncher(b, n, m, min_radius, max_radius, nsample,
|
||||
new_xyz, xyz, idx);
|
||||
};
|
||||
#endif
|
||||
|
||||
void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m,
|
||||
float min_radius, float max_radius, int nsample) {
|
||||
if (new_xyz_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(new_xyz_tensor);
|
||||
CHECK_CUDA_INPUT(xyz_tensor);
|
||||
|
||||
ball_query_forward_cuda(b, n, m, min_radius, max_radius, nsample,
|
||||
new_xyz_tensor, xyz_tensor, idx_tensor);
|
||||
#else
|
||||
AT_ERROR("ball_query is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("ball_query is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef BALL_QUERY_PYTORCH_H
|
||||
#define BALL_QUERY_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void ball_query_forward(const Tensor new_xyz, const Tensor xyz, Tensor idx,
|
||||
int b, int n, int m, float min_radius, float max_radius,
|
||||
int nsample);
|
||||
|
||||
#endif // BALL_QUERY_PYTORCH_H
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include <iostream>
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
|
||||
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
|
||||
Tensor output, int kH, int kW,
|
||||
int patchH, int patchW, int padH,
|
||||
int padW, int dilationH,
|
||||
int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1,
|
||||
Tensor input2, Tensor grad_input1,
|
||||
Tensor grad_input2, int kH, int kW,
|
||||
int patchH, int patchW, int padH,
|
||||
int padW, int dilationH,
|
||||
int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output,
|
||||
int kH, int kW, int patchH, int patchW, int padH,
|
||||
int padW, int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW, int dH,
|
||||
int dW) {
|
||||
CorrelationForwardCUDAKernelLauncher(
|
||||
input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH,
|
||||
dilationW, dilation_patchH, dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
void correlation_cuda_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2, int kH,
|
||||
int kW, int patchH, int patchW, int padH,
|
||||
int padW, int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW, int dH,
|
||||
int dW) {
|
||||
CorrelationBackwardCUDAKernelLauncher(
|
||||
grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH,
|
||||
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW) {
|
||||
if (input1.device().is_cuda() && input2.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(input1);
|
||||
CHECK_CUDA_INPUT(input2);
|
||||
correlation_cuda_forward(input1, input2, output, kH, kW, patchH, patchW,
|
||||
padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
#else
|
||||
AT_ERROR("Correlation is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("Correlation is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW) {
|
||||
if (input1.device().is_cuda() && input2.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(grad_output);
|
||||
CHECK_CUDA_INPUT(input1);
|
||||
CHECK_CUDA_INPUT(input2);
|
||||
correlation_cuda_backward(grad_output, input1, input2, grad_input1,
|
||||
grad_input2, kH, kW, patchH, patchW, padH, padW,
|
||||
dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
|
||||
#else
|
||||
AT_ERROR("Correlation is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("Correlation is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "correlation_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void correlation_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW;
|
||||
SSAttrs(attr)
|
||||
.get<int>("kH", kH)
|
||||
.get<int>("kW", kW)
|
||||
.get<int>("patchH", patchH)
|
||||
.get<int>("patchW", patchW)
|
||||
.get<int>("padH", padH)
|
||||
.get<int>("padW", padW)
|
||||
.get<int>("dilationH", dilationH)
|
||||
.get<int>("dilationW", dilationW)
|
||||
.get<int>("dilation_patchH", dilation_patchH)
|
||||
.get<int>("dilation_patchW", dilation_patchW)
|
||||
.get<int>("dH", dH)
|
||||
.get<int>("dW", dW)
|
||||
.done();
|
||||
|
||||
auto input1 = buildATensor(ctx, ins[0]);
|
||||
auto input2 = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto output = buildATensor(ctx, outs[0]);
|
||||
|
||||
correlation_forward(input1, input2, output, kH, kW, patchH, patchW, padH,
|
||||
padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
void correlation_backward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW;
|
||||
SSAttrs(attr)
|
||||
.get<int>("kH", kH)
|
||||
.get<int>("kW", kW)
|
||||
.get<int>("patchH", patchH)
|
||||
.get<int>("patchW", patchW)
|
||||
.get<int>("padH", padH)
|
||||
.get<int>("padW", padW)
|
||||
.get<int>("dilationH", dilationH)
|
||||
.get<int>("dilationW", dilationW)
|
||||
.get<int>("dilation_patchH", dilation_patchH)
|
||||
.get<int>("dilation_patchW", dilation_patchW)
|
||||
.get<int>("dH", dH)
|
||||
.get<int>("dW", dW)
|
||||
.done();
|
||||
|
||||
auto grad_output = buildATensor(ctx, ins[0]);
|
||||
auto input1 = buildATensor(ctx, ins[1]);
|
||||
auto input2 = buildATensor(ctx, ins[2]);
|
||||
|
||||
auto grad_input1 = buildATensor(ctx, outs[0]);
|
||||
auto grad_input2 = buildATensor(ctx, outs[1]);
|
||||
|
||||
correlation_backward(grad_output, input1, input2, grad_input1, grad_input2,
|
||||
kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW, dH, dW);
|
||||
}
|
||||
#endif
|
||||
|
||||
void correlation_forward_cpu_parrots(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW;
|
||||
SSAttrs(attr)
|
||||
.get<int>("kH", kH)
|
||||
.get<int>("kW", kW)
|
||||
.get<int>("patchH", patchH)
|
||||
.get<int>("patchW", patchW)
|
||||
.get<int>("padH", padH)
|
||||
.get<int>("padW", padW)
|
||||
.get<int>("dilationH", dilationH)
|
||||
.get<int>("dilationW", dilationW)
|
||||
.get<int>("dilation_patchH", dilation_patchH)
|
||||
.get<int>("dilation_patchW", dilation_patchW)
|
||||
.get<int>("dH", dH)
|
||||
.get<int>("dW", dW)
|
||||
.done();
|
||||
|
||||
auto input1 = buildATensor(ctx, ins[0]);
|
||||
auto input2 = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto output = buildATensor(ctx, outs[0]);
|
||||
|
||||
correlation_forward(input1, input2, output, kH, kW, patchH, patchW, padH,
|
||||
padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
void correlation_backward_cpu_parrots(HostContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW;
|
||||
SSAttrs(attr)
|
||||
.get<int>("kH", kH)
|
||||
.get<int>("kW", kW)
|
||||
.get<int>("patchH", patchH)
|
||||
.get<int>("patchW", patchW)
|
||||
.get<int>("padH", padH)
|
||||
.get<int>("padW", padW)
|
||||
.get<int>("dilationH", dilationH)
|
||||
.get<int>("dilationW", dilationW)
|
||||
.get<int>("dilation_patchH", dilation_patchH)
|
||||
.get<int>("dilation_patchW", dilation_patchW)
|
||||
.get<int>("dH", dH)
|
||||
.get<int>("dW", dW)
|
||||
.done();
|
||||
|
||||
auto grad_output = buildATensor(ctx, ins[0]);
|
||||
auto input1 = buildATensor(ctx, ins[1]);
|
||||
auto input2 = buildATensor(ctx, ins[2]);
|
||||
|
||||
auto grad_input1 = buildATensor(ctx, outs[0]);
|
||||
auto grad_input2 = buildATensor(ctx, outs[1]);
|
||||
|
||||
correlation_backward(grad_output, input1, input2, grad_input1, grad_input2,
|
||||
kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(correlation_forward)
|
||||
.attr("kH")
|
||||
.attr("kW")
|
||||
.attr("patchH")
|
||||
.attr("patchW")
|
||||
.attr("padH")
|
||||
.attr("padW")
|
||||
.attr("dilationH")
|
||||
.attr("dilationW")
|
||||
.attr("dilation_patchH")
|
||||
.attr("dilation_patchW")
|
||||
.attr("dH")
|
||||
.attr("dW")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(correlation_forward_cpu_parrots)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(correlation_forward_cuda_parrots)
|
||||
#endif
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(correlation_backward)
|
||||
.attr("kH")
|
||||
.attr("kW")
|
||||
.attr("patchH")
|
||||
.attr("patchW")
|
||||
.attr("padH")
|
||||
.attr("padW")
|
||||
.attr("dilationH")
|
||||
.attr("dilationW")
|
||||
.attr("dilation_patchH")
|
||||
.attr("dilation_patchW")
|
||||
.attr("dH")
|
||||
.attr("dW")
|
||||
.input(3)
|
||||
.output(2)
|
||||
.apply(correlation_backward_cpu_parrots)
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
.apply(correlation_backward_cuda_parrots)
|
||||
#endif
|
||||
.done();
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef CORRELATION_PYTORCH_H
|
||||
#define CORRELATION_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
#endif // CORRELATION_PYTORCH_H
|
|
@ -0,0 +1,62 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m,
|
||||
const float *dataset,
|
||||
float *temp, int *idxs);
|
||||
|
||||
void furthest_point_sampling_forward_cuda(int b, int n, int m,
|
||||
const float *dataset, float *temp,
|
||||
int *idxs) {
|
||||
FurthestPointSamplingForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs);
|
||||
}
|
||||
|
||||
void FurthestPointSamplingWithDistForwardCUDAKernelLauncher(
|
||||
int b, int n, int m, const float *dataset, float *temp, int *idxs);
|
||||
|
||||
void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m,
|
||||
const float *dataset,
|
||||
float *temp, int *idxs) {
|
||||
FurthestPointSamplingWithDistForwardCUDAKernelLauncher(b, n, m, dataset, temp,
|
||||
idxs);
|
||||
}
|
||||
#endif
|
||||
|
||||
void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
const float *points = points_tensor.data_ptr<float>();
|
||||
float *temp = temp_tensor.data_ptr<float>();
|
||||
int *idx = idx_tensor.data_ptr<int>();
|
||||
furthest_point_sampling_forward_cuda(b, n, m, points, temp, idx);
|
||||
#else
|
||||
AT_ERROR("furthest_point_sampling is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("furthest_point_sampling is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void furthest_point_sampling_with_dist_forward(Tensor points_tensor,
|
||||
Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n,
|
||||
int m) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
const float *points = points_tensor.data<float>();
|
||||
float *temp = temp_tensor.data<float>();
|
||||
int *idx = idx_tensor.data<int>();
|
||||
|
||||
furthest_point_sampling_with_dist_forward_cuda(b, n, m, points, temp, idx);
|
||||
#else
|
||||
AT_ERROR(
|
||||
"furthest_point_sampling_with_dist is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("furthest_point_sampling_with_dist is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "furthest_point_sample_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void furthest_point_sample_forward_cuda_parrots(
|
||||
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, n, m;
|
||||
SSAttrs(attr).get<int>("b", b).get<int>("n", n).get<int>("m", m).done();
|
||||
|
||||
auto points_tensor = buildATensor(ctx, ins[0]);
|
||||
auto temp_tensor = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto idx_tensor = buildATensor(ctx, outs[0]);
|
||||
|
||||
furthest_point_sampling_forward(points_tensor, temp_tensor, idx_tensor, b, n,
|
||||
m);
|
||||
}
|
||||
|
||||
void furthest_point_sampling_with_dist_forward_cuda_parrots(
|
||||
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, n, m;
|
||||
SSAttrs(attr).get<int>("b", b).get<int>("n", n).get<int>("m", m).done();
|
||||
|
||||
auto points_tensor = buildATensor(ctx, ins[0]);
|
||||
auto temp_tensor = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto idx_tensor = buildATensor(ctx, outs[0]);
|
||||
|
||||
furthest_point_sampling_with_dist_forward(points_tensor, temp_tensor,
|
||||
idx_tensor, b, n, m);
|
||||
}
|
||||
PARROTS_EXTENSION_REGISTER(furthest_point_sampling_forward)
|
||||
.attr("b")
|
||||
.attr("n")
|
||||
.attr("m")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(furthest_point_sample_forward_cuda_parrots)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(furthest_point_sampling_with_dist_forward)
|
||||
.attr("b")
|
||||
.attr("n")
|
||||
.attr("m")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(furthest_point_sampling_with_dist_forward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef FURTHEST_POINT_SAMPLE_PYTORCH_H
|
||||
#define FURTHEST_POINT_SAMPLE_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m);
|
||||
|
||||
void furthest_point_sampling_with_dist_forward(Tensor points_tensor,
|
||||
Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n,
|
||||
int m);
|
||||
#endif // FURTHEST_POINT_SAMPLE_PYTORCH_H
|
|
@ -0,0 +1,55 @@
|
|||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
||||
const Tensor points,
|
||||
const Tensor idx, Tensor out);
|
||||
|
||||
void gather_points_forward_cuda(int b, int c, int n, int npoints,
|
||||
const Tensor points, const Tensor idx,
|
||||
Tensor out) {
|
||||
GatherPointsForwardCUDAKernelLauncher(b, c, n, npoints, points, idx, out);
|
||||
};
|
||||
|
||||
void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
||||
const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
Tensor grad_points);
|
||||
|
||||
void gather_points_backward_cuda(int b, int c, int n, int npoints,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
Tensor grad_points) {
|
||||
GatherPointsBackwardCUDAKernelLauncher(b, c, n, npoints, grad_out, idx,
|
||||
grad_points);
|
||||
};
|
||||
#endif
|
||||
|
||||
void gather_points_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor, int b, int c, int n,
|
||||
int npoints) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor,
|
||||
out_tensor);
|
||||
#else
|
||||
AT_ERROR("gather_points is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("gather_points is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor, int b, int c, int n,
|
||||
int npoints) {
|
||||
if (grad_out_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor,
|
||||
grad_points_tensor);
|
||||
#else
|
||||
AT_ERROR("gather_points is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("gather_points is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "gather_points_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void gather_points_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, c, n, npoints;
|
||||
SSAttrs(attr)
|
||||
.get<int>("b", b)
|
||||
.get<int>("c", c)
|
||||
.get<int>("n", n)
|
||||
.get<int>("npoints", npoints)
|
||||
.done();
|
||||
|
||||
auto points_tensor = buildATensor(ctx, ins[0]);
|
||||
auto idx_tensor = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto out_tensor = buildATensor(ctx, outs[0]);
|
||||
|
||||
gather_points_forward(points_tensor, idx_tensor, out_tensor, b, c, n,
|
||||
npoints);
|
||||
}
|
||||
|
||||
void gather_points_backward_cuda_parrots(CudaContext& ctx,
|
||||
const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, c, n, npoints;
|
||||
SSAttrs(attr)
|
||||
.get<int>("b", b)
|
||||
.get<int>("c", c)
|
||||
.get<int>("n", n)
|
||||
.get<int>("npoints", npoints)
|
||||
.done();
|
||||
|
||||
auto grad_out_tensor = buildATensor(ctx, ins[0]);
|
||||
auto idx_tensor = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto grad_points_tensor = buildATensor(ctx, outs[0]);
|
||||
|
||||
gather_points_backward(grad_out_tensor, idx_tensor, grad_points_tensor, b, c,
|
||||
n, npoints);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(gather_points_forward)
|
||||
.attr("b")
|
||||
.attr("c")
|
||||
.attr("n")
|
||||
.attr("npoints")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(gather_points_forward_cuda_parrots)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(gather_points_backward)
|
||||
.attr("b")
|
||||
.attr("c")
|
||||
.attr("n")
|
||||
.attr("npoints")
|
||||
.input(2)
|
||||
.output(1)
|
||||
.apply(gather_points_backward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,13 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef GATHER_POINTS_PYTORCH_H
|
||||
#define GATHER_POINTS_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void gather_points_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor, int b, int c, int n, int npoints);
|
||||
|
||||
void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor, int b, int c, int n,
|
||||
int npoints);
|
||||
#endif // GATHER_POINTS_PYTORCH_H
|
|
@ -0,0 +1,32 @@
|
|||
// Modified from
|
||||
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
|
||||
const Tensor xyz, const Tensor new_xyz,
|
||||
Tensor idx, Tensor dist2);
|
||||
|
||||
void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz,
|
||||
const Tensor new_xyz, Tensor idx, Tensor dist2) {
|
||||
KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2);
|
||||
}
|
||||
#endif
|
||||
|
||||
void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor,
|
||||
Tensor dist2_tensor, int b, int n, int m, int nsample) {
|
||||
if (new_xyz_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(new_xyz_tensor);
|
||||
CHECK_CUDA_INPUT(xyz_tensor);
|
||||
|
||||
knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor,
|
||||
dist2_tensor);
|
||||
#else
|
||||
AT_ERROR("knn is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("knn is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "knn_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void knn_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, n, m, nsample;
|
||||
SSAttrs(attr)
|
||||
.get<int>("b", b)
|
||||
.get<int>("n", n)
|
||||
.get<int>("m", m)
|
||||
.get<int>("nsample", nsample)
|
||||
.done();
|
||||
|
||||
auto xyz_tensor = buildATensor(ctx, ins[0]);
|
||||
auto new_xyz_tensor = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto idx_tensor = buildATensor(ctx, outs[0]);
|
||||
auto dist2_tensor = buildATensor(ctx, outs[1]);
|
||||
|
||||
knn_forward(xyz_tensor, new_xyz_tensor, idx_tensor, dist2_tensor, b, n, m,
|
||||
nsample);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(knn_forward)
|
||||
.attr("b")
|
||||
.attr("n")
|
||||
.attr("m")
|
||||
.attr("nsample")
|
||||
.input(2)
|
||||
.output(2)
|
||||
.apply(knn_forward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,9 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef KNN_PYTORCH_H
|
||||
#define KNN_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor,
|
||||
Tensor dist2_tensor, int b, int n, int m, int nsample);
|
||||
#endif // KNN_PYTORCH_H
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
Modified from
|
||||
https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d.cpp
|
||||
Point cloud feature pooling
|
||||
Written by Shaoshuai Shi
|
||||
All Rights Reserved 2018.
|
||||
*/
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void RoIPointPool3dForwardCUDAKernelLauncher(
|
||||
int batch_size, int pts_num, int boxes_num, int feature_in_len,
|
||||
int sampled_pts_num, const Tensor xyz, const Tensor boxes3d,
|
||||
const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag);
|
||||
|
||||
void roipoint_pool3d_forward_cuda(int batch_size, int pts_num, int boxes_num,
|
||||
int feature_in_len, int sampled_pts_num,
|
||||
const Tensor xyz, const Tensor boxes3d,
|
||||
const Tensor pts_feature,
|
||||
Tensor pooled_features,
|
||||
Tensor pooled_empty_flag) {
|
||||
RoIPointPool3dForwardCUDAKernelLauncher(
|
||||
batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz,
|
||||
boxes3d, pts_feature, pooled_features, pooled_empty_flag);
|
||||
};
|
||||
#endif
|
||||
|
||||
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
|
||||
Tensor pooled_features, Tensor pooled_empty_flag) {
|
||||
// params xyz: (B, N, 3)
|
||||
// params boxes3d: (B, M, 7)
|
||||
// params pts_feature: (B, N, C)
|
||||
// params pooled_features: (B, M, 512, 3+C)
|
||||
// params pooled_empty_flag: (B, M)
|
||||
|
||||
if (xyz.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(xyz);
|
||||
CHECK_CUDA_INPUT(boxes3d);
|
||||
CHECK_CUDA_INPUT(pts_feature);
|
||||
CHECK_CUDA_INPUT(pooled_features);
|
||||
CHECK_CUDA_INPUT(pooled_empty_flag);
|
||||
|
||||
int batch_size = xyz.size(0);
|
||||
int pts_num = xyz.size(1);
|
||||
int boxes_num = boxes3d.size(1);
|
||||
int feature_in_len = pts_feature.size(2);
|
||||
int sampled_pts_num = pooled_features.size(2);
|
||||
|
||||
roipoint_pool3d_forward_cuda(batch_size, pts_num, boxes_num, feature_in_len,
|
||||
sampled_pts_num, xyz, boxes3d, pts_feature,
|
||||
pooled_features, pooled_empty_flag);
|
||||
#else
|
||||
AT_ERROR("roipoint_pool3d is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("roipoint_pool3d is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "roipoint_pool3d_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void roipoint_pool3d_forward_cuda_parrots(CudaContext& ctx,
|
||||
const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
auto xyz = buildATensor(ctx, ins[0]);
|
||||
auto boxes3d = buildATensor(ctx, ins[1]);
|
||||
auto pts_feature = buildATensor(ctx, ins[2]);
|
||||
|
||||
auto pooled_features = buildATensor(ctx, outs[0]);
|
||||
auto pooled_empty_flag = buildATensor(ctx, outs[1]);
|
||||
|
||||
roipoint_pool3d_forward(xyz, boxes3d, pts_feature, pooled_features,
|
||||
pooled_empty_flag);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(roipoint_pool3d_forward)
|
||||
.input(3)
|
||||
.output(2)
|
||||
.apply(roipoint_pool3d_forward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ROIPOINT_POOL3D_PYTORCH_H
|
||||
#define ROIPOINT_POOL3D_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
|
||||
Tensor pooled_features, Tensor pooled_empty_flag);
|
||||
|
||||
#endif // ROIPOINT_POOL3D_PYTORCH_H
|
|
@ -0,0 +1,61 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
|
||||
const Tensor points,
|
||||
const Tensor idx,
|
||||
const Tensor weight, Tensor out);
|
||||
|
||||
void three_interpolate_forward_cuda(int b, int c, int m, int n,
|
||||
const Tensor points, const Tensor idx,
|
||||
const Tensor weight, Tensor out) {
|
||||
ThreeInterpolateForwardCUDAKernelLauncher(b, c, m, n, points, idx, weight,
|
||||
out);
|
||||
};
|
||||
|
||||
void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
|
||||
const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
const Tensor weight,
|
||||
Tensor grad_points);
|
||||
|
||||
void three_interpolate_backward_cuda(int b, int c, int n, int m,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
const Tensor weight, Tensor grad_points) {
|
||||
ThreeInterpolateBackwardCUDAKernelLauncher(b, c, n, m, grad_out, idx, weight,
|
||||
grad_points);
|
||||
};
|
||||
#endif
|
||||
|
||||
void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor out_tensor, int b,
|
||||
int c, int m, int n) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor,
|
||||
weight_tensor, out_tensor);
|
||||
#else
|
||||
AT_ERROR("three_interpolate is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("three_interpolate is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor grad_points_tensor,
|
||||
int b, int c, int n, int m) {
|
||||
if (grad_out_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor,
|
||||
weight_tensor, grad_points_tensor);
|
||||
#else
|
||||
AT_ERROR("three_interpolate is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("three_interpolate is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "three_interpolate_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void three_interpolate_forward_cuda_parrots(CudaContext& ctx,
|
||||
const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, c, m, n;
|
||||
SSAttrs(attr)
|
||||
.get<int>("b", b)
|
||||
.get<int>("c", c)
|
||||
.get<int>("m", m)
|
||||
.get<int>("n", n)
|
||||
.done();
|
||||
|
||||
auto points_tensor = buildATensor(ctx, ins[0]);
|
||||
auto idx_tensor = buildATensor(ctx, ins[1]);
|
||||
auto weight_tensor = buildATensor(ctx, ins[2]);
|
||||
|
||||
auto out_tensor = buildATensor(ctx, outs[0]);
|
||||
|
||||
three_interpolate_forward(points_tensor, idx_tensor, weight_tensor,
|
||||
out_tensor, b, c, m, n);
|
||||
}
|
||||
|
||||
void three_interpolate_backward_cuda_parrots(CudaContext& ctx,
|
||||
const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, c, n, m;
|
||||
SSAttrs(attr)
|
||||
.get<int>("b", b)
|
||||
.get<int>("c", c)
|
||||
.get<int>("n", n)
|
||||
.get<int>("m", m)
|
||||
.done();
|
||||
|
||||
auto grad_out_tensor = buildATensor(ctx, ins[0]);
|
||||
auto idx_tensor = buildATensor(ctx, ins[1]);
|
||||
auto weight_tensor = buildATensor(ctx, ins[2]);
|
||||
|
||||
auto grad_points_tensor = buildATensor(ctx, outs[0]);
|
||||
|
||||
three_interpolate_backward(grad_out_tensor, idx_tensor, weight_tensor,
|
||||
grad_points_tensor, b, c, n, m);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(three_interpolate_forward)
|
||||
.attr("b")
|
||||
.attr("c")
|
||||
.attr("m")
|
||||
.attr("n")
|
||||
.input(3)
|
||||
.output(1)
|
||||
.apply(three_interpolate_forward_cuda_parrots)
|
||||
.done();
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(three_interpolate_backward)
|
||||
.attr("b")
|
||||
.attr("c")
|
||||
.attr("n")
|
||||
.attr("m")
|
||||
.input(3)
|
||||
.output(1)
|
||||
.apply(three_interpolate_backward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef THREE_INTERPOLATE_PYTORCH_H
|
||||
#define THREE_INTERPOLATE_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor out_tensor, int b,
|
||||
int c, int m, int n);
|
||||
|
||||
void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor grad_points_tensor,
|
||||
int b, int c, int n, int m);
|
||||
#endif // THREE_INTERPOLATE_PYTORCH_H
|
|
@ -0,0 +1,30 @@
|
|||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2,
|
||||
Tensor idx);
|
||||
|
||||
void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2, Tensor idx) {
|
||||
ThreeNNForwardCUDAKernelLauncher(b, n, m, unknown, known, dist2, idx);
|
||||
};
|
||||
#endif
|
||||
|
||||
void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor,
|
||||
Tensor dist2_tensor, Tensor idx_tensor, int b, int n,
|
||||
int m) {
|
||||
if (unknown_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor,
|
||||
idx_tensor);
|
||||
#else
|
||||
AT_ERROR("three_nn is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("three_nn is not implemented on CPU");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <parrots/compute/aten.hpp>
|
||||
#include <parrots/extension.hpp>
|
||||
#include <parrots/foundation/ssattrs.hpp>
|
||||
|
||||
#include "three_nn_pytorch.h"
|
||||
|
||||
using namespace parrots;
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void three_nn_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
|
||||
const OperatorBase::in_list_t& ins,
|
||||
OperatorBase::out_list_t& outs) {
|
||||
int b, n, m;
|
||||
SSAttrs(attr).get<int>("b", b).get<int>("n", n).get<int>("m", m).done();
|
||||
|
||||
auto unknown_tensor = buildATensor(ctx, ins[0]);
|
||||
auto known_tensor = buildATensor(ctx, ins[1]);
|
||||
|
||||
auto dist2_tensor = buildATensor(ctx, outs[0]);
|
||||
auto idx_tensor = buildATensor(ctx, outs[1]);
|
||||
|
||||
three_nn_forward(unknown_tensor, known_tensor, dist2_tensor, idx_tensor, b, n,
|
||||
m);
|
||||
}
|
||||
|
||||
PARROTS_EXTENSION_REGISTER(three_nn_forward)
|
||||
.attr("b")
|
||||
.attr("n")
|
||||
.attr("m")
|
||||
.input(2)
|
||||
.output(2)
|
||||
.apply(three_nn_forward_cuda_parrots)
|
||||
.done();
|
||||
#endif
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef THREE_NN_PYTORCH_H
|
||||
#define THREE_NN_PYTORCH_H
|
||||
#include <torch/extension.h>
|
||||
using namespace at;
|
||||
|
||||
void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor,
|
||||
Tensor dist2_tensor, Tensor idx_tensor, int b, int n,
|
||||
int m);
|
||||
#endif // THREE_NN_PYTORCH_H
|
|
@ -34,10 +34,10 @@ void assign_score_withk_backward_cuda(
|
|||
};
|
||||
#endif
|
||||
|
||||
void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate, const Tensor& points,
|
||||
const Tensor& centers, const Tensor& scores,
|
||||
const Tensor& knn_idx, Tensor& output) {
|
||||
void assign_score_withk_forward(const Tensor& points, const Tensor& centers,
|
||||
const Tensor& scores, const Tensor& knn_idx,
|
||||
Tensor& output, int B, int N0, int N1, int M,
|
||||
int K, int O, int aggregate) {
|
||||
if (points.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CONTIGUOUS(points);
|
||||
|
@ -56,12 +56,12 @@ void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O,
|
|||
}
|
||||
}
|
||||
|
||||
void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate, const Tensor& grad_out,
|
||||
const Tensor& points, const Tensor& centers,
|
||||
const Tensor& scores, const Tensor& knn_idx,
|
||||
Tensor& grad_points, Tensor& grad_centers,
|
||||
Tensor& grad_scores) {
|
||||
void assign_score_withk_backward(const Tensor& grad_out, const Tensor& points,
|
||||
const Tensor& centers, const Tensor& scores,
|
||||
const Tensor& knn_idx, Tensor& grad_points,
|
||||
Tensor& grad_centers, Tensor& grad_scores,
|
||||
int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate) {
|
||||
if (grad_points.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CONTIGUOUS(grad_out);
|
||||
|
|
|
@ -18,9 +18,9 @@ void ball_query_forward_cuda(int b, int n, int m, float min_radius,
|
|||
};
|
||||
#endif
|
||||
|
||||
void ball_query_forward(int b, int n, int m, float min_radius, float max_radius,
|
||||
int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor,
|
||||
Tensor idx_tensor) {
|
||||
void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m,
|
||||
float min_radius, float max_radius, int nsample) {
|
||||
if (new_xyz_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(new_xyz_tensor);
|
||||
|
|
|
@ -25,8 +25,8 @@ void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m,
|
|||
}
|
||||
#endif
|
||||
|
||||
void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
|
||||
Tensor temp_tensor, Tensor idx_tensor) {
|
||||
void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
const float *points = points_tensor.data_ptr<float>();
|
||||
|
@ -41,10 +41,10 @@ void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
|
|||
}
|
||||
}
|
||||
|
||||
void furthest_point_sampling_with_dist_forward(int b, int n, int m,
|
||||
Tensor points_tensor,
|
||||
void furthest_point_sampling_with_dist_forward(Tensor points_tensor,
|
||||
Tensor temp_tensor,
|
||||
Tensor idx_tensor) {
|
||||
Tensor idx_tensor, int b, int n,
|
||||
int m) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
const float *points = points_tensor.data<float>();
|
||||
|
|
|
@ -24,9 +24,9 @@ void gather_points_backward_cuda(int b, int c, int n, int npoints,
|
|||
};
|
||||
#endif
|
||||
|
||||
void gather_points_forward(int b, int c, int n, int npoints,
|
||||
Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor) {
|
||||
void gather_points_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor, int b, int c, int n,
|
||||
int npoints) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor,
|
||||
|
@ -39,9 +39,9 @@ void gather_points_forward(int b, int c, int n, int npoints,
|
|||
}
|
||||
}
|
||||
|
||||
void gather_points_backward(int b, int c, int n, int npoints,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor) {
|
||||
void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor, int b, int c, int n,
|
||||
int npoints) {
|
||||
if (grad_out_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor,
|
||||
|
|
|
@ -14,9 +14,8 @@ void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz,
|
|||
}
|
||||
#endif
|
||||
|
||||
void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
|
||||
Tensor new_xyz_tensor, Tensor idx_tensor,
|
||||
Tensor dist2_tensor) {
|
||||
void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor,
|
||||
Tensor dist2_tensor, int b, int n, int m, int nsample) {
|
||||
if (new_xyz_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(new_xyz_tensor);
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
std::string get_compiler_version();
|
||||
std::string get_compiling_cuda_version();
|
||||
|
||||
void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate, const Tensor &points,
|
||||
const Tensor ¢ers, const Tensor &scores,
|
||||
const Tensor &knn_idx, Tensor &output);
|
||||
void assign_score_withk_forward(const Tensor &points, const Tensor ¢ers,
|
||||
const Tensor &scores, const Tensor &knn_idx,
|
||||
Tensor &output, int B, int N0, int N1, int M,
|
||||
int K, int O, int aggregate);
|
||||
|
||||
void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate, const Tensor &grad_out,
|
||||
const Tensor &points, const Tensor ¢ers,
|
||||
const Tensor &scores, const Tensor &knn_idx,
|
||||
Tensor &grad_points, Tensor &grad_centers,
|
||||
Tensor &grad_scores);
|
||||
void assign_score_withk_backward(const Tensor &grad_out, const Tensor &points,
|
||||
const Tensor ¢ers, const Tensor &scores,
|
||||
const Tensor &knn_idx, Tensor &grad_points,
|
||||
Tensor &grad_centers, Tensor &grad_scores,
|
||||
int B, int N0, int N1, int M, int K, int O,
|
||||
int aggregate);
|
||||
|
||||
void carafe_naive_forward(Tensor features, Tensor masks, Tensor output,
|
||||
int kernel_size, int group_size, int scale_factor);
|
||||
|
@ -76,13 +76,12 @@ void group_points_backward(int b, int c, int n, int npoints, int nsample,
|
|||
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
|
||||
Tensor pooled_features, Tensor pooled_empty_flag);
|
||||
|
||||
void gather_points_forward(int b, int c, int n, int npoints,
|
||||
Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor);
|
||||
void gather_points_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor, int b, int c, int n, int npoints);
|
||||
|
||||
void gather_points_backward(int b, int c, int n, int npoints,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor);
|
||||
void gather_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor, int b, int c, int n,
|
||||
int npoints);
|
||||
|
||||
void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha);
|
||||
|
@ -97,22 +96,23 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
|
|||
Tensor buff, Tensor grad_input, float gamma,
|
||||
float alpha);
|
||||
|
||||
void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
|
||||
Tensor idx_tensor, Tensor weight_tensor,
|
||||
Tensor out_tensor);
|
||||
void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor out_tensor, int b,
|
||||
int c, int m, int n);
|
||||
|
||||
void three_interpolate_backward(int b, int c, int n, int m,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor,
|
||||
Tensor grad_points_tensor);
|
||||
void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor grad_points_tensor,
|
||||
int b, int c, int n, int m);
|
||||
|
||||
void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
|
||||
Tensor known_tensor, Tensor dist2_tensor,
|
||||
Tensor idx_tensor);
|
||||
void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor,
|
||||
Tensor dist2_tensor, Tensor idx_tensor, int b, int n,
|
||||
int m);
|
||||
|
||||
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
|
||||
const int mode, const bool aligned, const int offset);
|
||||
|
||||
void knn_forward(Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor,
|
||||
Tensor dist2_tensor, int b, int n, int m, int nsample);
|
||||
void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
|
||||
Tensor ans_overlap);
|
||||
|
||||
|
@ -124,16 +124,13 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh);
|
|||
int iou3d_nms_normal_forward(Tensor boxes, Tensor keep,
|
||||
float nms_overlap_thresh);
|
||||
|
||||
void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
|
||||
Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor);
|
||||
void furthest_point_sampling_forward(Tensor points_tensor, Tensor temp_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m);
|
||||
|
||||
void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
|
||||
Tensor temp_tensor, Tensor idx_tensor);
|
||||
|
||||
void furthest_point_sampling_with_dist_forward(int b, int n, int m,
|
||||
Tensor points_tensor,
|
||||
void furthest_point_sampling_with_dist_forward(Tensor points_tensor,
|
||||
Tensor temp_tensor,
|
||||
Tensor idx_tensor);
|
||||
Tensor idx_tensor, int b, int n,
|
||||
int m);
|
||||
|
||||
void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx,
|
||||
const Tensor mask_w_idx, Tensor col,
|
||||
|
@ -238,9 +235,9 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output);
|
|||
|
||||
void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input);
|
||||
|
||||
void ball_query_forward(int b, int n, int m, float min_radius, float max_radius,
|
||||
int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor,
|
||||
Tensor idx_tensor);
|
||||
void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
|
||||
Tensor idx_tensor, int b, int n, int m,
|
||||
float min_radius, float max_radius, int nsample);
|
||||
|
||||
Tensor bottom_pool_forward(Tensor input);
|
||||
|
||||
|
@ -352,32 +349,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
|
||||
py::arg("scale"));
|
||||
m.def("gather_points_forward", &gather_points_forward,
|
||||
"gather_points_forward", py::arg("b"), py::arg("c"), py::arg("n"),
|
||||
py::arg("npoints"), py::arg("points_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("out_tensor"));
|
||||
"gather_points_forward", py::arg("points_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("out_tensor"), py::arg("b"),
|
||||
py::arg("c"), py::arg("n"), py::arg("npoints"));
|
||||
m.def("gather_points_backward", &gather_points_backward,
|
||||
"gather_points_backward", py::arg("b"), py::arg("c"), py::arg("n"),
|
||||
py::arg("npoints"), py::arg("grad_out_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("grad_points_tensor"));
|
||||
"gather_points_backward", py::arg("grad_out_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"),
|
||||
py::arg("c"), py::arg("n"), py::arg("npoints"));
|
||||
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
|
||||
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
|
||||
"get_compiling_cuda_version");
|
||||
m.def("assign_score_withk_forward", &assign_score_withk_forward,
|
||||
"assign_score_withk_forward", py::arg("B"), py::arg("N0"),
|
||||
py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"),
|
||||
py::arg("aggregate"), py::arg("points"), py::arg("centers"),
|
||||
py::arg("scores"), py::arg("knn_idx"), py::arg("output"));
|
||||
"assign_score_withk_forward", py::arg("points"), py::arg("centers"),
|
||||
py::arg("scores"), py::arg("knn_idx"), py::arg("output"), py::arg("B"),
|
||||
py::arg("N0"), py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"),
|
||||
py::arg("aggregate"));
|
||||
m.def("assign_score_withk_backward", &assign_score_withk_backward,
|
||||
"assign_score_withk_backward", py::arg("B"), py::arg("N0"),
|
||||
py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"),
|
||||
py::arg("aggregate"), py::arg("grad_out"), py::arg("points"),
|
||||
"assign_score_withk_backward", py::arg("grad_out"), py::arg("points"),
|
||||
py::arg("centers"), py::arg("scores"), py::arg("knn_idx"),
|
||||
py::arg("grad_points"), py::arg("grad_centers"),
|
||||
py::arg("grad_scores"));
|
||||
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
|
||||
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
|
||||
py::arg("grad_points"), py::arg("grad_centers"), py::arg("grad_scores"),
|
||||
py::arg("B"), py::arg("N0"), py::arg("N1"), py::arg("M"), py::arg("K"),
|
||||
py::arg("O"), py::arg("aggregate"));
|
||||
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("xyz_tensor"),
|
||||
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("dist2_tensor"));
|
||||
py::arg("dist2_tensor"), py::arg("b"), py::arg("n"), py::arg("m"),
|
||||
py::arg("nsample"));
|
||||
m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward",
|
||||
py::arg("features"), py::arg("masks"), py::arg("output"),
|
||||
py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor"));
|
||||
|
@ -447,17 +443,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("weight"), py::arg("buff"), py::arg("grad_input"),
|
||||
py::arg("gamma"), py::arg("alpha"));
|
||||
m.def("three_interpolate_forward", &three_interpolate_forward,
|
||||
"three_interpolate_forward", py::arg("b"), py::arg("c"), py::arg("m"),
|
||||
py::arg("n"), py::arg("points_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("weight_tensor"), py::arg("out_tensor"));
|
||||
"three_interpolate_forward", py::arg("points_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("weight_tensor"), py::arg("out_tensor"),
|
||||
py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n"));
|
||||
m.def("three_interpolate_backward", &three_interpolate_backward,
|
||||
"three_interpolate_backward", py::arg("b"), py::arg("c"), py::arg("n"),
|
||||
py::arg("m"), py::arg("grad_out_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("weight_tensor"), py::arg("grad_points_tensor"));
|
||||
m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", py::arg("b"),
|
||||
py::arg("n"), py::arg("m"), py::arg("unknown_tensor"),
|
||||
py::arg("known_tensor"), py::arg("dist2_tensor"),
|
||||
py::arg("idx_tensor"));
|
||||
"three_interpolate_backward", py::arg("grad_out_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("weight_tensor"),
|
||||
py::arg("grad_points_tensor"), py::arg("b"), py::arg("c"), py::arg("n"),
|
||||
py::arg("m"));
|
||||
m.def("three_nn_forward", &three_nn_forward, "three_nn_forward",
|
||||
py::arg("unknown_tensor"), py::arg("known_tensor"),
|
||||
py::arg("dist2_tensor"), py::arg("idx_tensor"), py::arg("b"),
|
||||
py::arg("n"), py::arg("m"));
|
||||
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
|
||||
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
|
||||
py::arg("aligned"), py::arg("offset"));
|
||||
|
@ -485,14 +482,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"),
|
||||
py::arg("nms_overlap_thresh"));
|
||||
m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward,
|
||||
"furthest_point_sampling_forward", py::arg("b"), py::arg("n"),
|
||||
py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"),
|
||||
py::arg("idx_tensor"));
|
||||
"furthest_point_sampling_forward", py::arg("points_tensor"),
|
||||
py::arg("temp_tensor"), py::arg("idx_tensor"), py::arg("b"),
|
||||
py::arg("n"), py::arg("m"));
|
||||
m.def("furthest_point_sampling_with_dist_forward",
|
||||
&furthest_point_sampling_with_dist_forward,
|
||||
"furthest_point_sampling_with_dist_forward", py::arg("b"), py::arg("n"),
|
||||
py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"),
|
||||
py::arg("idx_tensor"));
|
||||
"furthest_point_sampling_with_dist_forward", py::arg("points_tensor"),
|
||||
py::arg("temp_tensor"), py::arg("idx_tensor"), py::arg("b"),
|
||||
py::arg("n"), py::arg("m"));
|
||||
m.def("masked_im2col_forward", &masked_im2col_forward,
|
||||
"masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"),
|
||||
py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"),
|
||||
|
@ -609,9 +606,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
|
||||
py::arg("iou_threshold"), py::arg("multi_label"));
|
||||
m.def("ball_query_forward", &ball_query_forward, "ball_query_forward",
|
||||
py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
|
||||
py::arg("max_radius"), py::arg("nsample"), py::arg("new_xyz_tensor"),
|
||||
py::arg("xyz_tensor"), py::arg("idx_tensor"));
|
||||
py::arg("max_radius"), py::arg("nsample"));
|
||||
m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
|
||||
"roi_align_rotated forward", py::arg("input"), py::arg("rois"),
|
||||
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
|
||||
|
@ -657,6 +654,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"backward function of border_align", py::arg("grad_output"),
|
||||
py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"),
|
||||
py::arg("pool_size"));
|
||||
m.def("correlation_forward", &correlation_forward, "Correlation forward",
|
||||
py::arg("input1"), py::arg("input2"), py::arg("output"), py::arg("kH"),
|
||||
py::arg("kW"), py::arg("patchH"), py::arg("patchW"), py::arg("padH"),
|
||||
py::arg("padW"), py::arg("dilationH"), py::arg("dilationW"),
|
||||
py::arg("dilation_patchH"), py::arg("dilation_patchW"), py::arg("dH"),
|
||||
py::arg("dW"));
|
||||
m.def("correlation_backward", &correlation_backward, "Correlation backward",
|
||||
py::arg("grad_output"), py::arg("input1"), py::arg("input2"),
|
||||
py::arg("grad_input1"), py::arg("grad_input2"), py::arg("kH"),
|
||||
py::arg("kW"), py::arg("patchH"), py::arg("patchW"), py::arg("padH"),
|
||||
py::arg("padW"), py::arg("dilationH"), py::arg("dilationW"),
|
||||
py::arg("dilation_patchH"), py::arg("dilation_patchW"), py::arg("dH"),
|
||||
py::arg("dW"));
|
||||
m.def("points_in_boxes_cpu_forward", &points_in_boxes_cpu_forward,
|
||||
"points_in_boxes_cpu_forward", py::arg("boxes_tensor"),
|
||||
py::arg("pts_tensor"), py::arg("pts_indices_tensor"));
|
||||
|
@ -674,6 +684,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"),
|
||||
py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"),
|
||||
py::arg("pool_method"));
|
||||
m.def("correlation_forward", &correlation_forward, "Correlation forward");
|
||||
m.def("correlation_backward", &correlation_backward, "Correlation backward");
|
||||
}
|
||||
|
|
|
@ -30,9 +30,9 @@ void three_interpolate_backward_cuda(int b, int c, int n, int m,
|
|||
};
|
||||
#endif
|
||||
|
||||
void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
|
||||
Tensor idx_tensor, Tensor weight_tensor,
|
||||
Tensor out_tensor) {
|
||||
void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor out_tensor, int b,
|
||||
int c, int m, int n) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor,
|
||||
|
@ -45,10 +45,9 @@ void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
|
|||
}
|
||||
}
|
||||
|
||||
void three_interpolate_backward(int b, int c, int n, int m,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor,
|
||||
Tensor grad_points_tensor) {
|
||||
void three_interpolate_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor grad_points_tensor,
|
||||
int b, int c, int n, int m) {
|
||||
if (grad_out_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor,
|
||||
|
|
|
@ -14,9 +14,9 @@ void three_nn_forward_cuda(int b, int n, int m, const Tensor unknown,
|
|||
};
|
||||
#endif
|
||||
|
||||
void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
|
||||
Tensor known_tensor, Tensor dist2_tensor,
|
||||
Tensor idx_tensor) {
|
||||
void three_nn_forward(Tensor unknown_tensor, Tensor known_tensor,
|
||||
Tensor dist2_tensor, Tensor idx_tensor, int b, int n,
|
||||
int m) {
|
||||
if (unknown_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
three_nn_forward_cuda(b, n, m, unknown_tensor, known_tensor, dist2_tensor,
|
||||
|
|
|
@ -30,9 +30,16 @@ class FurthestPointSampling(Function):
|
|||
output = torch.cuda.IntTensor(B, num_points)
|
||||
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
||||
|
||||
ext_module.furthest_point_sampling_forward(B, N, num_points,
|
||||
points_xyz, temp, output)
|
||||
ctx.mark_non_differentiable(output)
|
||||
ext_module.furthest_point_sampling_forward(
|
||||
points_xyz,
|
||||
temp,
|
||||
output,
|
||||
b=B,
|
||||
n=N,
|
||||
m=num_points,
|
||||
)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
@ -62,8 +69,9 @@ class FurthestPointSamplingWithDist(Function):
|
|||
temp = points_dist.new_zeros([B, N]).fill_(1e10)
|
||||
|
||||
ext_module.furthest_point_sampling_with_dist_forward(
|
||||
B, N, num_points, points_dist, temp, output)
|
||||
ctx.mark_non_differentiable(output)
|
||||
points_dist, temp, output, b=B, n=N, m=num_points)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -28,11 +28,12 @@ class GatherPoints(Function):
|
|||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, npoint)
|
||||
|
||||
ext_module.gather_points_forward(B, C, N, npoint, features, indices,
|
||||
output)
|
||||
ext_module.gather_points_forward(
|
||||
features, indices, output, b=B, c=C, n=N, npoints=npoint)
|
||||
|
||||
ctx.for_backwards = (indices, C, N)
|
||||
ctx.mark_non_differentiable(indices)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(indices)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
@ -42,8 +43,14 @@ class GatherPoints(Function):
|
|||
|
||||
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
ext_module.gather_points_backward(B, C, N, npoint, grad_out_data, idx,
|
||||
grad_features.data)
|
||||
ext_module.gather_points_backward(
|
||||
grad_out_data,
|
||||
idx,
|
||||
grad_features.data,
|
||||
b=B,
|
||||
c=C,
|
||||
n=N,
|
||||
npoints=npoint)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
|
|
|
@ -61,10 +61,12 @@ class KNN(Function):
|
|||
idx = center_xyz.new_zeros((B, npoint, k)).int()
|
||||
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
|
||||
|
||||
ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2)
|
||||
ext_module.knn_forward(
|
||||
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
|
||||
# idx shape to [B, k, npoint]
|
||||
idx = idx.transpose(2, 1).contiguous()
|
||||
ctx.mark_non_differentiable(idx)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return idx
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -39,8 +39,8 @@ class ThreeInterpolate(Function):
|
|||
ctx.three_interpolate_for_backward = (indices, weight, m)
|
||||
output = torch.cuda.FloatTensor(B, c, n)
|
||||
|
||||
ext_module.three_interpolate_forward(B, c, m, n, features, indices,
|
||||
weight, output)
|
||||
ext_module.three_interpolate_forward(
|
||||
features, indices, weight, output, b=B, c=c, m=m, n=n)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
@ -60,8 +60,8 @@ class ThreeInterpolate(Function):
|
|||
grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
|
||||
ext_module.three_interpolate_backward(B, c, n, m, grad_out_data, idx,
|
||||
weight, grad_features.data)
|
||||
ext_module.three_interpolate_backward(
|
||||
grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
|
||||
return grad_features, None, None
|
||||
|
||||
|
||||
|
|
|
@ -37,9 +37,9 @@ class ThreeNN(Function):
|
|||
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
||||
idx = torch.cuda.IntTensor(B, N, 3)
|
||||
|
||||
ext_module.three_nn_forward(B, N, m, target, source, dist2, idx)
|
||||
|
||||
ctx.mark_non_differentiable(idx)
|
||||
ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(idx)
|
||||
|
||||
return torch.sqrt(dist2), idx
|
||||
|
||||
|
|
Loading…
Reference in New Issue