From 3c35bab67c460c86114457bd998c4ac2a4041f7a Mon Sep 17 00:00:00 2001 From: ChaimZhu Date: Wed, 26 Oct 2022 14:06:28 +0800 Subject: [PATCH] [Fix] Fix the potential NaN bug in calc_square_dist() (#2356) --- mmcv/ops/points_sampler.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/mmcv/ops/points_sampler.py b/mmcv/ops/points_sampler.py index e1fd37605..1cff84620 100644 --- a/mmcv/ops/points_sampler.py +++ b/mmcv/ops/points_sampler.py @@ -24,16 +24,11 @@ def calc_square_dist(point_feat_a: Tensor, torch.Tensor: (B, N, M) Square distance between each point pair. """ num_channel = point_feat_a.shape[-1] - # [bs, n, 1] - a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1) - # [bs, 1, m] - b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1) - - corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2)) - - dist = a_square + b_square - 2 * corr_matrix + dist = torch.cdist(point_feat_a, point_feat_b) if norm: - dist = torch.sqrt(dist) / num_channel + dist = dist / num_channel + else: + dist = torch.square(dist) return dist