diff --git a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh index a5562d5d7..b7bca23eb 100644 --- a/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh @@ -166,4 +166,51 @@ __global__ void determin_voxel_num( } } +__global__ void nondeterministic_get_assign_pos( + const int nthreads, const int32_t* coors_map, int32_t* pts_id, + int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) { + CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) { + int coors_idx = coors_map[thread_idx]; + if (coors_idx > -1) { + int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1); + pts_id[thread_idx] = coors_pts_pos; + if (coors_pts_pos == 0) { + coors_order[coors_idx] = atomicAdd(coors_count, 1); + } + } + } +} + +template +__global__ void nondeterministic_assign_point_voxel( + const int nthreads, const T* points, const int32_t* coors_map, + const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count, + const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count, + const int max_voxels, const int max_points, const int num_features, + const int NDim) { + CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) { + int coors_idx = coors_map[thread_idx]; + int coors_pts_pos = pts_id[thread_idx]; + if (coors_idx > -1 && coors_pts_pos < max_points) { + int coors_pos = coors_order[coors_idx]; + if (coors_pos < max_voxels) { + auto voxels_offset = + voxels + (coors_pos * max_points + coors_pts_pos) * num_features; + auto points_offset = points + thread_idx * num_features; + for (int k = 0; k < num_features; k++) { + voxels_offset[k] = points_offset[k]; + } + if (coors_pts_pos == 0) { + pts_count[coors_pos] = min(reduce_count[coors_idx], max_points); + auto coors_offset = coors + coors_pos * NDim; + auto coors_in_offset = coors_in + coors_idx * NDim; + for (int k = 0; k < NDim; k++) { + coors_offset[k] = coors_in_offset[k]; + } + } + } + } + } +} + #endif // VOXELIZATION_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index 92b0a556d..b92ad6791 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1396,6 +1396,12 @@ int HardVoxelizeForwardCUDAKernelLauncher( const std::vector coors_range, const int max_points, const int max_voxels, const int NDim = 3); +int NondeterministicHardVoxelizeForwardCUDAKernelLauncher( + const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, + at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3); + void DynamicVoxelizeForwardCUDAKernelLauncher( const at::Tensor& points, at::Tensor& coors, const std::vector voxel_size, const std::vector coors_range, @@ -1413,6 +1419,16 @@ int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels, max_points, max_voxels, NDim); }; +int nondeterministic_hard_voxelize_forward_cuda( + const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, + at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim) { + return NondeterministicHardVoxelizeForwardCUDAKernelLauncher( + points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, + max_points, max_voxels, NDim); +}; + void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors, const std::vector voxel_size, const std::vector coors_range, @@ -1429,6 +1445,12 @@ int hard_voxelize_forward_impl(const at::Tensor& points, at::Tensor& voxels, const int max_points, const int max_voxels, const int NDim); +int nondeterministic_hard_voxelize_forward_impl( + const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, + at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim); + void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors, const std::vector voxel_size, const std::vector coors_range, @@ -1436,6 +1458,8 @@ void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors, REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA, hard_voxelize_forward_cuda); +REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, CUDA, + nondeterministic_hard_voxelize_forward_cuda); REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA, dynamic_voxelize_forward_cuda); diff --git a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu index bcb7da338..f4166b7b7 100644 --- a/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu @@ -145,6 +145,104 @@ int HardVoxelizeForwardCUDAKernelLauncher( return voxel_num_int; } +int NondeterministicHardVoxelizeForwardCUDAKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3) { + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + if (num_points == 0) return 0; + + dim3 blocks( + std::min(at::cuda::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096)); + dim3 threads(THREADS_PER_BLOCK); + + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + // map points to voxel coors + at::Tensor temp_coors = + at::zeros({num_points, NDim}, points.options().dtype(at::kInt)); + + // 1. link point to corresponding voxel coors + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "hard_voxelize_kernel", ([&] { + dynamic_voxelize_kernel<<>>( + points.contiguous().data_ptr(), + temp_coors.contiguous().data_ptr(), voxel_x, voxel_y, voxel_z, + coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max, + coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, + NDim); + })); + + at::Tensor coors_map; + at::Tensor reduce_count; + + auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1); + + std::tie(temp_coors, coors_map, reduce_count) = + at::unique_dim(coors_clean, 0, true, true, false); + + if (temp_coors[0][0].lt(0).item()) { + // the first element of temp_coors is (-1,-1,-1) and should be removed + temp_coors = temp_coors.slice(0, 1); + coors_map = coors_map - 1; + } + + int num_coors = temp_coors.size(0); + temp_coors = temp_coors.to(at::kInt); + coors_map = coors_map.to(at::kInt); + + at::Tensor coors_count = at::zeros({1}, coors_map.options()); + at::Tensor coors_order = at::empty({num_coors}, coors_map.options()); + at::Tensor pts_id = at::zeros({num_points}, coors_map.options()); + reduce_count = at::zeros({num_coors}, coors_map.options()); + + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "get_assign_pos", ([&] { + nondeterministic_get_assign_pos<<>>( + num_points, coors_map.contiguous().data_ptr(), + pts_id.contiguous().data_ptr(), + coors_count.contiguous().data_ptr(), + reduce_count.contiguous().data_ptr(), + coors_order.contiguous().data_ptr()); + })); + + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "assign_point_to_voxel", ([&] { + nondeterministic_assign_point_voxel + <<>>( + num_points, points.contiguous().data_ptr(), + coors_map.contiguous().data_ptr(), + pts_id.contiguous().data_ptr(), + temp_coors.contiguous().data_ptr(), + reduce_count.contiguous().data_ptr(), + coors_order.contiguous().data_ptr(), + voxels.contiguous().data_ptr(), + coors.contiguous().data_ptr(), + num_points_per_voxel.contiguous().data_ptr(), + max_voxels, max_points, num_features, NDim); + })); + AT_CUDA_CHECK(cudaGetLastError()); + return max_voxels < num_coors ? max_voxels : num_coors; +} + void DynamicVoxelizeForwardCUDAKernelLauncher( const at::Tensor &points, at::Tensor &coors, const std::vector voxel_size, const std::vector coors_range, diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 064aabc07..1dc55cea5 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -338,7 +338,8 @@ void hard_voxelize_forward(const at::Tensor &points, const at::Tensor &coors_range, at::Tensor &voxels, at::Tensor &coors, at::Tensor &num_points_per_voxel, at::Tensor &voxel_num, const int max_points, - const int max_voxels, const int NDim); + const int max_voxels, const int NDim, + const bool deterministic); void dynamic_voxelize_forward(const at::Tensor &points, const at::Tensor &voxel_size, @@ -756,7 +757,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "hard_voxelize_forward", py::arg("points"), py::arg("voxel_size"), py::arg("coors_range"), py::arg("voxels"), py::arg("coors"), py::arg("num_points_per_voxel"), py::arg("voxel_num"), - py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim")); + py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"), + py::arg("deterministic")); m.def("dynamic_voxelize_forward", &dynamic_voxelize_forward, "dynamic_voxelize_forward", py::arg("points"), py::arg("voxel_size"), py::arg("coors_range"), py::arg("coors"), py::arg("NDim")); diff --git a/mmcv/ops/csrc/pytorch/voxelization.cpp b/mmcv/ops/csrc/pytorch/voxelization.cpp index 1d1c229c1..7946be617 100644 --- a/mmcv/ops/csrc/pytorch/voxelization.cpp +++ b/mmcv/ops/csrc/pytorch/voxelization.cpp @@ -14,6 +14,17 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, max_points, max_voxels, NDim); } +int nondeterministic_hard_voxelize_forward_impl( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3) { + return DISPATCH_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, + points, voxels, coors, num_points_per_voxel, + voxel_size, coors_range, max_points, max_voxels, + NDim); +} + void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors, const std::vector voxel_size, const std::vector coors_range, @@ -27,7 +38,8 @@ void hard_voxelize_forward(const at::Tensor &points, const at::Tensor &coors_range, at::Tensor &voxels, at::Tensor &coors, at::Tensor &num_points_per_voxel, at::Tensor &voxel_num, const int max_points, - const int max_voxels, const int NDim = 3) { + const int max_voxels, const int NDim = 3, + const bool deterministic = true) { int64_t *voxel_num_data = voxel_num.data_ptr(); std::vector voxel_size_v( voxel_size.data_ptr(), @@ -36,9 +48,15 @@ void hard_voxelize_forward(const at::Tensor &points, coors_range.data_ptr(), coors_range.data_ptr() + coors_range.numel()); - *voxel_num_data = hard_voxelize_forward_impl( - points, voxels, coors, num_points_per_voxel, voxel_size_v, coors_range_v, - max_points, max_voxels, NDim); + if (deterministic) { + *voxel_num_data = hard_voxelize_forward_impl( + points, voxels, coors, num_points_per_voxel, voxel_size_v, + coors_range_v, max_points, max_voxels, NDim); + } else { + *voxel_num_data = nondeterministic_hard_voxelize_forward_impl( + points, voxels, coors, num_points_per_voxel, voxel_size_v, + coors_range_v, max_points, max_voxels, NDim); + } } void dynamic_voxelize_forward(const at::Tensor &points, diff --git a/mmcv/ops/voxelize.py b/mmcv/ops/voxelize.py index 2dc2f19e9..ee4e0ae8f 100644 --- a/mmcv/ops/voxelize.py +++ b/mmcv/ops/voxelize.py @@ -18,7 +18,8 @@ class _Voxelization(Function): voxel_size, coors_range, max_points=35, - max_voxels=20000): + max_voxels=20000, + deterministic=True): """Convert kitti points(N, >=3) to voxels. Args: @@ -34,6 +35,16 @@ class _Voxelization(Function): for second, 20000 is a good choice. Users should shuffle points before call this function because max_voxels may drop points. Default: 20000. + deterministic: bool. whether to invoke the non-deterministic + version of hard-voxelization implementations. non-deterministic + version is considerablly fast but is not deterministic. only + affects hard voxelization. default True. for more information + of this argument and the implementation insights, please refer + to the following links: + https://github.com/open-mmlab/mmdetection3d/issues/894 + https://github.com/open-mmlab/mmdetection3d/pull/904 + it is an experimental feature and we will appreciate it if + you could share with us the failing cases. Returns: tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains three @@ -69,7 +80,8 @@ class _Voxelization(Function): voxel_num, max_points=max_points, max_voxels=max_voxels, - NDim=3) + NDim=3, + deterministic=deterministic) # select the valid voxels voxels_out = voxels[:voxel_num] coors_out = coors[:voxel_num] @@ -102,7 +114,27 @@ class Voxelization(nn.Module): voxel_size, point_cloud_range, max_num_points, - max_voxels=20000): + max_voxels=20000, + deterministic=True): + """ + Args: + voxel_size (list): list [x, y, z] size of three dimension + point_cloud_range (list): + [x_min, y_min, z_min, x_max, y_max, z_max] + max_num_points (int): max number of points per voxel + max_voxels (tuple or int): max number of voxels in + (training, testing) time + deterministic: bool. whether to invoke the non-deterministic + version of hard-voxelization implementations. non-deterministic + version is considerablly fast but is not deterministic. only + affects hard voxelization. default True. for more information + of this argument and the implementation insights, please refer + to the following links: + https://github.com/open-mmlab/mmdetection3d/issues/894 + https://github.com/open-mmlab/mmdetection3d/pull/904 + it is an experimental feature and we will appreciate it if + you could share with us the failing cases. + """ super().__init__() self.voxel_size = voxel_size @@ -112,6 +144,7 @@ class Voxelization(nn.Module): self.max_voxels = max_voxels else: self.max_voxels = _pair(max_voxels) + self.deterministic = deterministic point_cloud_range = torch.tensor( point_cloud_range, dtype=torch.float32) @@ -132,7 +165,8 @@ class Voxelization(nn.Module): max_voxels = self.max_voxels[1] return voxelization(input, self.voxel_size, self.point_cloud_range, - self.max_num_points, max_voxels) + self.max_num_points, max_voxels, + self.deterministic) def __repr__(self): s = self.__class__.__name__ + '(' @@ -140,5 +174,6 @@ class Voxelization(nn.Module): s += ', point_cloud_range=' + str(self.point_cloud_range) s += ', max_num_points=' + str(self.max_num_points) s += ', max_voxels=' + str(self.max_voxels) + s += ', deterministic=' + str(self.deterministic) s += ')' return s diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index db956da41..6f0fa9aef 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -60,3 +60,80 @@ def test_voxelization(device_type): assert np.all( points[indices] == expected_coors[i][:num_points_current_voxel]) assert num_points_current_voxel == expected_num_points_per_voxel[i] + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_voxelization_nondeterministic(): + voxel_size = [0.5, 0.5, 0.5] + point_cloud_range = [0, -40, -3, 70.4, 40, 1] + + voxel_dict = np.load( + 'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item() + points = voxel_dict['points'] + + points = torch.tensor(points) + max_num_points = -1 + dynamic_voxelization = Voxelization(voxel_size, point_cloud_range, + max_num_points) + + max_num_points = 10 + max_voxels = 50 + hard_voxelization = Voxelization( + voxel_size, + point_cloud_range, + max_num_points, + max_voxels, + deterministic=False) + + # test hard_voxelization (non-deterministic version) on gpu + points = torch.tensor(points).contiguous().to(device='cuda:0') + voxels, coors, num_points_per_voxel = hard_voxelization.forward(points) + coors = coors.cpu().detach().numpy().tolist() + voxels = voxels.cpu().detach().numpy().tolist() + num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy().tolist() + + coors_all = dynamic_voxelization.forward(points) + coors_all = coors_all.cpu().detach().numpy().tolist() + + coors_set = set([tuple(c) for c in coors]) + coors_all_set = set([tuple(c) for c in coors_all]) + + assert len(coors_set) == len(coors) + assert len(coors_set - coors_all_set) == 0 + + points = points.cpu().detach().numpy().tolist() + + coors_points_dict = {} + for c, ps in zip(coors_all, points): + if tuple(c) not in coors_points_dict: + coors_points_dict[tuple(c)] = set() + coors_points_dict[tuple(c)].add(tuple(ps)) + + for c, ps, n in zip(coors, voxels, num_points_per_voxel): + ideal_voxel_points_set = coors_points_dict[tuple(c)] + voxel_points_set = set([tuple(p) for p in ps[:n]]) + assert len(voxel_points_set) == n + if n < max_num_points: + assert voxel_points_set == ideal_voxel_points_set + for p in ps[n:]: + assert max(p) == min(p) == 0 + else: + assert len(voxel_points_set - ideal_voxel_points_set) == 0 + + # test hard_voxelization (non-deterministic version) on gpu + # with all input point in range + points = torch.tensor(points).contiguous().to(device='cuda:0')[:max_voxels] + coors_all = dynamic_voxelization.forward(points) + valid_mask = coors_all.ge(0).all(-1) + points = points[valid_mask] + coors_all = coors_all[valid_mask] + coors_all = coors_all.cpu().detach().numpy().tolist() + + voxels, coors, num_points_per_voxel = hard_voxelization.forward(points) + coors = coors.cpu().detach().numpy().tolist() + + coors_set = set([tuple(c) for c in coors]) + coors_all_set = set([tuple(c) for c in coors_all]) + + assert len(coors_set) == len(coors) == len(coors_all_set)