From 75cae78c55a9bf9a57d60a64d9cce0b61e1ac700 Mon Sep 17 00:00:00 2001 From: dingchang Date: Sat, 23 Oct 2021 14:01:31 +0800 Subject: [PATCH] [Feature] Add group points ops from mmdet3d (#1415) * add op (group points) and its related ops (ball query and knn) in mmdet3d * refactor code * fix typo * refactor code * fix typo * refactor code * make input contiguous Co-authored-by: zhouzaida --- docs/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 13 +- .../common/cuda/group_points_cuda_kernel.cuh | 63 +++++ .../csrc/pytorch/cuda/group_points_cuda.cu | 61 +++++ mmcv/ops/csrc/pytorch/group_points.cpp | 58 +++++ mmcv/ops/csrc/pytorch/pybind.cpp | 20 ++ mmcv/ops/group_points.py | 224 ++++++++++++++++++ tests/test_ops/test_group_points.py | 76 ++++++ 8 files changed, 510 insertions(+), 6 deletions(-) create mode 100644 mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/group_points.cpp create mode 100644 mmcv/ops/group_points.py create mode 100644 tests/test_ops/test_group_points.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index 900705afa..2729e441c 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -16,6 +16,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - FurthestPointSample - FurthestPointSampleWithDist - GeneralizedAttention +- GroupPoints - KNN - MaskedConv - NMS diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index b5a06c761..999e090a4 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -22,6 +22,7 @@ from .furthest_point_sample import (furthest_point_sample, furthest_point_sample_with_dist) from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu from .gather_points import gather_points +from .group_points import GroupAll, QueryAndGroup, grouping_operation from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev @@ -68,13 +69,13 @@ __all__ = [ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', - 'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev', - 'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated', - 'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', + 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', + 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', - 'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', - 'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu', - 'points_in_boxes_all' + 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization', + 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', + 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all' ] diff --git a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh new file mode 100644 index 000000000..9cfc2dc86 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh @@ -0,0 +1,63 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu +#ifndef GROUP_POINTS_CUDA_KERNEL_CUH +#define GROUP_POINTS_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void group_points_forward_cuda_kernel(int b, int c, int n, + int npoints, int nsample, + const T *points, + const int *__restrict__ idx, + T *out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + +template +__global__ void group_points_backward_cuda_kernel(int b, int c, int n, + int npoints, int nsample, + const T *grad_out, + const int *__restrict__ idx, + T *grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]); +} + +#endif // GROUP_POINTS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu new file mode 100644 index 000000000..e7c57b018 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu @@ -0,0 +1,61 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu +#include +#include + +#include "group_points_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor points, + const Tensor idx, Tensor out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "group_points_forward_cuda_kernel", [&] { + group_points_forward_cuda_kernel + <<>>( + b, c, n, npoints, nsample, points.data_ptr(), + idx.data_ptr(), out.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor grad_out, + const Tensor idx, + Tensor grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + + at::cuda::CUDAGuard device_guard(grad_out.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_out.scalar_type(), "group_points_backward_cuda_kernel", [&] { + group_points_backward_cuda_kernel + <<>>( + b, c, n, npoints, nsample, grad_out.data_ptr(), + idx.data_ptr(), grad_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/group_points.cpp b/mmcv/ops/csrc/pytorch/group_points.cpp new file mode 100644 index 000000000..1ebc947a1 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/group_points.cpp @@ -0,0 +1,58 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor points, + const Tensor idx, Tensor out); +void group_points_forward_cuda(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out) { + GroupPointsForwardCUDAKernelLauncher(b, c, n, npoints, nsample, points, idx, + out); +}; + +void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor grad_out, + const Tensor idx, + Tensor grad_points); +void group_points_backward_cuda(int b, int c, int n, int npoints, int nsample, + const Tensor grad_out, const Tensor idx, + Tensor grad_points) { + GroupPointsBackwardCUDAKernelLauncher(b, c, n, npoints, nsample, grad_out, + idx, grad_points); +}; +#endif + +void group_points_forward(int b, int c, int n, int npoints, int nsample, + Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + group_points_forward_cuda(b, c, n, npoints, nsample, points_tensor, + idx_tensor, out_tensor); +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } +} + +void group_points_backward(int b, int c, int n, int npoints, int nsample, + Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + group_points_backward_cuda(b, c, n, npoints, nsample, grad_out_tensor, + idx_tensor, grad_points_tensor); +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 7b39a5e44..8f52e26e8 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -65,6 +65,14 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois, int pooled_width, float spatial_scale, int sampling_ratio, float gamma); +void group_points_forward(int b, int c, int n, int npoints, int nsample, + Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor); + +void group_points_backward(int b, int c, int n, int npoints, int nsample, + Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor); + void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); @@ -453,6 +461,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); + m.def("group_points_forward", &group_points_forward, "group_points_forward", + py::arg("b"), py::arg("c"), py::arg("n"), py::arg("npoints"), + py::arg("nsample"), py::arg("points_tensor"), py::arg("idx_tensor"), + py::arg("out_tensor")); + m.def("group_points_backward", &group_points_backward, + "group_points_backward", py::arg("b"), py::arg("c"), py::arg("n"), + py::arg("npoints"), py::arg("nsample"), py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("grad_points_tensor")); + m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), + py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), + py::arg("new_xyz_tensor"), py::arg("idx_tensor"), + py::arg("dist2_tensor")); m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward, "iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"), py::arg("boxes_b"), py::arg("ans_overlap")); diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py new file mode 100644 index 000000000..5afd22794 --- /dev/null +++ b/mmcv/ops/group_points.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader +from .ball_query import ball_query +from .knn import knn + +ext_module = ext_loader.load_ext( + '_ext', ['group_points_forward', 'group_points_backward']) + + +class QueryAndGroup(nn.Module): + """Groups points with a ball query of radius. + + Args: + max_radius (float): The maximum radius of the balls. + If None is given, we will use kNN sampling instead of ball query. + sample_num (int): Maximum number of features to gather in the ball. + min_radius (float, optional): The minimum radius of the balls. + Default: 0. + use_xyz (bool, optional): Whether to use xyz. + Default: True. + return_grouped_xyz (bool, optional): Whether to return grouped xyz. + Default: False. + normalize_xyz (bool, optional): Whether to normalize xyz. + Default: False. + uniform_sample (bool, optional): Whether to sample uniformly. + Default: False + return_unique_cnt (bool, optional): Whether to return the count of + unique samples. Default: False. + return_grouped_idx (bool, optional): Whether to return grouped idx. + Default: False. + """ + + def __init__(self, + max_radius, + sample_num, + min_radius=0, + use_xyz=True, + return_grouped_xyz=False, + normalize_xyz=False, + uniform_sample=False, + return_unique_cnt=False, + return_grouped_idx=False): + super().__init__() + self.max_radius = max_radius + self.min_radius = min_radius + self.sample_num = sample_num + self.use_xyz = use_xyz + self.return_grouped_xyz = return_grouped_xyz + self.normalize_xyz = normalize_xyz + self.uniform_sample = uniform_sample + self.return_unique_cnt = return_unique_cnt + self.return_grouped_idx = return_grouped_idx + if self.return_unique_cnt: + assert self.uniform_sample, \ + 'uniform_sample should be True when ' \ + 'returning the count of unique samples' + if self.max_radius is None: + assert not self.normalize_xyz, \ + 'can not normalize grouped xyz when max_radius is None' + + def forward(self, points_xyz, center_xyz, features=None): + """ + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods. + features (Tensor): (B, C, N) Descriptors of the features. + + Return: + Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. + """ + # if self.max_radius is None, we will perform kNN instead of ball query + # idx is of shape [B, npoint, sample_num] + if self.max_radius is None: + idx = knn(self.sample_num, points_xyz, center_xyz, False) + idx = idx.transpose(1, 2).contiguous() + else: + idx = ball_query(self.min_radius, self.max_radius, self.sample_num, + points_xyz, center_xyz) + + if self.uniform_sample: + unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) + for i_batch in range(idx.shape[0]): + for i_region in range(idx.shape[1]): + unique_ind = torch.unique(idx[i_batch, i_region, :]) + num_unique = unique_ind.shape[0] + unique_cnt[i_batch, i_region] = num_unique + sample_ind = torch.randint( + 0, + num_unique, (self.sample_num - num_unique, ), + dtype=torch.long) + all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) + idx[i_batch, i_region, :] = all_ind + + xyz_trans = points_xyz.transpose(1, 2).contiguous() + # (B, 3, npoint, sample_num) + grouped_xyz = grouping_operation(xyz_trans, idx) + grouped_xyz_diff = grouped_xyz - \ + center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets + if self.normalize_xyz: + grouped_xyz_diff /= self.max_radius + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + # (B, C + 3, npoint, sample_num) + new_features = torch.cat([grouped_xyz_diff, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + assert (self.use_xyz + ), 'Cannot have not features and not use xyz as a feature!' + new_features = grouped_xyz_diff + + ret = [new_features] + if self.return_grouped_xyz: + ret.append(grouped_xyz) + if self.return_unique_cnt: + ret.append(unique_cnt) + if self.return_grouped_idx: + ret.append(idx) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + + +class GroupAll(nn.Module): + """Group xyz with feature. + + Args: + use_xyz (bool): Whether to use xyz. + """ + + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, + xyz: torch.Tensor, + new_xyz: torch.Tensor, + features: torch.Tensor = None): + """ + Args: + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + new_xyz (Tensor): new xyz coordinates of the features. + features (Tensor): (B, C, N) features to group. + + Return: + Tensor: (B, C + 3, 1, N) Grouped feature. + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + # (B, 3 + C, 1, N) + new_features = torch.cat([grouped_xyz, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features + + +class GroupingOperation(Function): + """Group feature with given index.""" + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, N) tensor of features to group. + indices (Tensor): (B, npoint, nsample) the indices of + features to group with. + + Returns: + Tensor: (B, C, npoint, nsample) Grouped features. + """ + features = features.contiguous() + indices = indices.contiguous() + + B, nfeatures, nsample = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + ext_module.group_points_forward(B, C, N, nfeatures, nsample, features, + indices, output) + + ctx.for_backwards = (indices, N) + return output + + @staticmethod + def backward(ctx, + grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients + of the output from forward. + + Returns: + Tensor: (B, C, N) gradient of the features. + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + + grad_out_data = grad_out.data.contiguous() + ext_module.group_points_backward(B, C, N, npoint, nsample, + grad_out_data, idx, + grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py new file mode 100644 index 000000000..1b495c285 --- /dev/null +++ b/tests/test_ops/test_group_points.py @@ -0,0 +1,76 @@ +import pytest +import torch + +from mmcv.ops import grouping_operation + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_grouping_points(): + idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], + [0, 0, 0]]]).int().cuda() + festures = torch.tensor([[[ + 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, + 0.9268, 0.8414 + ], + [ + 5.4247, 1.5113, 2.3944, 1.4740, 5.0300, + 5.1030, 1.9360, 2.1939, 2.1581, 3.4666 + ], + [ + -1.6266, -1.0281, -1.0393, -1.6931, -1.3982, + -0.5732, -1.0830, -1.7561, -1.6786, -1.6967 + ]], + [[ + -0.0380, -0.1880, -1.5724, 0.6905, -0.3190, + 0.7798, -0.3693, -0.9457, -0.2942, -1.8527 + ], + [ + 1.1773, 1.5009, 2.6399, 5.9242, 1.0962, + 2.7346, 6.0865, 1.5555, 4.3303, 2.8229 + ], + [ + -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, + -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 + ]]]).cuda() + + output = grouping_operation(festures, idx) + expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], + [-1.3311, -1.3311, -1.3311], + [0.9268, 0.9268, 0.9268], + [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798]], + [[5.4247, 5.4247, 5.4247], + [1.4740, 1.4740, 1.4740], + [2.1581, 2.1581, 2.1581], + [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247]], + [[-1.6266, -1.6266, -1.6266], + [-1.6931, -1.6931, -1.6931], + [-1.6786, -1.6786, -1.6786], + [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266]]], + [[[-0.0380, -0.0380, -0.0380], + [-0.3693, -0.3693, -0.3693], + [-1.8527, -1.8527, -1.8527], + [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380]], + [[1.1773, 1.1773, 1.1773], + [6.0865, 6.0865, 6.0865], + [2.8229, 2.8229, 2.8229], + [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773]], + [[-0.6646, -0.6646, -0.6646], + [0.4990, 0.4990, 0.4990], + [0.0386, 0.0386, 0.0386], + [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646]]]]).cuda() + assert torch.allclose(output, expected_output)