[Fix] Fix bugs in DynamicScatter op (#1748)

* Fix bugs in DynamicScatter op

* recover unittest

* add a comment as a reminder

* compatible to torch with lower version
This commit is contained in:
Wenhao Wu 2022-03-07 21:02:38 +08:00 committed by GitHub
parent 0394990a47
commit 09b64a60b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 4 deletions

View File

@ -26,10 +26,15 @@ std::vector<at::Tensor> DynamicPointToVoxelForwardCUDAKernelLauncher(
std::tie(out_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, true);
// the first element of out_coors is always (-1,-1,-1) and should be removed
out_coors = out_coors.slice(0, 1);
reduce_count = reduce_count.slice(0, 1).to(torch::kInt32);
coors_map = coors_map.to(torch::kInt32) - 1;
if (out_coors[0][0].lt(0).item<bool>()) {
// the first element of out_coors (-1,-1,-1) and should be removed
out_coors = out_coors.slice(0, 1);
reduce_count = reduce_count.slice(0, 1);
coors_map = coors_map - 1;
}
coors_map = coors_map.to(torch::kInt32);
reduce_count = reduce_count.to(torch::kInt32);
auto reduced_feats =
at::empty({out_coors.size(0), num_feats}, feats.options());

View File

@ -76,6 +76,42 @@ def test_dynamic_scatter():
feats_out_max = feats_out_max[seq_max]
coors_cout_max = coors_out_max[seq_max]
assert (coors_out_mean == ref_voxel_coors).all()
assert torch.allclose(
feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)
assert (coors_cout_max == ref_voxel_coors).all()
assert torch.allclose(
feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5)
# test non-empty input without any point out of bound
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=0, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
ref_voxel_feats_mean = []
ref_voxel_feats_max = []
for ref_voxel_coor in ref_voxel_coors:
voxel_mask = (coors == ref_voxel_coor).all(dim=-1)
ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0))
ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values)
ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean)
ref_voxel_feats_max = torch.stack(ref_voxel_feats_max)
feats_out_mean, coors_out_mean = dsmean(feats, coors)
seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 +
coors_out_mean[:, 2]).argsort()
feats_out_mean = feats_out_mean[seq_mean]
coors_out_mean = coors_out_mean[seq_mean]
feats_out_max, coors_out_max = dsmax(feats, coors)
seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 +
coors_out_max[:, 2]).argsort()
feats_out_max = feats_out_max[seq_max]
coors_cout_max = coors_out_max[seq_max]
assert (coors_out_mean == ref_voxel_coors).all()
assert torch.allclose(
feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)