mirror of https://github.com/open-mmlab/mmcv.git
[Features] Add stack ball query and stack group points ops (#2292)
* add stack sa model ops * fix lint * fix lint * fix comments * fix bug * fix lint * fix comments * fix lint * fix lint * fixpull/2371/head
parent
a4f304a5f5
commit
93fe4829f7
tests/test_ops
|
@ -1,28 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])
|
||||
|
||||
|
||||
class BallQuery(Function):
|
||||
"""Find nearby points in spherical space."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
|
||||
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
min_radius: float,
|
||||
max_radius: float,
|
||||
sample_num: int,
|
||||
xyz: torch.Tensor,
|
||||
center_xyz: torch.Tensor,
|
||||
xyz_batch_cnt: Optional[torch.Tensor] = None,
|
||||
center_xyz_batch_cnt: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
min_radius (float): minimum radius of the balls.
|
||||
max_radius (float): maximum radius of the balls.
|
||||
sample_num (int): maximum number of features in the balls.
|
||||
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features.
|
||||
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
|
||||
or staked input (N1 + N2 ..., 3).
|
||||
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
|
||||
query.
|
||||
query, or staked input (M1 + M2 ..., 3).
|
||||
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
|
||||
each batch, just like (N1, N2, ...). Defaults to None.
|
||||
New in version 1.7.0.
|
||||
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
|
||||
nums in each batch, just line (M1, M2, ...). Defaults to None.
|
||||
New in version 1.7.0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (B, npoint, nsample) tensor with the indices of the
|
||||
|
@ -31,21 +47,34 @@ class BallQuery(Function):
|
|||
assert center_xyz.is_contiguous()
|
||||
assert xyz.is_contiguous()
|
||||
assert min_radius < max_radius
|
||||
|
||||
B, N, _ = xyz.size()
|
||||
npoint = center_xyz.size(1)
|
||||
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
|
||||
|
||||
ext_module.ball_query_forward(
|
||||
center_xyz,
|
||||
xyz,
|
||||
idx,
|
||||
b=B,
|
||||
n=N,
|
||||
m=npoint,
|
||||
min_radius=min_radius,
|
||||
max_radius=max_radius,
|
||||
nsample=sample_num)
|
||||
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
|
||||
assert xyz_batch_cnt.dtype == torch.int
|
||||
assert center_xyz_batch_cnt.dtype == torch.int
|
||||
idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
|
||||
dtype=torch.int32)
|
||||
ext_module.stack_ball_query_forward(
|
||||
center_xyz,
|
||||
center_xyz_batch_cnt,
|
||||
xyz,
|
||||
xyz_batch_cnt,
|
||||
idx,
|
||||
max_radius=max_radius,
|
||||
nsample=sample_num,
|
||||
)
|
||||
else:
|
||||
B, N, _ = xyz.size()
|
||||
npoint = center_xyz.size(1)
|
||||
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
|
||||
ext_module.ball_query_forward(
|
||||
center_xyz,
|
||||
xyz,
|
||||
idx,
|
||||
b=B,
|
||||
n=N,
|
||||
m=npoint,
|
||||
min_radius=min_radius,
|
||||
max_radius=max_radius,
|
||||
nsample=sample_num)
|
||||
if torch.__version__ != 'parrots':
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return idx
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
|
||||
#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH
|
||||
#define STACK_BALL_QUERY_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void stack_ball_query_forward_cuda_kernel(
|
||||
int B, int M, float radius, int nsample, const T *new_xyz,
|
||||
const int *new_xyz_batch_cnt, const T *xyz, const int *xyz_batch_cnt,
|
||||
int *idx) {
|
||||
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
|
||||
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
|
||||
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
|
||||
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
|
||||
// output:
|
||||
// idx: (M, nsample)
|
||||
const T *cur_xyz = xyz;
|
||||
int *cur_idx = idx;
|
||||
CUDA_1D_KERNEL_LOOP(pt_idx, M) {
|
||||
int bs_idx = 0;
|
||||
for (int pt_cnt = 0; bs_idx < B; bs_idx++) {
|
||||
pt_cnt += new_xyz_batch_cnt[bs_idx];
|
||||
if (pt_idx < pt_cnt) break;
|
||||
}
|
||||
|
||||
int xyz_batch_start_idx = 0;
|
||||
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
|
||||
|
||||
const T *new_xyz_p = new_xyz + pt_idx * 3;
|
||||
cur_xyz += xyz_batch_start_idx * 3;
|
||||
cur_idx += pt_idx * nsample;
|
||||
|
||||
float radius2 = radius * radius;
|
||||
T new_x = new_xyz_p[0];
|
||||
T new_y = new_xyz_p[1];
|
||||
T new_z = new_xyz_p[2];
|
||||
int n = xyz_batch_cnt[bs_idx];
|
||||
|
||||
int cnt = 0;
|
||||
for (int k = 0; k < n; ++k) {
|
||||
T x = cur_xyz[k * 3 + 0];
|
||||
T y = cur_xyz[k * 3 + 1];
|
||||
T z = cur_xyz[k * 3 + 2];
|
||||
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
|
||||
(new_z - z) * (new_z - z);
|
||||
if (d2 < radius2) {
|
||||
if (cnt == 0) {
|
||||
for (int l = 0; l < nsample; ++l) {
|
||||
cur_idx[l] = k;
|
||||
}
|
||||
}
|
||||
cur_idx[cnt] = k;
|
||||
++cnt;
|
||||
if (cnt >= nsample) break;
|
||||
}
|
||||
}
|
||||
if (cnt == 0) cur_idx[0] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
|
||||
#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH
|
||||
#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
#include <stdio.h>
|
||||
template <typename T>
|
||||
__global__ void stack_group_points_forward_cuda_kernel(
|
||||
int b, int c, int m, int nsample, const T *features,
|
||||
const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt,
|
||||
T *out) {
|
||||
// :param features: (N1 + N2 ..., C) tensor of features to group
|
||||
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
|
||||
// indices of features to group with :param idx: (M1 + M2 ..., nsample) tensor
|
||||
// containing the indices of features to group with :param idx_batch_cnt:
|
||||
// (batch_size) [M1 + M2 ...] tensor containing the indices of features to
|
||||
// group with :return:
|
||||
// output: (M1 + M2, C, nsample) tensor
|
||||
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
|
||||
const T *cur_features = features;
|
||||
const int *cur_idx = idx;
|
||||
int sample_idx = index % nsample;
|
||||
int c_idx = (index / nsample) % c;
|
||||
int pt_idx = (index / nsample / c);
|
||||
|
||||
if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
|
||||
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
|
||||
for (int k = 1; k < b; k++) {
|
||||
if (pt_idx < pt_cnt) break;
|
||||
pt_cnt += idx_batch_cnt[k];
|
||||
bs_idx = k;
|
||||
}
|
||||
|
||||
int features_batch_start_idx = 0;
|
||||
int features_batch_end_idx = features_batch_cnt[0];
|
||||
for (int k = 0; k < bs_idx; k++) {
|
||||
features_batch_start_idx += features_batch_cnt[k];
|
||||
features_batch_end_idx =
|
||||
features_batch_start_idx + features_batch_cnt[k + 1];
|
||||
}
|
||||
cur_features += features_batch_start_idx * c;
|
||||
|
||||
cur_idx += pt_idx * nsample + sample_idx;
|
||||
int in_idx = cur_idx[0] * c + c_idx;
|
||||
int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx;
|
||||
if (in_idx < features_batch_end_idx * c) {
|
||||
out[out_idx] = cur_features[in_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void stack_group_points_backward_cuda_kernel(
|
||||
int b, int c, int m, int n, int nsample, const T *grad_out, const int *idx,
|
||||
const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) {
|
||||
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the
|
||||
// output from forward :param idx: (M1 + M2 ..., nsample) tensor containing
|
||||
// the indices of features to group with :param idx_batch_cnt: (batch_size)
|
||||
// [M1 + M2 ...] tensor containing the indices of features to group with
|
||||
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
|
||||
// indices of features to group with :return:
|
||||
// grad_features: (N1 + N2 ..., C) gradient of the features
|
||||
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
|
||||
const T *cur_grad_out = grad_out;
|
||||
const int *cur_idx = idx;
|
||||
T *cur_grad_features = grad_features;
|
||||
int sample_idx = index % nsample;
|
||||
int c_idx = (index / nsample) % c;
|
||||
int pt_idx = (index / nsample / c);
|
||||
|
||||
if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
|
||||
|
||||
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
|
||||
for (int k = 1; k < b; k++) {
|
||||
if (pt_idx < pt_cnt) break;
|
||||
pt_cnt += idx_batch_cnt[k];
|
||||
bs_idx = k;
|
||||
}
|
||||
|
||||
int features_batch_start_idx = 0;
|
||||
for (int k = 0; k < bs_idx; k++)
|
||||
features_batch_start_idx += features_batch_cnt[k];
|
||||
|
||||
cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx;
|
||||
cur_idx += pt_idx * nsample + sample_idx;
|
||||
cur_grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx;
|
||||
|
||||
atomicAdd(cur_grad_features, cur_grad_out[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GROUP_POINTS_CUDA_KERNEL_CUH
|
|
@ -15,5 +15,6 @@ using at::Tensor;
|
|||
using phalf = at::Half;
|
||||
|
||||
#define __PHALF(x) (x)
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
#endif // PYTORCH_CUDA_HELPER
|
||||
|
|
|
@ -18,3 +18,21 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
|
|||
ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample,
|
||||
new_xyz_tensor, xyz_tensor, idx_tensor);
|
||||
}
|
||||
|
||||
void stack_ball_query_forward_impl(float max_radius, int nsample,
|
||||
const Tensor new_xyz,
|
||||
const Tensor new_xyz_batch_cnt,
|
||||
const Tensor xyz, const Tensor xyz_batch_cnt,
|
||||
Tensor idx) {
|
||||
DISPATCH_DEVICE_IMPL(stack_ball_query_forward_impl, max_radius, nsample,
|
||||
new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
|
||||
}
|
||||
|
||||
void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
|
||||
Tensor xyz_tensor, Tensor xyz_batch_cnt,
|
||||
Tensor idx_tensor, float max_radius,
|
||||
int nsample) {
|
||||
stack_ball_query_forward_impl(max_radius, nsample, new_xyz_tensor,
|
||||
new_xyz_batch_cnt, xyz_tensor, xyz_batch_cnt,
|
||||
idx_tensor);
|
||||
}
|
||||
|
|
|
@ -67,6 +67,30 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius,
|
|||
Tensor idx);
|
||||
REGISTER_DEVICE_IMPL(ball_query_forward_impl, CUDA, ball_query_forward_cuda);
|
||||
|
||||
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
|
||||
const Tensor new_xyz,
|
||||
const Tensor new_xyz_batch_cnt,
|
||||
const Tensor xyz,
|
||||
const Tensor xyz_batch_cnt,
|
||||
Tensor idx);
|
||||
|
||||
void stack_ball_query_forward_cuda(float max_radius, int nsample,
|
||||
const Tensor new_xyz,
|
||||
const Tensor new_xyz_batch_cnt,
|
||||
const Tensor xyz, const Tensor xyz_batch_cnt,
|
||||
Tensor idx) {
|
||||
StackBallQueryForwardCUDAKernelLauncher(
|
||||
max_radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
|
||||
};
|
||||
|
||||
void stack_ball_query_forward_impl(float max_radius, int nsample,
|
||||
const Tensor new_xyz,
|
||||
const Tensor new_xyz_batch_cnt,
|
||||
const Tensor xyz, const Tensor xyz_batch_cnt,
|
||||
Tensor idx);
|
||||
REGISTER_DEVICE_IMPL(stack_ball_query_forward_impl, CUDA,
|
||||
stack_ball_query_forward_cuda);
|
||||
|
||||
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
|
||||
Tensor ious, const int mode,
|
||||
const bool aligned, const int offset);
|
||||
|
@ -571,6 +595,56 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
|
|||
REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA,
|
||||
group_points_backward_cuda);
|
||||
|
||||
void StackGroupPointsForwardCUDAKernelLauncher(
|
||||
int b, int c, int m, int nsample, const Tensor features_tensor,
|
||||
const Tensor features_batch_cnt_tensor, const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor, Tensor out_tensor);
|
||||
void StackGroupPointsBackwardCUDAKernelLauncher(
|
||||
int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor,
|
||||
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
|
||||
const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor);
|
||||
|
||||
void stack_group_points_forward_cuda(int b, int c, int m, int nsample,
|
||||
const Tensor features_tensor,
|
||||
const Tensor features_batch_cnt_tensor,
|
||||
const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor,
|
||||
Tensor out_tensor) {
|
||||
StackGroupPointsForwardCUDAKernelLauncher(
|
||||
b, c, m, nsample, features_tensor, features_batch_cnt_tensor, idx_tensor,
|
||||
idx_batch_cnt_tensor, out_tensor);
|
||||
};
|
||||
|
||||
void stack_group_points_backward_cuda(int b, int c, int m, int n, int nsample,
|
||||
const Tensor grad_out_tensor,
|
||||
const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor,
|
||||
const Tensor features_batch_cnt_tensor,
|
||||
Tensor grad_features_tensor) {
|
||||
StackGroupPointsBackwardCUDAKernelLauncher(
|
||||
b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
|
||||
features_batch_cnt_tensor, grad_features_tensor);
|
||||
};
|
||||
|
||||
void stack_group_points_forward_impl(int b, int c, int m, int nsample,
|
||||
const Tensor features_tensor,
|
||||
const Tensor features_batch_cnt_tensor,
|
||||
const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor,
|
||||
Tensor out_tensor);
|
||||
|
||||
void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
|
||||
const Tensor grad_out_tensor,
|
||||
const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor,
|
||||
const Tensor features_batch_cnt_tensor,
|
||||
Tensor grad_features_tensor);
|
||||
|
||||
REGISTER_DEVICE_IMPL(stack_group_points_forward_impl, CUDA,
|
||||
stack_group_points_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(stack_group_points_backward_impl, CUDA,
|
||||
stack_group_points_backward_cuda);
|
||||
|
||||
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
|
||||
const Tensor boxes_a,
|
||||
const int num_b,
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#include "stack_ball_query_cuda_kernel.cuh"
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
|
||||
const Tensor new_xyz,
|
||||
const Tensor new_xyz_batch_cnt,
|
||||
const Tensor xyz,
|
||||
const Tensor xyz_batch_cnt,
|
||||
Tensor idx) {
|
||||
at::cuda::CUDAGuard device_guard(new_xyz.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// const float *new_xyz_ptr = new_xyz.data_ptr<float>();
|
||||
// const float *xyz_ptr = xyz.data_ptr<float>();
|
||||
// const int *new_xyz_batch_cnt_ptr = new_xyz_batch_cnt.data_ptr<int>();
|
||||
// const int *xyz_batch_cnt_ptr = xyz_batch_cnt.data_ptr<int>();
|
||||
// int *idx_ptr = idx.data_ptr<int>();
|
||||
|
||||
int B = xyz_batch_cnt.size(0);
|
||||
int M = new_xyz.size(0);
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
new_xyz.scalar_type(), "stack_ball_query_forward_cuda_kernel", [&] {
|
||||
stack_ball_query_forward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
B, M, max_radius, nsample, new_xyz.data_ptr<scalar_t>(),
|
||||
new_xyz_batch_cnt.data_ptr<int>(), xyz.data_ptr<scalar_t>(),
|
||||
xyz_batch_cnt.data_ptr<int>(), idx.data_ptr<int>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#include "stack_group_points_cuda_kernel.cuh"
|
||||
|
||||
void StackGroupPointsForwardCUDAKernelLauncher(
|
||||
int b, int c, int m, int nsample, const Tensor features_tensor,
|
||||
const Tensor features_batch_cnt_tensor, const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor, Tensor out_tensor) {
|
||||
// points: (B, C, N)
|
||||
// idx: (B, npoints, nsample)
|
||||
// output:
|
||||
// out: (B, C, npoints, nsample)
|
||||
at::cuda::CUDAGuard device_guard(features_tensor.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
features_tensor.scalar_type(), "stack_group_points_forward_cuda_kernel",
|
||||
[&] {
|
||||
stack_group_points_forward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, m, nsample, features_tensor.data_ptr<scalar_t>(),
|
||||
features_batch_cnt_tensor.data_ptr<int>(),
|
||||
idx_tensor.data_ptr<int>(),
|
||||
idx_batch_cnt_tensor.data_ptr<int>(),
|
||||
out_tensor.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void StackGroupPointsBackwardCUDAKernelLauncher(
|
||||
int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor,
|
||||
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
|
||||
const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor) {
|
||||
at::cuda::CUDAGuard device_guard(grad_features_tensor.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_features_tensor.scalar_type(),
|
||||
"stack_group_points_backward_cuda_kernel", [&] {
|
||||
stack_group_points_backward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, m, n, nsample, grad_out_tensor.data_ptr<scalar_t>(),
|
||||
idx_tensor.data_ptr<int>(),
|
||||
idx_batch_cnt_tensor.data_ptr<int>(),
|
||||
features_batch_cnt_tensor.data_ptr<int>(),
|
||||
grad_features_tensor.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -32,3 +32,45 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
|||
group_points_backward_impl(b, c, n, npoints, nsample, grad_out_tensor,
|
||||
idx_tensor, grad_points_tensor);
|
||||
}
|
||||
|
||||
void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
|
||||
const Tensor grad_out_tensor,
|
||||
const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor,
|
||||
const Tensor features_batch_cnt_tensor,
|
||||
Tensor grad_features_tensor) {
|
||||
DISPATCH_DEVICE_IMPL(stack_group_points_backward_impl, b, c, m, n, nsample,
|
||||
grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
|
||||
features_batch_cnt_tensor, grad_features_tensor);
|
||||
}
|
||||
|
||||
void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor idx_batch_cnt_tensor,
|
||||
Tensor features_batch_cnt_tensor,
|
||||
Tensor grad_features_tensor, int b, int c,
|
||||
int m, int n, int nsample) {
|
||||
stack_group_points_backward_impl(
|
||||
b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
|
||||
features_batch_cnt_tensor, grad_features_tensor);
|
||||
}
|
||||
|
||||
void stack_group_points_forward_impl(int b, int c, int m, int nsample,
|
||||
const Tensor features_tensor,
|
||||
const Tensor features_batch_cnt_tensor,
|
||||
const Tensor idx_tensor,
|
||||
const Tensor idx_batch_cnt_tensor,
|
||||
Tensor out_tensor) {
|
||||
DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample,
|
||||
features_tensor, features_batch_cnt_tensor, idx_tensor,
|
||||
idx_batch_cnt_tensor, out_tensor);
|
||||
}
|
||||
|
||||
void stack_group_points_forward(Tensor features_tensor,
|
||||
Tensor features_batch_cnt_tensor,
|
||||
Tensor idx_tensor, Tensor idx_batch_cnt_tensor,
|
||||
Tensor out_tensor, int b, int c, int m,
|
||||
int nsample) {
|
||||
DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample,
|
||||
features_tensor, features_batch_cnt_tensor, idx_tensor,
|
||||
idx_batch_cnt_tensor, out_tensor);
|
||||
}
|
||||
|
|
|
@ -75,6 +75,18 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
|||
Tensor grad_points_tensor, int b, int c, int n,
|
||||
int npoints, int nsample);
|
||||
|
||||
void stack_group_points_forward(Tensor features_tensor,
|
||||
Tensor features_batch_cnt_tensor,
|
||||
Tensor idx_tensor, Tensor idx_batch_cnt_tensor,
|
||||
Tensor out_tensor, int b, int c, int m,
|
||||
int nsample);
|
||||
|
||||
void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor idx_batch_cnt_tensor,
|
||||
Tensor features_batch_cnt_tensor,
|
||||
Tensor grad_features_tensor, int b, int c,
|
||||
int m, int n, int nsample);
|
||||
|
||||
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
|
||||
Tensor pooled_features, Tensor pooled_empty_flag);
|
||||
|
||||
|
@ -240,6 +252,10 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
|
|||
Tensor idx_tensor, int b, int n, int m,
|
||||
float min_radius, float max_radius, int nsample);
|
||||
|
||||
void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
|
||||
Tensor xyz_tensor, Tensor xyz_batch_cnt,
|
||||
Tensor idx_tensor, float max_radius, int nsample);
|
||||
|
||||
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale);
|
||||
|
@ -557,6 +573,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"group_points_backward", py::arg("grad_out_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"),
|
||||
py::arg("c"), py::arg("n"), py::arg("npoints"), py::arg("nsample"));
|
||||
m.def("stack_group_points_forward", &stack_group_points_forward,
|
||||
"stack_group_points_forward", py::arg("features_tensor"),
|
||||
py::arg("features_batch_cnt_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("idx_batch_cnt_tensor"), py::arg("out_tensor"), py::arg("b"),
|
||||
py::arg("c"), py::arg("m"), py::arg("nsample"));
|
||||
m.def("stack_group_points_backward", &stack_group_points_backward,
|
||||
"stack_group_points_backward", py::arg("grad_out_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("idx_batch_cnt_tensor"),
|
||||
py::arg("features_batch_cnt_tensor"), py::arg("grad_features_tensor"),
|
||||
py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n"),
|
||||
py::arg("nsample"));
|
||||
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
|
||||
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
|
||||
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
|
||||
|
@ -726,6 +753,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
|
||||
py::arg("max_radius"), py::arg("nsample"));
|
||||
m.def("stack_ball_query_forward", &stack_ball_query_forward,
|
||||
"stack_ball_query_forward", py::arg("new_xyz_tensor"),
|
||||
py::arg("new_xyz_batch_cnt"), py::arg("xyz_tensor"),
|
||||
py::arg("xyz_batch_cnt"), py::arg("idx_tensor"), py::arg("max_radius"),
|
||||
py::arg("nsample"));
|
||||
m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
|
||||
"roi_align_rotated forward", py::arg("input"), py::arg("rois"),
|
||||
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
|
||||
|
|
|
@ -9,8 +9,10 @@ from ..utils import ext_loader
|
|||
from .ball_query import ball_query
|
||||
from .knn import knn
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['group_points_forward', 'group_points_backward'])
|
||||
ext_module = ext_loader.load_ext('_ext', [
|
||||
'group_points_forward', 'group_points_backward',
|
||||
'stack_group_points_forward', 'stack_group_points_backward'
|
||||
])
|
||||
|
||||
|
||||
class QueryAndGroup(nn.Module):
|
||||
|
@ -183,39 +185,71 @@ class GroupingOperation(Function):
|
|||
"""Group feature with given index."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor,
|
||||
indices: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
features: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
features_batch_cnt: Optional[torch.Tensor] = None,
|
||||
indices_batch_cnt: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
features (Tensor): (B, C, N) tensor of features to group.
|
||||
indices (Tensor): (B, npoint, nsample) the indices of
|
||||
features to group with.
|
||||
features (Tensor): Tensor of features to group, input shape is
|
||||
(B, C, N) or stacked inputs (N1 + N2 ..., C).
|
||||
indices (Tensor): The indices of features to group with, input
|
||||
shape is (B, npoint, nsample) or stacked inputs
|
||||
(M1 + M2 ..., nsample).
|
||||
features_batch_cnt (Tensor, optional): Input features nums in
|
||||
each batch, just like (N1, N2, ...). Defaults to None.
|
||||
New in version 1.7.0.
|
||||
indices_batch_cnt (Tensor, optional): Input indices nums in
|
||||
each batch, just like (M1, M2, ...). Defaults to None.
|
||||
New in version 1.7.0.
|
||||
|
||||
Returns:
|
||||
Tensor: (B, C, npoint, nsample) Grouped features.
|
||||
Tensor: Grouped features, the shape is (B, C, npoint, nsample)
|
||||
or (M1 + M2 ..., C, nsample).
|
||||
"""
|
||||
features = features.contiguous()
|
||||
indices = indices.contiguous()
|
||||
if features_batch_cnt is not None and indices_batch_cnt is not None:
|
||||
assert features_batch_cnt.dtype == torch.int
|
||||
assert indices_batch_cnt.dtype == torch.int
|
||||
M, nsample = indices.size()
|
||||
N, C = features.size()
|
||||
B = indices_batch_cnt.shape[0]
|
||||
output = features.new_zeros((M, C, nsample))
|
||||
ext_module.stack_group_points_forward(
|
||||
features,
|
||||
features_batch_cnt,
|
||||
indices,
|
||||
indices_batch_cnt,
|
||||
output,
|
||||
b=B,
|
||||
m=M,
|
||||
c=C,
|
||||
nsample=nsample)
|
||||
ctx.for_backwards = (B, N, indices, features_batch_cnt,
|
||||
indices_batch_cnt)
|
||||
else:
|
||||
B, nfeatures, nsample = indices.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
||||
|
||||
B, nfeatures, nsample = indices.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
||||
ext_module.group_points_forward(
|
||||
features,
|
||||
indices,
|
||||
output,
|
||||
b=B,
|
||||
c=C,
|
||||
n=N,
|
||||
npoints=nfeatures,
|
||||
nsample=nsample)
|
||||
|
||||
ext_module.group_points_forward(
|
||||
features,
|
||||
indices,
|
||||
output,
|
||||
b=B,
|
||||
c=C,
|
||||
n=N,
|
||||
npoints=nfeatures,
|
||||
nsample=nsample)
|
||||
|
||||
ctx.for_backwards = (indices, N)
|
||||
ctx.for_backwards = (indices, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
def backward(ctx, grad_out: torch.Tensor) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
|
||||
|
@ -224,22 +258,42 @@ class GroupingOperation(Function):
|
|||
Returns:
|
||||
Tensor: (B, C, N) gradient of the features.
|
||||
"""
|
||||
idx, N = ctx.for_backwards
|
||||
if len(ctx.for_backwards) != 5:
|
||||
idx, N = ctx.for_backwards
|
||||
|
||||
B, C, npoint, nsample = grad_out.size()
|
||||
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
|
||||
B, C, npoint, nsample = grad_out.size()
|
||||
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
|
||||
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
ext_module.group_points_backward(
|
||||
grad_out_data,
|
||||
idx,
|
||||
grad_features.data,
|
||||
b=B,
|
||||
c=C,
|
||||
n=N,
|
||||
npoints=npoint,
|
||||
nsample=nsample)
|
||||
return grad_features, None
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
ext_module.group_points_backward(
|
||||
grad_out_data,
|
||||
idx,
|
||||
grad_features.data,
|
||||
b=B,
|
||||
c=C,
|
||||
n=N,
|
||||
npoints=npoint,
|
||||
nsample=nsample)
|
||||
return grad_features, None
|
||||
else:
|
||||
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
|
||||
|
||||
M, C, nsample = grad_out.size()
|
||||
grad_features = torch.cuda.FloatTensor(N, C).zero_()
|
||||
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
ext_module.stack_group_points_backward(
|
||||
grad_out_data,
|
||||
idx,
|
||||
idx_batch_cnt,
|
||||
features_batch_cnt,
|
||||
grad_features.data,
|
||||
b=B,
|
||||
c=C,
|
||||
m=M,
|
||||
n=N,
|
||||
nsample=nsample)
|
||||
return grad_features, None, None, None
|
||||
|
||||
|
||||
grouping_operation = GroupingOperation.apply
|
||||
|
|
Binary file not shown.
|
@ -53,3 +53,50 @@ def test_ball_query():
|
|||
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]).cuda()
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_stack_ball_query():
|
||||
new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625],
|
||||
[-2.2769, 2.7817, -0.2334],
|
||||
[-0.4003, 2.4666, -0.5116],
|
||||
[-0.0740, 1.3147, -1.3625],
|
||||
[-0.0740, 1.3147, -1.3625],
|
||||
[-2.0289, 2.4952, -0.1708],
|
||||
[-2.0668, 6.0278, -0.4875],
|
||||
[0.4066, 1.4211, -0.2947],
|
||||
[-2.0289, 2.4952, -0.1708],
|
||||
[-2.0289, 2.4952, -0.1708]]).cuda()
|
||||
new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda()
|
||||
xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
|
||||
[-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466],
|
||||
[-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626],
|
||||
[-2.2769, 2.7817, -0.2334], [-0.2822, 1.3192, -1.3645],
|
||||
[0.1533, 1.5024, -1.0432], [0.4917, 1.1529, -1.3496],
|
||||
[-2.0289, 2.4952, -0.1708], [-0.7188, 0.9956, -0.5096],
|
||||
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
|
||||
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
|
||||
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
|
||||
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
|
||||
-1.2000]]).cuda()
|
||||
xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda()
|
||||
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
|
||||
expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
|
||||
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
|
||||
[2, 2, 2, 2, 2], [7, 7, 7, 7, 7],
|
||||
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda()
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
||||
xyz = xyz.double()
|
||||
new_xyz = new_xyz.double()
|
||||
expected_idx = expected_idx.double()
|
||||
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
||||
xyz = xyz.half()
|
||||
new_xyz = new_xyz.half()
|
||||
expected_idx = expected_idx.half()
|
||||
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
|
|
@ -12,7 +12,7 @@ def test_grouping_points():
|
|||
[0, 0, 0]],
|
||||
[[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
|
||||
[0, 0, 0]]]).int().cuda()
|
||||
festures = torch.tensor([[[
|
||||
features = torch.tensor([[[
|
||||
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
|
||||
0.9268, 0.8414
|
||||
],
|
||||
|
@ -37,7 +37,7 @@ def test_grouping_points():
|
|||
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386
|
||||
]]]).cuda()
|
||||
|
||||
output = grouping_operation(festures, idx)
|
||||
output = grouping_operation(features, idx)
|
||||
expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
|
||||
[-1.3311, -1.3311, -1.3311],
|
||||
[0.9268, 0.9268, 0.9268],
|
||||
|
@ -75,3 +75,161 @@ def test_grouping_points():
|
|||
[-0.6646, -0.6646, -0.6646],
|
||||
[-0.6646, -0.6646, -0.6646]]]]).cuda()
|
||||
assert torch.allclose(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_stack_grouping_points():
|
||||
idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
|
||||
[2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
|
||||
[1, 1, 1], [0, 0, 0]]).int().cuda()
|
||||
features = torch.tensor([[
|
||||
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
|
||||
0.9268, 0.8414
|
||||
],
|
||||
[
|
||||
5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
|
||||
5.1030, 1.9360, 2.1939, 2.1581, 3.4666
|
||||
],
|
||||
[
|
||||
-1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
|
||||
-0.5732, -1.0830, -1.7561, -1.6786, -1.6967
|
||||
],
|
||||
[
|
||||
-0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
|
||||
0.7798, -0.3693, -0.9457, -0.2942, -1.8527
|
||||
],
|
||||
[
|
||||
1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
|
||||
2.7346, 6.0865, 1.5555, 4.3303, 2.8229
|
||||
],
|
||||
[
|
||||
-0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
|
||||
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386
|
||||
]]).float().cuda()
|
||||
features_batch_cnt = torch.tensor([3, 3]).int().cuda()
|
||||
indices_batch_cnt = torch.tensor([6, 6]).int().cuda()
|
||||
output = grouping_operation(features, idx, features_batch_cnt,
|
||||
indices_batch_cnt)
|
||||
expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798],
|
||||
[-0.7981, -0.7981, -0.7981],
|
||||
[-0.9280, -0.9280, -0.9280],
|
||||
[-1.3311, -1.3311, -1.3311],
|
||||
[1.3687, 1.3687, 1.3687],
|
||||
[0.9277, 0.9277, 0.9277],
|
||||
[-0.4164, -0.4164, -0.4164],
|
||||
[-1.8274, -1.8274, -1.8274],
|
||||
[0.9268, 0.9268, 0.9268],
|
||||
[0.8414, 0.8414, 0.8414]],
|
||||
[[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000]],
|
||||
[[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000]],
|
||||
[[5.4247, 5.4247, 5.4247],
|
||||
[1.5113, 1.5113, 1.5113],
|
||||
[2.3944, 2.3944, 2.3944],
|
||||
[1.4740, 1.4740, 1.4740],
|
||||
[5.0300, 5.0300, 5.0300],
|
||||
[5.1030, 5.1030, 5.1030],
|
||||
[1.9360, 1.9360, 1.9360],
|
||||
[2.1939, 2.1939, 2.1939],
|
||||
[2.1581, 2.1581, 2.1581],
|
||||
[3.4666, 3.4666, 3.4666]],
|
||||
[[0.5798, 0.5798, 0.5798],
|
||||
[-0.7981, -0.7981, -0.7981],
|
||||
[-0.9280, -0.9280, -0.9280],
|
||||
[-1.3311, -1.3311, -1.3311],
|
||||
[1.3687, 1.3687, 1.3687],
|
||||
[0.9277, 0.9277, 0.9277],
|
||||
[-0.4164, -0.4164, -0.4164],
|
||||
[-1.8274, -1.8274, -1.8274],
|
||||
[0.9268, 0.9268, 0.9268],
|
||||
[0.8414, 0.8414, 0.8414]],
|
||||
[[-1.6266, -1.6266, -1.6266],
|
||||
[-1.0281, -1.0281, -1.0281],
|
||||
[-1.0393, -1.0393, -1.0393],
|
||||
[-1.6931, -1.6931, -1.6931],
|
||||
[-1.3982, -1.3982, -1.3982],
|
||||
[-0.5732, -0.5732, -0.5732],
|
||||
[-1.0830, -1.0830, -1.0830],
|
||||
[-1.7561, -1.7561, -1.7561],
|
||||
[-1.6786, -1.6786, -1.6786],
|
||||
[-1.6967, -1.6967, -1.6967]],
|
||||
[[-0.0380, -0.0380, -0.0380],
|
||||
[-0.1880, -0.1880, -0.1880],
|
||||
[-1.5724, -1.5724, -1.5724],
|
||||
[0.6905, 0.6905, 0.6905],
|
||||
[-0.3190, -0.3190, -0.3190],
|
||||
[0.7798, 0.7798, 0.7798],
|
||||
[-0.3693, -0.3693, -0.3693],
|
||||
[-0.9457, -0.9457, -0.9457],
|
||||
[-0.2942, -0.2942, -0.2942],
|
||||
[-1.8527, -1.8527, -1.8527]],
|
||||
[[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000]],
|
||||
[[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000]],
|
||||
[[-0.0380, -0.0380, -0.0380],
|
||||
[-0.1880, -0.1880, -0.1880],
|
||||
[-1.5724, -1.5724, -1.5724],
|
||||
[0.6905, 0.6905, 0.6905],
|
||||
[-0.3190, -0.3190, -0.3190],
|
||||
[0.7798, 0.7798, 0.7798],
|
||||
[-0.3693, -0.3693, -0.3693],
|
||||
[-0.9457, -0.9457, -0.9457],
|
||||
[-0.2942, -0.2942, -0.2942],
|
||||
[-1.8527, -1.8527, -1.8527]],
|
||||
[[1.1773, 1.1773, 1.1773],
|
||||
[1.5009, 1.5009, 1.5009],
|
||||
[2.6399, 2.6399, 2.6399],
|
||||
[5.9242, 5.9242, 5.9242],
|
||||
[1.0962, 1.0962, 1.0962],
|
||||
[2.7346, 2.7346, 2.7346],
|
||||
[6.0865, 6.0865, 6.0865],
|
||||
[1.5555, 1.5555, 1.5555],
|
||||
[4.3303, 4.3303, 4.3303],
|
||||
[2.8229, 2.8229, 2.8229]],
|
||||
[[-0.0380, -0.0380, -0.0380],
|
||||
[-0.1880, -0.1880, -0.1880],
|
||||
[-1.5724, -1.5724, -1.5724],
|
||||
[0.6905, 0.6905, 0.6905],
|
||||
[-0.3190, -0.3190, -0.3190],
|
||||
[0.7798, 0.7798, 0.7798],
|
||||
[-0.3693, -0.3693, -0.3693],
|
||||
[-0.9457, -0.9457, -0.9457],
|
||||
[-0.2942, -0.2942, -0.2942],
|
||||
[-1.8527, -1.8527,
|
||||
-1.8527]]]).cuda().float()
|
||||
assert torch.allclose(output, expected_output)
|
||||
|
|
Loading…
Reference in New Issue