mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Add group points ops from mmdet3d (#1415)
* add op (group points) and its related ops (ball query and knn) in mmdet3d * refactor code * fix typo * refactor code * fix typo * refactor code * make input contiguous Co-authored-by: zhouzaida <zhouzaida@163.com>
This commit is contained in:
parent
6935a818ec
commit
75cae78c55
@ -16,6 +16,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
||||
- FurthestPointSample
|
||||
- FurthestPointSampleWithDist
|
||||
- GeneralizedAttention
|
||||
- GroupPoints
|
||||
- KNN
|
||||
- MaskedConv
|
||||
- NMS
|
||||
|
@ -22,6 +22,7 @@ from .furthest_point_sample import (furthest_point_sample,
|
||||
furthest_point_sample_with_dist)
|
||||
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
|
||||
from .gather_points import gather_points
|
||||
from .group_points import GroupAll, QueryAndGroup, grouping_operation
|
||||
from .info import (get_compiler_version, get_compiling_cuda_version,
|
||||
get_onnxruntime_op_path)
|
||||
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
|
||||
@ -68,13 +69,13 @@ __all__ = [
|
||||
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
|
||||
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
|
||||
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
|
||||
'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev',
|
||||
'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated',
|
||||
'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn',
|
||||
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
|
||||
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
|
||||
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
|
||||
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
|
||||
'border_align', 'gather_points', 'furthest_point_sample',
|
||||
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
|
||||
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter',
|
||||
'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu',
|
||||
'points_in_boxes_all'
|
||||
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
|
||||
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
|
||||
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
|
||||
]
|
||||
|
63
mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh
Normal file
63
mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
|
||||
#ifndef GROUP_POINTS_CUDA_KERNEL_CUH
|
||||
#define GROUP_POINTS_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void group_points_forward_cuda_kernel(int b, int c, int n,
|
||||
int npoints, int nsample,
|
||||
const T *points,
|
||||
const int *__restrict__ idx,
|
||||
T *out) {
|
||||
// points: (B, C, N)
|
||||
// idx: (B, npoints, nsample)
|
||||
// output:
|
||||
// out: (B, C, npoints, nsample)
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int pt_idx = index / nsample;
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
||||
|
||||
int sample_idx = index % nsample;
|
||||
|
||||
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
||||
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
|
||||
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
|
||||
pt_idx * nsample + sample_idx;
|
||||
|
||||
out[out_idx] = points[in_idx];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void group_points_backward_cuda_kernel(int b, int c, int n,
|
||||
int npoints, int nsample,
|
||||
const T *grad_out,
|
||||
const int *__restrict__ idx,
|
||||
T *grad_points) {
|
||||
// grad_out: (B, C, npoints, nsample)
|
||||
// idx: (B, npoints, nsample)
|
||||
// output:
|
||||
// grad_points: (B, C, N)
|
||||
int bs_idx = blockIdx.z;
|
||||
int c_idx = blockIdx.y;
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int pt_idx = index / nsample;
|
||||
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
||||
|
||||
int sample_idx = index % nsample;
|
||||
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
|
||||
pt_idx * nsample + sample_idx;
|
||||
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
||||
|
||||
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]);
|
||||
}
|
||||
|
||||
#endif // GROUP_POINTS_CUDA_KERNEL_CUH
|
61
mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu
Normal file
61
mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "group_points_cuda_kernel.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
||||
int nsample, const Tensor points,
|
||||
const Tensor idx, Tensor out) {
|
||||
// points: (B, C, N)
|
||||
// idx: (B, npoints, nsample)
|
||||
// output:
|
||||
// out: (B, C, npoints, nsample)
|
||||
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
points.scalar_type(), "group_points_forward_cuda_kernel", [&] {
|
||||
group_points_forward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, n, npoints, nsample, points.data_ptr<scalar_t>(),
|
||||
idx.data_ptr<int>(), out.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
||||
int nsample, const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
Tensor grad_points) {
|
||||
// grad_out: (B, C, npoints, nsample)
|
||||
// idx: (B, npoints, nsample)
|
||||
// output:
|
||||
// grad_points: (B, C, N)
|
||||
|
||||
at::cuda::CUDAGuard device_guard(grad_out.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_out.scalar_type(), "group_points_backward_cuda_kernel", [&] {
|
||||
group_points_backward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
b, c, n, npoints, nsample, grad_out.data_ptr<scalar_t>(),
|
||||
idx.data_ptr<int>(), grad_points.data_ptr<scalar_t>());
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
58
mmcv/ops/csrc/pytorch/group_points.cpp
Normal file
58
mmcv/ops/csrc/pytorch/group_points.cpp
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
||||
int nsample, const Tensor points,
|
||||
const Tensor idx, Tensor out);
|
||||
void group_points_forward_cuda(int b, int c, int n, int npoints, int nsample,
|
||||
const Tensor points, const Tensor idx,
|
||||
Tensor out) {
|
||||
GroupPointsForwardCUDAKernelLauncher(b, c, n, npoints, nsample, points, idx,
|
||||
out);
|
||||
};
|
||||
|
||||
void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
|
||||
int nsample, const Tensor grad_out,
|
||||
const Tensor idx,
|
||||
Tensor grad_points);
|
||||
void group_points_backward_cuda(int b, int c, int n, int npoints, int nsample,
|
||||
const Tensor grad_out, const Tensor idx,
|
||||
Tensor grad_points) {
|
||||
GroupPointsBackwardCUDAKernelLauncher(b, c, n, npoints, nsample, grad_out,
|
||||
idx, grad_points);
|
||||
};
|
||||
#endif
|
||||
|
||||
void group_points_forward(int b, int c, int n, int npoints, int nsample,
|
||||
Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor) {
|
||||
if (points_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
group_points_forward_cuda(b, c, n, npoints, nsample, points_tensor,
|
||||
idx_tensor, out_tensor);
|
||||
#else
|
||||
AT_ERROR("group_points is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("group_points is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void group_points_backward(int b, int c, int n, int npoints, int nsample,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor) {
|
||||
if (grad_out_tensor.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
group_points_backward_cuda(b, c, n, npoints, nsample, grad_out_tensor,
|
||||
idx_tensor, grad_points_tensor);
|
||||
#else
|
||||
AT_ERROR("group_points is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("group_points is not implemented on CPU");
|
||||
}
|
||||
}
|
@ -65,6 +65,14 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois,
|
||||
int pooled_width, float spatial_scale,
|
||||
int sampling_ratio, float gamma);
|
||||
|
||||
void group_points_forward(int b, int c, int n, int npoints, int nsample,
|
||||
Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor out_tensor);
|
||||
|
||||
void group_points_backward(int b, int c, int n, int npoints, int nsample,
|
||||
Tensor grad_out_tensor, Tensor idx_tensor,
|
||||
Tensor grad_points_tensor);
|
||||
|
||||
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
|
||||
Tensor pooled_features, Tensor pooled_empty_flag);
|
||||
|
||||
@ -453,6 +461,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
|
||||
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
|
||||
py::arg("aligned"), py::arg("offset"));
|
||||
m.def("group_points_forward", &group_points_forward, "group_points_forward",
|
||||
py::arg("b"), py::arg("c"), py::arg("n"), py::arg("npoints"),
|
||||
py::arg("nsample"), py::arg("points_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("out_tensor"));
|
||||
m.def("group_points_backward", &group_points_backward,
|
||||
"group_points_backward", py::arg("b"), py::arg("c"), py::arg("n"),
|
||||
py::arg("npoints"), py::arg("nsample"), py::arg("grad_out_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("grad_points_tensor"));
|
||||
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
|
||||
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
|
||||
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
|
||||
py::arg("dist2_tensor"));
|
||||
m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward,
|
||||
"iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"),
|
||||
py::arg("boxes_b"), py::arg("ans_overlap"));
|
||||
|
224
mmcv/ops/group_points.py
Normal file
224
mmcv/ops/group_points.py
Normal file
@ -0,0 +1,224 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..utils import ext_loader
|
||||
from .ball_query import ball_query
|
||||
from .knn import knn
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['group_points_forward', 'group_points_backward'])
|
||||
|
||||
|
||||
class QueryAndGroup(nn.Module):
|
||||
"""Groups points with a ball query of radius.
|
||||
|
||||
Args:
|
||||
max_radius (float): The maximum radius of the balls.
|
||||
If None is given, we will use kNN sampling instead of ball query.
|
||||
sample_num (int): Maximum number of features to gather in the ball.
|
||||
min_radius (float, optional): The minimum radius of the balls.
|
||||
Default: 0.
|
||||
use_xyz (bool, optional): Whether to use xyz.
|
||||
Default: True.
|
||||
return_grouped_xyz (bool, optional): Whether to return grouped xyz.
|
||||
Default: False.
|
||||
normalize_xyz (bool, optional): Whether to normalize xyz.
|
||||
Default: False.
|
||||
uniform_sample (bool, optional): Whether to sample uniformly.
|
||||
Default: False
|
||||
return_unique_cnt (bool, optional): Whether to return the count of
|
||||
unique samples. Default: False.
|
||||
return_grouped_idx (bool, optional): Whether to return grouped idx.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_radius,
|
||||
sample_num,
|
||||
min_radius=0,
|
||||
use_xyz=True,
|
||||
return_grouped_xyz=False,
|
||||
normalize_xyz=False,
|
||||
uniform_sample=False,
|
||||
return_unique_cnt=False,
|
||||
return_grouped_idx=False):
|
||||
super().__init__()
|
||||
self.max_radius = max_radius
|
||||
self.min_radius = min_radius
|
||||
self.sample_num = sample_num
|
||||
self.use_xyz = use_xyz
|
||||
self.return_grouped_xyz = return_grouped_xyz
|
||||
self.normalize_xyz = normalize_xyz
|
||||
self.uniform_sample = uniform_sample
|
||||
self.return_unique_cnt = return_unique_cnt
|
||||
self.return_grouped_idx = return_grouped_idx
|
||||
if self.return_unique_cnt:
|
||||
assert self.uniform_sample, \
|
||||
'uniform_sample should be True when ' \
|
||||
'returning the count of unique samples'
|
||||
if self.max_radius is None:
|
||||
assert not self.normalize_xyz, \
|
||||
'can not normalize grouped xyz when max_radius is None'
|
||||
|
||||
def forward(self, points_xyz, center_xyz, features=None):
|
||||
"""
|
||||
Args:
|
||||
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||||
center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
|
||||
features (Tensor): (B, C, N) Descriptors of the features.
|
||||
|
||||
Return:
|
||||
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
|
||||
"""
|
||||
# if self.max_radius is None, we will perform kNN instead of ball query
|
||||
# idx is of shape [B, npoint, sample_num]
|
||||
if self.max_radius is None:
|
||||
idx = knn(self.sample_num, points_xyz, center_xyz, False)
|
||||
idx = idx.transpose(1, 2).contiguous()
|
||||
else:
|
||||
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
|
||||
points_xyz, center_xyz)
|
||||
|
||||
if self.uniform_sample:
|
||||
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
|
||||
for i_batch in range(idx.shape[0]):
|
||||
for i_region in range(idx.shape[1]):
|
||||
unique_ind = torch.unique(idx[i_batch, i_region, :])
|
||||
num_unique = unique_ind.shape[0]
|
||||
unique_cnt[i_batch, i_region] = num_unique
|
||||
sample_ind = torch.randint(
|
||||
0,
|
||||
num_unique, (self.sample_num - num_unique, ),
|
||||
dtype=torch.long)
|
||||
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
|
||||
idx[i_batch, i_region, :] = all_ind
|
||||
|
||||
xyz_trans = points_xyz.transpose(1, 2).contiguous()
|
||||
# (B, 3, npoint, sample_num)
|
||||
grouped_xyz = grouping_operation(xyz_trans, idx)
|
||||
grouped_xyz_diff = grouped_xyz - \
|
||||
center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
|
||||
if self.normalize_xyz:
|
||||
grouped_xyz_diff /= self.max_radius
|
||||
|
||||
if features is not None:
|
||||
grouped_features = grouping_operation(features, idx)
|
||||
if self.use_xyz:
|
||||
# (B, C + 3, npoint, sample_num)
|
||||
new_features = torch.cat([grouped_xyz_diff, grouped_features],
|
||||
dim=1)
|
||||
else:
|
||||
new_features = grouped_features
|
||||
else:
|
||||
assert (self.use_xyz
|
||||
), 'Cannot have not features and not use xyz as a feature!'
|
||||
new_features = grouped_xyz_diff
|
||||
|
||||
ret = [new_features]
|
||||
if self.return_grouped_xyz:
|
||||
ret.append(grouped_xyz)
|
||||
if self.return_unique_cnt:
|
||||
ret.append(unique_cnt)
|
||||
if self.return_grouped_idx:
|
||||
ret.append(idx)
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
class GroupAll(nn.Module):
|
||||
"""Group xyz with feature.
|
||||
|
||||
Args:
|
||||
use_xyz (bool): Whether to use xyz.
|
||||
"""
|
||||
|
||||
def __init__(self, use_xyz: bool = True):
|
||||
super().__init__()
|
||||
self.use_xyz = use_xyz
|
||||
|
||||
def forward(self,
|
||||
xyz: torch.Tensor,
|
||||
new_xyz: torch.Tensor,
|
||||
features: torch.Tensor = None):
|
||||
"""
|
||||
Args:
|
||||
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
||||
new_xyz (Tensor): new xyz coordinates of the features.
|
||||
features (Tensor): (B, C, N) features to group.
|
||||
|
||||
Return:
|
||||
Tensor: (B, C + 3, 1, N) Grouped feature.
|
||||
"""
|
||||
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
||||
if features is not None:
|
||||
grouped_features = features.unsqueeze(2)
|
||||
if self.use_xyz:
|
||||
# (B, 3 + C, 1, N)
|
||||
new_features = torch.cat([grouped_xyz, grouped_features],
|
||||
dim=1)
|
||||
else:
|
||||
new_features = grouped_features
|
||||
else:
|
||||
new_features = grouped_xyz
|
||||
|
||||
return new_features
|
||||
|
||||
|
||||
class GroupingOperation(Function):
|
||||
"""Group feature with given index."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor,
|
||||
indices: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
features (Tensor): (B, C, N) tensor of features to group.
|
||||
indices (Tensor): (B, npoint, nsample) the indices of
|
||||
features to group with.
|
||||
|
||||
Returns:
|
||||
Tensor: (B, C, npoint, nsample) Grouped features.
|
||||
"""
|
||||
features = features.contiguous()
|
||||
indices = indices.contiguous()
|
||||
|
||||
B, nfeatures, nsample = indices.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
||||
|
||||
ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
|
||||
indices, output)
|
||||
|
||||
ctx.for_backwards = (indices, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
|
||||
of the output from forward.
|
||||
|
||||
Returns:
|
||||
Tensor: (B, C, N) gradient of the features.
|
||||
"""
|
||||
idx, N = ctx.for_backwards
|
||||
|
||||
B, C, npoint, nsample = grad_out.size()
|
||||
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
|
||||
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
ext_module.group_points_backward(B, C, N, npoint, nsample,
|
||||
grad_out_data, idx,
|
||||
grad_features.data)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
grouping_operation = GroupingOperation.apply
|
76
tests/test_ops/test_group_points.py
Normal file
76
tests/test_ops/test_group_points.py
Normal file
@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import grouping_operation
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_grouping_points():
|
||||
idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0],
|
||||
[0, 0, 0]],
|
||||
[[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
|
||||
[0, 0, 0]]]).int().cuda()
|
||||
festures = torch.tensor([[[
|
||||
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
|
||||
0.9268, 0.8414
|
||||
],
|
||||
[
|
||||
5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
|
||||
5.1030, 1.9360, 2.1939, 2.1581, 3.4666
|
||||
],
|
||||
[
|
||||
-1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
|
||||
-0.5732, -1.0830, -1.7561, -1.6786, -1.6967
|
||||
]],
|
||||
[[
|
||||
-0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
|
||||
0.7798, -0.3693, -0.9457, -0.2942, -1.8527
|
||||
],
|
||||
[
|
||||
1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
|
||||
2.7346, 6.0865, 1.5555, 4.3303, 2.8229
|
||||
],
|
||||
[
|
||||
-0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
|
||||
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386
|
||||
]]]).cuda()
|
||||
|
||||
output = grouping_operation(festures, idx)
|
||||
expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
|
||||
[-1.3311, -1.3311, -1.3311],
|
||||
[0.9268, 0.9268, 0.9268],
|
||||
[0.5798, 0.5798, 0.5798],
|
||||
[0.5798, 0.5798, 0.5798],
|
||||
[0.5798, 0.5798, 0.5798]],
|
||||
[[5.4247, 5.4247, 5.4247],
|
||||
[1.4740, 1.4740, 1.4740],
|
||||
[2.1581, 2.1581, 2.1581],
|
||||
[5.4247, 5.4247, 5.4247],
|
||||
[5.4247, 5.4247, 5.4247],
|
||||
[5.4247, 5.4247, 5.4247]],
|
||||
[[-1.6266, -1.6266, -1.6266],
|
||||
[-1.6931, -1.6931, -1.6931],
|
||||
[-1.6786, -1.6786, -1.6786],
|
||||
[-1.6266, -1.6266, -1.6266],
|
||||
[-1.6266, -1.6266, -1.6266],
|
||||
[-1.6266, -1.6266, -1.6266]]],
|
||||
[[[-0.0380, -0.0380, -0.0380],
|
||||
[-0.3693, -0.3693, -0.3693],
|
||||
[-1.8527, -1.8527, -1.8527],
|
||||
[-0.0380, -0.0380, -0.0380],
|
||||
[-0.0380, -0.0380, -0.0380],
|
||||
[-0.0380, -0.0380, -0.0380]],
|
||||
[[1.1773, 1.1773, 1.1773],
|
||||
[6.0865, 6.0865, 6.0865],
|
||||
[2.8229, 2.8229, 2.8229],
|
||||
[1.1773, 1.1773, 1.1773],
|
||||
[1.1773, 1.1773, 1.1773],
|
||||
[1.1773, 1.1773, 1.1773]],
|
||||
[[-0.6646, -0.6646, -0.6646],
|
||||
[0.4990, 0.4990, 0.4990],
|
||||
[0.0386, 0.0386, 0.0386],
|
||||
[-0.6646, -0.6646, -0.6646],
|
||||
[-0.6646, -0.6646, -0.6646],
|
||||
[-0.6646, -0.6646, -0.6646]]]]).cuda()
|
||||
assert torch.allclose(output, expected_output)
|
Loading…
x
Reference in New Issue
Block a user