[Feature] Add nondeterministic voxelization op from mmdet3d (#1783)

* add nondeterministic voxelization op

* fix lint

* fix lint

* resolve comments

* fix lint
pull/1806/head^2
Wenhao Wu 2022-03-15 14:21:34 +08:00 committed by GitHub
parent 33e14deaea
commit b5d550f090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 311 additions and 10 deletions

View File

@ -166,4 +166,51 @@ __global__ void determin_voxel_num(
}
}
__global__ void nondeterministic_get_assign_pos(
const int nthreads, const int32_t* coors_map, int32_t* pts_id,
int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}
template <typename T>
__global__ void nondeterministic_assign_point_voxel(
const int nthreads, const T* points, const int32_t* coors_map,
const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
const int max_voxels, const int max_points, const int num_features,
const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1 && coors_pts_pos < max_points) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}
#endif // VOXELIZATION_CUDA_KERNEL_CUH

View File

@ -1396,6 +1396,12 @@ int HardVoxelizeForwardCUDAKernelLauncher(
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
@ -1413,6 +1419,16 @@ int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels,
max_points, max_voxels, NDim);
};
int nondeterministic_hard_voxelize_forward_cuda(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim) {
return NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};
void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
@ -1429,6 +1445,12 @@ int hard_voxelize_forward_impl(const at::Tensor& points, at::Tensor& voxels,
const int max_points, const int max_voxels,
const int NDim);
int nondeterministic_hard_voxelize_forward_impl(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim);
void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
@ -1436,6 +1458,8 @@ void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA,
hard_voxelize_forward_cuda);
REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, CUDA,
nondeterministic_hard_voxelize_forward_cuda);
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA,
dynamic_voxelize_forward_cuda);

View File

@ -145,6 +145,104 @@ int HardVoxelizeForwardCUDAKernelLauncher(
return voxel_num_int;
}
int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int num_points = points.size(0);
const int num_features = points.size(1);
if (num_points == 0) return 0;
dim3 blocks(
std::min(at::cuda::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096));
dim3 threads(THREADS_PER_BLOCK);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
NDim);
}));
at::Tensor coors_map;
at::Tensor reduce_count;
auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
std::tie(temp_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, false);
if (temp_coors[0][0].lt(0).item<bool>()) {
// the first element of temp_coors is (-1,-1,-1) and should be removed
temp_coors = temp_coors.slice(0, 1);
coors_map = coors_map - 1;
}
int num_coors = temp_coors.size(0);
temp_coors = temp_coors.to(at::kInt);
coors_map = coors_map.to(at::kInt);
at::Tensor coors_count = at::zeros({1}, coors_map.options());
at::Tensor coors_order = at::empty({num_coors}, coors_map.options());
at::Tensor pts_id = at::zeros({num_points}, coors_map.options());
reduce_count = at::zeros({num_coors}, coors_map.options());
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "get_assign_pos", ([&] {
nondeterministic_get_assign_pos<<<blocks, threads, 0, stream>>>(
num_points, coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
coors_count.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>());
}));
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
nondeterministic_assign_point_voxel<scalar_t>
<<<blocks, threads, 0, stream>>>(
num_points, points.contiguous().data_ptr<scalar_t>(),
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
temp_coors.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>(),
voxels.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int32_t>(),
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
max_voxels, max_points, num_features, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
return max_voxels < num_coors ? max_voxels : num_coors;
}
void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,

View File

@ -338,7 +338,8 @@ void hard_voxelize_forward(const at::Tensor &points,
const at::Tensor &coors_range, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
at::Tensor &voxel_num, const int max_points,
const int max_voxels, const int NDim);
const int max_voxels, const int NDim,
const bool deterministic);
void dynamic_voxelize_forward(const at::Tensor &points,
const at::Tensor &voxel_size,
@ -756,7 +757,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"hard_voxelize_forward", py::arg("points"), py::arg("voxel_size"),
py::arg("coors_range"), py::arg("voxels"), py::arg("coors"),
py::arg("num_points_per_voxel"), py::arg("voxel_num"),
py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"));
py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"),
py::arg("deterministic"));
m.def("dynamic_voxelize_forward", &dynamic_voxelize_forward,
"dynamic_voxelize_forward", py::arg("points"), py::arg("voxel_size"),
py::arg("coors_range"), py::arg("coors"), py::arg("NDim"));

View File

@ -14,6 +14,17 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
max_points, max_voxels, NDim);
}
int nondeterministic_hard_voxelize_forward_impl(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
return DISPATCH_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl,
points, voxels, coors, num_points_per_voxel,
voxel_size, coors_range, max_points, max_voxels,
NDim);
}
void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
@ -27,7 +38,8 @@ void hard_voxelize_forward(const at::Tensor &points,
const at::Tensor &coors_range, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
at::Tensor &voxel_num, const int max_points,
const int max_voxels, const int NDim = 3) {
const int max_voxels, const int NDim = 3,
const bool deterministic = true) {
int64_t *voxel_num_data = voxel_num.data_ptr<int64_t>();
std::vector<float> voxel_size_v(
voxel_size.data_ptr<float>(),
@ -36,9 +48,15 @@ void hard_voxelize_forward(const at::Tensor &points,
coors_range.data_ptr<float>(),
coors_range.data_ptr<float>() + coors_range.numel());
*voxel_num_data = hard_voxelize_forward_impl(
points, voxels, coors, num_points_per_voxel, voxel_size_v, coors_range_v,
max_points, max_voxels, NDim);
if (deterministic) {
*voxel_num_data = hard_voxelize_forward_impl(
points, voxels, coors, num_points_per_voxel, voxel_size_v,
coors_range_v, max_points, max_voxels, NDim);
} else {
*voxel_num_data = nondeterministic_hard_voxelize_forward_impl(
points, voxels, coors, num_points_per_voxel, voxel_size_v,
coors_range_v, max_points, max_voxels, NDim);
}
}
void dynamic_voxelize_forward(const at::Tensor &points,

View File

@ -18,7 +18,8 @@ class _Voxelization(Function):
voxel_size,
coors_range,
max_points=35,
max_voxels=20000):
max_voxels=20000,
deterministic=True):
"""Convert kitti points(N, >=3) to voxels.
Args:
@ -34,6 +35,16 @@ class _Voxelization(Function):
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
Default: 20000.
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
Returns:
tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains three
@ -69,7 +80,8 @@ class _Voxelization(Function):
voxel_num,
max_points=max_points,
max_voxels=max_voxels,
NDim=3)
NDim=3,
deterministic=deterministic)
# select the valid voxels
voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num]
@ -102,7 +114,27 @@ class Voxelization(nn.Module):
voxel_size,
point_cloud_range,
max_num_points,
max_voxels=20000):
max_voxels=20000,
deterministic=True):
"""
Args:
voxel_size (list): list [x, y, z] size of three dimension
point_cloud_range (list):
[x_min, y_min, z_min, x_max, y_max, z_max]
max_num_points (int): max number of points per voxel
max_voxels (tuple or int): max number of voxels in
(training, testing) time
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
"""
super().__init__()
self.voxel_size = voxel_size
@ -112,6 +144,7 @@ class Voxelization(nn.Module):
self.max_voxels = max_voxels
else:
self.max_voxels = _pair(max_voxels)
self.deterministic = deterministic
point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
@ -132,7 +165,8 @@ class Voxelization(nn.Module):
max_voxels = self.max_voxels[1]
return voxelization(input, self.voxel_size, self.point_cloud_range,
self.max_num_points, max_voxels)
self.max_num_points, max_voxels,
self.deterministic)
def __repr__(self):
s = self.__class__.__name__ + '('
@ -140,5 +174,6 @@ class Voxelization(nn.Module):
s += ', point_cloud_range=' + str(self.point_cloud_range)
s += ', max_num_points=' + str(self.max_num_points)
s += ', max_voxels=' + str(self.max_voxels)
s += ', deterministic=' + str(self.deterministic)
s += ')'
return s

View File

@ -60,3 +60,80 @@ def test_voxelization(device_type):
assert np.all(
points[indices] == expected_coors[i][:num_points_current_voxel])
assert num_points_current_voxel == expected_num_points_per_voxel[i]
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_voxelization_nondeterministic():
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
voxel_dict = np.load(
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
points = voxel_dict['points']
points = torch.tensor(points)
max_num_points = -1
dynamic_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
max_num_points = 10
max_voxels = 50
hard_voxelization = Voxelization(
voxel_size,
point_cloud_range,
max_num_points,
max_voxels,
deterministic=False)
# test hard_voxelization (non-deterministic version) on gpu
points = torch.tensor(points).contiguous().to(device='cuda:0')
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()
voxels = voxels.cpu().detach().numpy().tolist()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy().tolist()
coors_all = dynamic_voxelization.forward(points)
coors_all = coors_all.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])
assert len(coors_set) == len(coors)
assert len(coors_set - coors_all_set) == 0
points = points.cpu().detach().numpy().tolist()
coors_points_dict = {}
for c, ps in zip(coors_all, points):
if tuple(c) not in coors_points_dict:
coors_points_dict[tuple(c)] = set()
coors_points_dict[tuple(c)].add(tuple(ps))
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
ideal_voxel_points_set = coors_points_dict[tuple(c)]
voxel_points_set = set([tuple(p) for p in ps[:n]])
assert len(voxel_points_set) == n
if n < max_num_points:
assert voxel_points_set == ideal_voxel_points_set
for p in ps[n:]:
assert max(p) == min(p) == 0
else:
assert len(voxel_points_set - ideal_voxel_points_set) == 0
# test hard_voxelization (non-deterministic version) on gpu
# with all input point in range
points = torch.tensor(points).contiguous().to(device='cuda:0')[:max_voxels]
coors_all = dynamic_voxelization.forward(points)
valid_mask = coors_all.ge(0).all(-1)
points = points[valid_mask]
coors_all = coors_all[valid_mask]
coors_all = coors_all.cpu().detach().numpy().tolist()
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])
assert len(coors_set) == len(coors) == len(coors_all_set)