From 93fe4829f7f7c70b2fed22521b1c533cae86595a Mon Sep 17 00:00:00 2001 From: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Date: Fri, 28 Oct 2022 00:42:10 +0800 Subject: [PATCH] [Features] Add stack ball query and stack group points ops (#2292) * add stack sa model ops * fix lint * fix lint * fix comments * fix bug * fix lint * fix comments * fix lint * fix lint * fix --- mmcv/ops/ball_query.py | 71 +++++--- .../cuda/stack_ball_query_cuda_kernel.cuh | 68 ++++++++ .../cuda/stack_group_points_cuda_kernel.cuh | 97 +++++++++++ mmcv/ops/csrc/common/pytorch_cuda_helper.hpp | 1 + mmcv/ops/csrc/pytorch/ball_query.cpp | 18 ++ mmcv/ops/csrc/pytorch/cuda/cudabind.cpp | 74 ++++++++ .../pytorch/cuda/stack_ball_query_cuda.cu | 45 +++++ .../pytorch/cuda/stack_group_points_cuda.cu | 62 +++++++ mmcv/ops/csrc/pytorch/group_points.cpp | 42 +++++ mmcv/ops/csrc/pytorch/pybind.cpp | 32 ++++ mmcv/ops/group_points.py | 128 ++++++++++---- tests/test_ops/output.pkl | Bin 0 -> 2168 bytes tests/test_ops/test_ball_query.py | 47 +++++ tests/test_ops/test_group_points.py | 162 +++++++++++++++++- 14 files changed, 787 insertions(+), 60 deletions(-) create mode 100644 mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/stack_ball_query_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/cuda/stack_group_points_cuda.cu create mode 100644 tests/test_ops/output.pkl diff --git a/mmcv/ops/ball_query.py b/mmcv/ops/ball_query.py index d24e0446c..a89b36b52 100644 --- a/mmcv/ops/ball_query.py +++ b/mmcv/ops/ball_query.py @@ -1,28 +1,44 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple +from typing import Optional, Tuple import torch from torch.autograd import Function from ..utils import ext_loader -ext_module = ext_loader.load_ext('_ext', ['ball_query_forward']) +ext_module = ext_loader.load_ext( + '_ext', ['ball_query_forward', 'stack_ball_query_forward']) class BallQuery(Function): """Find nearby points in spherical space.""" @staticmethod - def forward(ctx, min_radius: float, max_radius: float, sample_num: int, - xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor: + def forward( + ctx, + min_radius: float, + max_radius: float, + sample_num: int, + xyz: torch.Tensor, + center_xyz: torch.Tensor, + xyz_batch_cnt: Optional[torch.Tensor] = None, + center_xyz_batch_cnt: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Args: min_radius (float): minimum radius of the balls. max_radius (float): maximum radius of the balls. sample_num (int): maximum number of features in the balls. - xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features. + xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features, + or staked input (N1 + N2 ..., 3). center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball - query. + query, or staked input (M1 + M2 ..., 3). + xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in + each batch, just like (N1, N2, ...). Defaults to None. + New in version 1.7.0. + center_xyz_batch_cnt: (batch_size): Stacked centers coordinates + nums in each batch, just line (M1, M2, ...). Defaults to None. + New in version 1.7.0. Returns: torch.Tensor: (B, npoint, nsample) tensor with the indices of the @@ -31,21 +47,34 @@ class BallQuery(Function): assert center_xyz.is_contiguous() assert xyz.is_contiguous() assert min_radius < max_radius - - B, N, _ = xyz.size() - npoint = center_xyz.size(1) - idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int) - - ext_module.ball_query_forward( - center_xyz, - xyz, - idx, - b=B, - n=N, - m=npoint, - min_radius=min_radius, - max_radius=max_radius, - nsample=sample_num) + if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None: + assert xyz_batch_cnt.dtype == torch.int + assert center_xyz_batch_cnt.dtype == torch.int + idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num), + dtype=torch.int32) + ext_module.stack_ball_query_forward( + center_xyz, + center_xyz_batch_cnt, + xyz, + xyz_batch_cnt, + idx, + max_radius=max_radius, + nsample=sample_num, + ) + else: + B, N, _ = xyz.size() + npoint = center_xyz.size(1) + idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32) + ext_module.ball_query_forward( + center_xyz, + xyz, + idx, + b=B, + n=N, + m=npoint, + min_radius=min_radius, + max_radius=max_radius, + nsample=sample_num) if torch.__version__ != 'parrots': ctx.mark_non_differentiable(idx) return idx diff --git a/mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh new file mode 100644 index 000000000..06caefa18 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh @@ -0,0 +1,68 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu +#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH +#define STACK_BALL_QUERY_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void stack_ball_query_forward_cuda_kernel( + int B, int M, float radius, int nsample, const T *new_xyz, + const int *new_xyz_batch_cnt, const T *xyz, const int *xyz_batch_cnt, + int *idx) { + // :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // :param xyz_batch_cnt: (batch_size), [N1, N2, ...] + // :param new_xyz: (M1 + M2 ..., 3) centers of the ball query + // :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // output: + // idx: (M, nsample) + const T *cur_xyz = xyz; + int *cur_idx = idx; + CUDA_1D_KERNEL_LOOP(pt_idx, M) { + int bs_idx = 0; + for (int pt_cnt = 0; bs_idx < B; bs_idx++) { + pt_cnt += new_xyz_batch_cnt[bs_idx]; + if (pt_idx < pt_cnt) break; + } + + int xyz_batch_start_idx = 0; + for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k]; + + const T *new_xyz_p = new_xyz + pt_idx * 3; + cur_xyz += xyz_batch_start_idx * 3; + cur_idx += pt_idx * nsample; + + float radius2 = radius * radius; + T new_x = new_xyz_p[0]; + T new_y = new_xyz_p[1]; + T new_z = new_xyz_p[2]; + int n = xyz_batch_cnt[bs_idx]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + T x = cur_xyz[k * 3 + 0]; + T y = cur_xyz[k * 3 + 1]; + T z = cur_xyz[k * 3 + 2]; + T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 < radius2) { + if (cnt == 0) { + for (int l = 0; l < nsample; ++l) { + cur_idx[l] = k; + } + } + cur_idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } + if (cnt == 0) cur_idx[0] = -1; + } +} + +#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh new file mode 100644 index 000000000..4ef3663d0 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh @@ -0,0 +1,97 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu +#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH +#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif +#include +template +__global__ void stack_group_points_forward_cuda_kernel( + int b, int c, int m, int nsample, const T *features, + const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, + T *out) { + // :param features: (N1 + N2 ..., C) tensor of features to group + // :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the + // indices of features to group with :param idx: (M1 + M2 ..., nsample) tensor + // containing the indices of features to group with :param idx_batch_cnt: + // (batch_size) [M1 + M2 ...] tensor containing the indices of features to + // group with :return: + // output: (M1 + M2, C, nsample) tensor + CUDA_1D_KERNEL_LOOP(index, m * c * nsample) { + const T *cur_features = features; + const int *cur_idx = idx; + int sample_idx = index % nsample; + int c_idx = (index / nsample) % c; + int pt_idx = (index / nsample / c); + + if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return; + int bs_idx = 0, pt_cnt = idx_batch_cnt[0]; + for (int k = 1; k < b; k++) { + if (pt_idx < pt_cnt) break; + pt_cnt += idx_batch_cnt[k]; + bs_idx = k; + } + + int features_batch_start_idx = 0; + int features_batch_end_idx = features_batch_cnt[0]; + for (int k = 0; k < bs_idx; k++) { + features_batch_start_idx += features_batch_cnt[k]; + features_batch_end_idx = + features_batch_start_idx + features_batch_cnt[k + 1]; + } + cur_features += features_batch_start_idx * c; + + cur_idx += pt_idx * nsample + sample_idx; + int in_idx = cur_idx[0] * c + c_idx; + int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx; + if (in_idx < features_batch_end_idx * c) { + out[out_idx] = cur_features[in_idx]; + } + } +} + +template +__global__ void stack_group_points_backward_cuda_kernel( + int b, int c, int m, int n, int nsample, const T *grad_out, const int *idx, + const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) { + // :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the + // output from forward :param idx: (M1 + M2 ..., nsample) tensor containing + // the indices of features to group with :param idx_batch_cnt: (batch_size) + // [M1 + M2 ...] tensor containing the indices of features to group with + // :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the + // indices of features to group with :return: + // grad_features: (N1 + N2 ..., C) gradient of the features + CUDA_1D_KERNEL_LOOP(index, m * c * nsample) { + const T *cur_grad_out = grad_out; + const int *cur_idx = idx; + T *cur_grad_features = grad_features; + int sample_idx = index % nsample; + int c_idx = (index / nsample) % c; + int pt_idx = (index / nsample / c); + + if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return; + + int bs_idx = 0, pt_cnt = idx_batch_cnt[0]; + for (int k = 1; k < b; k++) { + if (pt_idx < pt_cnt) break; + pt_cnt += idx_batch_cnt[k]; + bs_idx = k; + } + + int features_batch_start_idx = 0; + for (int k = 0; k < bs_idx; k++) + features_batch_start_idx += features_batch_cnt[k]; + + cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx; + cur_idx += pt_idx * nsample + sample_idx; + cur_grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx; + + atomicAdd(cur_grad_features, cur_grad_out[0]); + } +} + +#endif // GROUP_POINTS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/pytorch_cuda_helper.hpp b/mmcv/ops/csrc/common/pytorch_cuda_helper.hpp index 9869b535f..52e512695 100644 --- a/mmcv/ops/csrc/common/pytorch_cuda_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_cuda_helper.hpp @@ -15,5 +15,6 @@ using at::Tensor; using phalf = at::Half; #define __PHALF(x) (x) +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #endif // PYTORCH_CUDA_HELPER diff --git a/mmcv/ops/csrc/pytorch/ball_query.cpp b/mmcv/ops/csrc/pytorch/ball_query.cpp index 1c9e7a207..b0534db5c 100644 --- a/mmcv/ops/csrc/pytorch/ball_query.cpp +++ b/mmcv/ops/csrc/pytorch/ball_query.cpp @@ -18,3 +18,21 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample, new_xyz_tensor, xyz_tensor, idx_tensor); } + +void stack_ball_query_forward_impl(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx) { + DISPATCH_DEVICE_IMPL(stack_ball_query_forward_impl, max_radius, nsample, + new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx); +} + +void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt, + Tensor xyz_tensor, Tensor xyz_batch_cnt, + Tensor idx_tensor, float max_radius, + int nsample) { + stack_ball_query_forward_impl(max_radius, nsample, new_xyz_tensor, + new_xyz_batch_cnt, xyz_tensor, xyz_batch_cnt, + idx_tensor); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index ade111d14..e55863406 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -67,6 +67,30 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius, Tensor idx); REGISTER_DEVICE_IMPL(ball_query_forward_impl, CUDA, ball_query_forward_cuda); +void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, + const Tensor xyz_batch_cnt, + Tensor idx); + +void stack_ball_query_forward_cuda(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx) { + StackBallQueryForwardCUDAKernelLauncher( + max_radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx); +}; + +void stack_ball_query_forward_impl(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx); +REGISTER_DEVICE_IMPL(stack_ball_query_forward_impl, CUDA, + stack_ball_query_forward_cuda); + void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); @@ -571,6 +595,56 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA, REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA, group_points_backward_cuda); +void StackGroupPointsForwardCUDAKernelLauncher( + int b, int c, int m, int nsample, const Tensor features_tensor, + const Tensor features_batch_cnt_tensor, const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, Tensor out_tensor); +void StackGroupPointsBackwardCUDAKernelLauncher( + int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor, + const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor, + const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor); + +void stack_group_points_forward_cuda(int b, int c, int m, int nsample, + const Tensor features_tensor, + const Tensor features_batch_cnt_tensor, + const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, + Tensor out_tensor) { + StackGroupPointsForwardCUDAKernelLauncher( + b, c, m, nsample, features_tensor, features_batch_cnt_tensor, idx_tensor, + idx_batch_cnt_tensor, out_tensor); +}; + +void stack_group_points_backward_cuda(int b, int c, int m, int n, int nsample, + const Tensor grad_out_tensor, + const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, + const Tensor features_batch_cnt_tensor, + Tensor grad_features_tensor) { + StackGroupPointsBackwardCUDAKernelLauncher( + b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor, + features_batch_cnt_tensor, grad_features_tensor); +}; + +void stack_group_points_forward_impl(int b, int c, int m, int nsample, + const Tensor features_tensor, + const Tensor features_batch_cnt_tensor, + const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, + Tensor out_tensor); + +void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample, + const Tensor grad_out_tensor, + const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, + const Tensor features_batch_cnt_tensor, + Tensor grad_features_tensor); + +REGISTER_DEVICE_IMPL(stack_group_points_forward_impl, CUDA, + stack_group_points_forward_cuda); +REGISTER_DEVICE_IMPL(stack_group_points_backward_impl, CUDA, + stack_group_points_backward_cuda); + void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, const Tensor boxes_a, const int num_b, diff --git a/mmcv/ops/csrc/pytorch/cuda/stack_ball_query_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/stack_ball_query_cuda.cu new file mode 100644 index 000000000..3095df5ee --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/stack_ball_query_cuda.cu @@ -0,0 +1,45 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu + +#include +#include +#include + +#include "pytorch_cuda_helper.hpp" +#include "stack_ball_query_cuda_kernel.cuh" +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, + const Tensor xyz_batch_cnt, + Tensor idx) { + at::cuda::CUDAGuard device_guard(new_xyz.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // const float *new_xyz_ptr = new_xyz.data_ptr(); + // const float *xyz_ptr = xyz.data_ptr(); + // const int *new_xyz_batch_cnt_ptr = new_xyz_batch_cnt.data_ptr(); + // const int *xyz_batch_cnt_ptr = xyz_batch_cnt.data_ptr(); + // int *idx_ptr = idx.data_ptr(); + + int B = xyz_batch_cnt.size(0); + int M = new_xyz.size(0); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + new_xyz.scalar_type(), "stack_ball_query_forward_cuda_kernel", [&] { + stack_ball_query_forward_cuda_kernel + <<>>( + B, M, max_radius, nsample, new_xyz.data_ptr(), + new_xyz_batch_cnt.data_ptr(), xyz.data_ptr(), + xyz_batch_cnt.data_ptr(), idx.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/stack_group_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/stack_group_points_cuda.cu new file mode 100644 index 000000000..9f903b02a --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/stack_group_points_cuda.cu @@ -0,0 +1,62 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu +#include +#include + +#include "pytorch_cuda_helper.hpp" +#include "stack_group_points_cuda_kernel.cuh" + +void StackGroupPointsForwardCUDAKernelLauncher( + int b, int c, int m, int nsample, const Tensor features_tensor, + const Tensor features_batch_cnt_tensor, const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, Tensor out_tensor) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + at::cuda::CUDAGuard device_guard(features_tensor.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features_tensor.scalar_type(), "stack_group_points_forward_cuda_kernel", + [&] { + stack_group_points_forward_cuda_kernel + <<>>( + b, c, m, nsample, features_tensor.data_ptr(), + features_batch_cnt_tensor.data_ptr(), + idx_tensor.data_ptr(), + idx_batch_cnt_tensor.data_ptr(), + out_tensor.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void StackGroupPointsBackwardCUDAKernelLauncher( + int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor, + const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor, + const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor) { + at::cuda::CUDAGuard device_guard(grad_features_tensor.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_features_tensor.scalar_type(), + "stack_group_points_backward_cuda_kernel", [&] { + stack_group_points_backward_cuda_kernel + <<>>( + b, c, m, n, nsample, grad_out_tensor.data_ptr(), + idx_tensor.data_ptr(), + idx_batch_cnt_tensor.data_ptr(), + features_batch_cnt_tensor.data_ptr(), + grad_features_tensor.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/group_points.cpp b/mmcv/ops/csrc/pytorch/group_points.cpp index cdd190d40..850deed98 100644 --- a/mmcv/ops/csrc/pytorch/group_points.cpp +++ b/mmcv/ops/csrc/pytorch/group_points.cpp @@ -32,3 +32,45 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, group_points_backward_impl(b, c, n, npoints, nsample, grad_out_tensor, idx_tensor, grad_points_tensor); } + +void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample, + const Tensor grad_out_tensor, + const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, + const Tensor features_batch_cnt_tensor, + Tensor grad_features_tensor) { + DISPATCH_DEVICE_IMPL(stack_group_points_backward_impl, b, c, m, n, nsample, + grad_out_tensor, idx_tensor, idx_batch_cnt_tensor, + features_batch_cnt_tensor, grad_features_tensor); +} + +void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor idx_batch_cnt_tensor, + Tensor features_batch_cnt_tensor, + Tensor grad_features_tensor, int b, int c, + int m, int n, int nsample) { + stack_group_points_backward_impl( + b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor, + features_batch_cnt_tensor, grad_features_tensor); +} + +void stack_group_points_forward_impl(int b, int c, int m, int nsample, + const Tensor features_tensor, + const Tensor features_batch_cnt_tensor, + const Tensor idx_tensor, + const Tensor idx_batch_cnt_tensor, + Tensor out_tensor) { + DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample, + features_tensor, features_batch_cnt_tensor, idx_tensor, + idx_batch_cnt_tensor, out_tensor); +} + +void stack_group_points_forward(Tensor features_tensor, + Tensor features_batch_cnt_tensor, + Tensor idx_tensor, Tensor idx_batch_cnt_tensor, + Tensor out_tensor, int b, int c, int m, + int nsample) { + DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample, + features_tensor, features_batch_cnt_tensor, idx_tensor, + idx_batch_cnt_tensor, out_tensor); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 22ff0db44..4947b7215 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -75,6 +75,18 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, Tensor grad_points_tensor, int b, int c, int n, int npoints, int nsample); +void stack_group_points_forward(Tensor features_tensor, + Tensor features_batch_cnt_tensor, + Tensor idx_tensor, Tensor idx_batch_cnt_tensor, + Tensor out_tensor, int b, int c, int m, + int nsample); + +void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, + Tensor idx_batch_cnt_tensor, + Tensor features_batch_cnt_tensor, + Tensor grad_features_tensor, int b, int c, + int m, int n, int nsample); + void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); @@ -240,6 +252,10 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, Tensor idx_tensor, int b, int n, int m, float min_radius, float max_radius, int nsample); +void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt, + Tensor xyz_tensor, Tensor xyz_batch_cnt, + Tensor idx_tensor, float max_radius, int nsample); + void prroi_pool_forward(Tensor input, Tensor rois, Tensor output, int pooled_height, int pooled_width, float spatial_scale); @@ -557,6 +573,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "group_points_backward", py::arg("grad_out_tensor"), py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"), py::arg("c"), py::arg("n"), py::arg("npoints"), py::arg("nsample")); + m.def("stack_group_points_forward", &stack_group_points_forward, + "stack_group_points_forward", py::arg("features_tensor"), + py::arg("features_batch_cnt_tensor"), py::arg("idx_tensor"), + py::arg("idx_batch_cnt_tensor"), py::arg("out_tensor"), py::arg("b"), + py::arg("c"), py::arg("m"), py::arg("nsample")); + m.def("stack_group_points_backward", &stack_group_points_backward, + "stack_group_points_backward", py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("idx_batch_cnt_tensor"), + py::arg("features_batch_cnt_tensor"), py::arg("grad_features_tensor"), + py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n"), + py::arg("nsample")); m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), py::arg("new_xyz_tensor"), py::arg("idx_tensor"), @@ -726,6 +753,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"), py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"), py::arg("max_radius"), py::arg("nsample")); + m.def("stack_ball_query_forward", &stack_ball_query_forward, + "stack_ball_query_forward", py::arg("new_xyz_tensor"), + py::arg("new_xyz_batch_cnt"), py::arg("xyz_tensor"), + py::arg("xyz_batch_cnt"), py::arg("idx_tensor"), py::arg("max_radius"), + py::arg("nsample")); m.def("roi_align_rotated_forward", &roi_align_rotated_forward, "roi_align_rotated forward", py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py index 5268a265f..95f359e86 100644 --- a/mmcv/ops/group_points.py +++ b/mmcv/ops/group_points.py @@ -9,8 +9,10 @@ from ..utils import ext_loader from .ball_query import ball_query from .knn import knn -ext_module = ext_loader.load_ext( - '_ext', ['group_points_forward', 'group_points_backward']) +ext_module = ext_loader.load_ext('_ext', [ + 'group_points_forward', 'group_points_backward', + 'stack_group_points_forward', 'stack_group_points_backward' +]) class QueryAndGroup(nn.Module): @@ -183,39 +185,71 @@ class GroupingOperation(Function): """Group feature with given index.""" @staticmethod - def forward(ctx, features: torch.Tensor, - indices: torch.Tensor) -> torch.Tensor: + def forward( + ctx, + features: torch.Tensor, + indices: torch.Tensor, + features_batch_cnt: Optional[torch.Tensor] = None, + indices_batch_cnt: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: - features (Tensor): (B, C, N) tensor of features to group. - indices (Tensor): (B, npoint, nsample) the indices of - features to group with. + features (Tensor): Tensor of features to group, input shape is + (B, C, N) or stacked inputs (N1 + N2 ..., C). + indices (Tensor): The indices of features to group with, input + shape is (B, npoint, nsample) or stacked inputs + (M1 + M2 ..., nsample). + features_batch_cnt (Tensor, optional): Input features nums in + each batch, just like (N1, N2, ...). Defaults to None. + New in version 1.7.0. + indices_batch_cnt (Tensor, optional): Input indices nums in + each batch, just like (M1, M2, ...). Defaults to None. + New in version 1.7.0. Returns: - Tensor: (B, C, npoint, nsample) Grouped features. + Tensor: Grouped features, the shape is (B, C, npoint, nsample) + or (M1 + M2 ..., C, nsample). """ features = features.contiguous() indices = indices.contiguous() + if features_batch_cnt is not None and indices_batch_cnt is not None: + assert features_batch_cnt.dtype == torch.int + assert indices_batch_cnt.dtype == torch.int + M, nsample = indices.size() + N, C = features.size() + B = indices_batch_cnt.shape[0] + output = features.new_zeros((M, C, nsample)) + ext_module.stack_group_points_forward( + features, + features_batch_cnt, + indices, + indices_batch_cnt, + output, + b=B, + m=M, + c=C, + nsample=nsample) + ctx.for_backwards = (B, N, indices, features_batch_cnt, + indices_batch_cnt) + else: + B, nfeatures, nsample = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) - B, nfeatures, nsample = indices.size() - _, C, N = features.size() - output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + ext_module.group_points_forward( + features, + indices, + output, + b=B, + c=C, + n=N, + npoints=nfeatures, + nsample=nsample) - ext_module.group_points_forward( - features, - indices, - output, - b=B, - c=C, - n=N, - npoints=nfeatures, - nsample=nsample) - - ctx.for_backwards = (indices, N) + ctx.for_backwards = (indices, N) return output @staticmethod - def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]: + def backward(ctx, grad_out: torch.Tensor) -> Tuple: """ Args: grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients @@ -224,22 +258,42 @@ class GroupingOperation(Function): Returns: Tensor: (B, C, N) gradient of the features. """ - idx, N = ctx.for_backwards + if len(ctx.for_backwards) != 5: + idx, N = ctx.for_backwards - B, C, npoint, nsample = grad_out.size() - grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + B, C, npoint, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() - grad_out_data = grad_out.data.contiguous() - ext_module.group_points_backward( - grad_out_data, - idx, - grad_features.data, - b=B, - c=C, - n=N, - npoints=npoint, - nsample=nsample) - return grad_features, None + grad_out_data = grad_out.data.contiguous() + ext_module.group_points_backward( + grad_out_data, + idx, + grad_features.data, + b=B, + c=C, + n=N, + npoints=npoint, + nsample=nsample) + return grad_features, None + else: + B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards + + M, C, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(N, C).zero_() + + grad_out_data = grad_out.data.contiguous() + ext_module.stack_group_points_backward( + grad_out_data, + idx, + idx_batch_cnt, + features_batch_cnt, + grad_features.data, + b=B, + c=C, + m=M, + n=N, + nsample=nsample) + return grad_features, None, None, None grouping_operation = GroupingOperation.apply diff --git a/tests/test_ops/output.pkl b/tests/test_ops/output.pkl new file mode 100644 index 0000000000000000000000000000000000000000..bcb7b2dd606930522b102d3a59fef70d6f3eb885 GIT binary patch literal 2168 zcmd^BO=}ZT6n)9$%dxbj6hQ)YAwmUB(`ibn3r9kU!c!bm#NZ}OCqojPB)-WcD+R$t zai<$sA{FVzO@Bc%E<(Yj3zx2Rp$oyKf~fB%IU$qUty%QK;pDyC`#w$%@5bOtgt0_| z9g0~t$4u9%RNMAa$@I+B{d-O>JI(F};!)W08Zs+YYezQ7e8+7|IAmep_^+w!W7dQ-jWmTcE9ZB#8!6^ZkCal#X7 zUYtxBJf8SCwVT|Ns}hVOub*Uz!1b4cC(C6cJtYom^EzCK=7#b5pV`6Ab42_AQF)=hIhQ`FlZC`kb z7@i`Ar-c#0UFBA$q;{==q<+~Z%El+MR(UwZHle!Z>lL>VI- z{ov2Av%?3!ZM#j`NOIXTW9=@``)IJD(hl!mmT!mUFHJCbh-lbTN88OTeG!Q94m(~w zdiG?X@{b&iR*yBP@r6c@I1^atyKJ#oXmD|Z$6^--NejxwVLCaP0^IEnS)SUv3|ZIv VbZYQ_BGj9UQV*9k3Zwjf?q5M}yi@=H literal 0 HcmV?d00001 diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py index 4c78dc660..d3fc7912c 100644 --- a/tests/test_ops/test_ball_query.py +++ b/tests/test_ops/test_ball_query.py @@ -53,3 +53,50 @@ def test_ball_query(): [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]).cuda() assert torch.all(idx == expected_idx) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_stack_ball_query(): + new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], + [-2.2769, 2.7817, -0.2334], + [-0.4003, 2.4666, -0.5116], + [-0.0740, 1.3147, -1.3625], + [-0.0740, 1.3147, -1.3625], + [-2.0289, 2.4952, -0.1708], + [-2.0668, 6.0278, -0.4875], + [0.4066, 1.4211, -0.2947], + [-2.0289, 2.4952, -0.1708], + [-2.0289, 2.4952, -0.1708]]).cuda() + new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda() + xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], + [-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466], + [-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626], + [-2.2769, 2.7817, -0.2334], [-0.2822, 1.3192, -1.3645], + [0.1533, 1.5024, -1.0432], [0.4917, 1.1529, -1.3496], + [-2.0289, 2.4952, -0.1708], [-0.7188, 0.9956, -0.5096], + [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], + [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], + [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, + -1.2000]]).cuda() + xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda() + idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt) + expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], + [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], + [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda() + assert torch.all(idx == expected_idx) + + xyz = xyz.double() + new_xyz = new_xyz.double() + expected_idx = expected_idx.double() + idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt) + assert torch.all(idx == expected_idx) + + xyz = xyz.half() + new_xyz = new_xyz.half() + expected_idx = expected_idx.half() + idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt) + assert torch.all(idx == expected_idx) diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py index b295437fb..48c0161ba 100644 --- a/tests/test_ops/test_group_points.py +++ b/tests/test_ops/test_group_points.py @@ -12,7 +12,7 @@ def test_grouping_points(): [0, 0, 0]], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]).int().cuda() - festures = torch.tensor([[[ + features = torch.tensor([[[ 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, 0.9268, 0.8414 ], @@ -37,7 +37,7 @@ def test_grouping_points(): -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 ]]]).cuda() - output = grouping_operation(festures, idx) + output = grouping_operation(features, idx) expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311], [0.9268, 0.9268, 0.9268], @@ -75,3 +75,161 @@ def test_grouping_points(): [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]]).cuda() assert torch.allclose(output, expected_output) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_stack_grouping_points(): + idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0], + [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], + [1, 1, 1], [0, 0, 0]]).int().cuda() + features = torch.tensor([[ + 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, + 0.9268, 0.8414 + ], + [ + 5.4247, 1.5113, 2.3944, 1.4740, 5.0300, + 5.1030, 1.9360, 2.1939, 2.1581, 3.4666 + ], + [ + -1.6266, -1.0281, -1.0393, -1.6931, -1.3982, + -0.5732, -1.0830, -1.7561, -1.6786, -1.6967 + ], + [ + -0.0380, -0.1880, -1.5724, 0.6905, -0.3190, + 0.7798, -0.3693, -0.9457, -0.2942, -1.8527 + ], + [ + 1.1773, 1.5009, 2.6399, 5.9242, 1.0962, + 2.7346, 6.0865, 1.5555, 4.3303, 2.8229 + ], + [ + -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, + -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 + ]]).float().cuda() + features_batch_cnt = torch.tensor([3, 3]).int().cuda() + indices_batch_cnt = torch.tensor([6, 6]).int().cuda() + output = grouping_operation(features, idx, features_batch_cnt, + indices_batch_cnt) + expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798], + [-0.7981, -0.7981, -0.7981], + [-0.9280, -0.9280, -0.9280], + [-1.3311, -1.3311, -1.3311], + [1.3687, 1.3687, 1.3687], + [0.9277, 0.9277, 0.9277], + [-0.4164, -0.4164, -0.4164], + [-1.8274, -1.8274, -1.8274], + [0.9268, 0.9268, 0.9268], + [0.8414, 0.8414, 0.8414]], + [[0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000]], + [[0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000]], + [[5.4247, 5.4247, 5.4247], + [1.5113, 1.5113, 1.5113], + [2.3944, 2.3944, 2.3944], + [1.4740, 1.4740, 1.4740], + [5.0300, 5.0300, 5.0300], + [5.1030, 5.1030, 5.1030], + [1.9360, 1.9360, 1.9360], + [2.1939, 2.1939, 2.1939], + [2.1581, 2.1581, 2.1581], + [3.4666, 3.4666, 3.4666]], + [[0.5798, 0.5798, 0.5798], + [-0.7981, -0.7981, -0.7981], + [-0.9280, -0.9280, -0.9280], + [-1.3311, -1.3311, -1.3311], + [1.3687, 1.3687, 1.3687], + [0.9277, 0.9277, 0.9277], + [-0.4164, -0.4164, -0.4164], + [-1.8274, -1.8274, -1.8274], + [0.9268, 0.9268, 0.9268], + [0.8414, 0.8414, 0.8414]], + [[-1.6266, -1.6266, -1.6266], + [-1.0281, -1.0281, -1.0281], + [-1.0393, -1.0393, -1.0393], + [-1.6931, -1.6931, -1.6931], + [-1.3982, -1.3982, -1.3982], + [-0.5732, -0.5732, -0.5732], + [-1.0830, -1.0830, -1.0830], + [-1.7561, -1.7561, -1.7561], + [-1.6786, -1.6786, -1.6786], + [-1.6967, -1.6967, -1.6967]], + [[-0.0380, -0.0380, -0.0380], + [-0.1880, -0.1880, -0.1880], + [-1.5724, -1.5724, -1.5724], + [0.6905, 0.6905, 0.6905], + [-0.3190, -0.3190, -0.3190], + [0.7798, 0.7798, 0.7798], + [-0.3693, -0.3693, -0.3693], + [-0.9457, -0.9457, -0.9457], + [-0.2942, -0.2942, -0.2942], + [-1.8527, -1.8527, -1.8527]], + [[0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000]], + [[0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000]], + [[-0.0380, -0.0380, -0.0380], + [-0.1880, -0.1880, -0.1880], + [-1.5724, -1.5724, -1.5724], + [0.6905, 0.6905, 0.6905], + [-0.3190, -0.3190, -0.3190], + [0.7798, 0.7798, 0.7798], + [-0.3693, -0.3693, -0.3693], + [-0.9457, -0.9457, -0.9457], + [-0.2942, -0.2942, -0.2942], + [-1.8527, -1.8527, -1.8527]], + [[1.1773, 1.1773, 1.1773], + [1.5009, 1.5009, 1.5009], + [2.6399, 2.6399, 2.6399], + [5.9242, 5.9242, 5.9242], + [1.0962, 1.0962, 1.0962], + [2.7346, 2.7346, 2.7346], + [6.0865, 6.0865, 6.0865], + [1.5555, 1.5555, 1.5555], + [4.3303, 4.3303, 4.3303], + [2.8229, 2.8229, 2.8229]], + [[-0.0380, -0.0380, -0.0380], + [-0.1880, -0.1880, -0.1880], + [-1.5724, -1.5724, -1.5724], + [0.6905, 0.6905, 0.6905], + [-0.3190, -0.3190, -0.3190], + [0.7798, 0.7798, 0.7798], + [-0.3693, -0.3693, -0.3693], + [-0.9457, -0.9457, -0.9457], + [-0.2942, -0.2942, -0.2942], + [-1.8527, -1.8527, + -1.8527]]]).cuda().float() + assert torch.allclose(output, expected_output)