diff --git a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu index 4939fe40a..cbc44651f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/scatter_points_cuda.cu @@ -26,10 +26,15 @@ std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( std::tie(out_coors, coors_map, reduce_count) = at::unique_dim(coors_clean, 0, true, true, true); - // the first element of out_coors is always (-1,-1,-1) and should be removed - out_coors = out_coors.slice(0, 1); - reduce_count = reduce_count.slice(0, 1).to(torch::kInt32); - coors_map = coors_map.to(torch::kInt32) - 1; + if (out_coors[0][0].lt(0).item()) { + // the first element of out_coors (-1,-1,-1) and should be removed + out_coors = out_coors.slice(0, 1); + reduce_count = reduce_count.slice(0, 1); + coors_map = coors_map - 1; + } + + coors_map = coors_map.to(torch::kInt32); + reduce_count = reduce_count.to(torch::kInt32); auto reduced_feats = at::empty({out_coors.size(0), num_feats}, feats.options()); diff --git a/tests/test_ops/test_scatter_points.py b/tests/test_ops/test_scatter_points.py index 8610124e3..aeedba931 100644 --- a/tests/test_ops/test_scatter_points.py +++ b/tests/test_ops/test_scatter_points.py @@ -76,6 +76,42 @@ def test_dynamic_scatter(): feats_out_max = feats_out_max[seq_max] coors_cout_max = coors_out_max[seq_max] + assert (coors_out_mean == ref_voxel_coors).all() + assert torch.allclose( + feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5) + assert (coors_cout_max == ref_voxel_coors).all() + assert torch.allclose( + feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5) + + # test non-empty input without any point out of bound + feats = torch.rand( + size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50 + coors = torch.randint( + low=0, high=20, size=(200000, 3), dtype=torch.int32, device='cuda') + + ref_voxel_coors = coors.unique(dim=0, sorted=True) + ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0] + ref_voxel_feats_mean = [] + ref_voxel_feats_max = [] + for ref_voxel_coor in ref_voxel_coors: + voxel_mask = (coors == ref_voxel_coor).all(dim=-1) + ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0)) + ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values) + ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean) + ref_voxel_feats_max = torch.stack(ref_voxel_feats_max) + + feats_out_mean, coors_out_mean = dsmean(feats, coors) + seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 + + coors_out_mean[:, 2]).argsort() + feats_out_mean = feats_out_mean[seq_mean] + coors_out_mean = coors_out_mean[seq_mean] + + feats_out_max, coors_out_max = dsmax(feats, coors) + seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 + + coors_out_max[:, 2]).argsort() + feats_out_max = feats_out_max[seq_max] + coors_cout_max = coors_out_max[seq_max] + assert (coors_out_mean == ref_voxel_coors).all() assert torch.allclose( feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)