mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix the potential NaN bug in calc_square_dist() (#2356)
parent
9709ff3f8c
commit
3c35bab67c
|
@ -24,16 +24,11 @@ def calc_square_dist(point_feat_a: Tensor,
|
||||||
torch.Tensor: (B, N, M) Square distance between each point pair.
|
torch.Tensor: (B, N, M) Square distance between each point pair.
|
||||||
"""
|
"""
|
||||||
num_channel = point_feat_a.shape[-1]
|
num_channel = point_feat_a.shape[-1]
|
||||||
# [bs, n, 1]
|
dist = torch.cdist(point_feat_a, point_feat_b)
|
||||||
a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
|
|
||||||
# [bs, 1, m]
|
|
||||||
b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
|
|
||||||
|
|
||||||
corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
|
|
||||||
|
|
||||||
dist = a_square + b_square - 2 * corr_matrix
|
|
||||||
if norm:
|
if norm:
|
||||||
dist = torch.sqrt(dist) / num_channel
|
dist = dist / num_channel
|
||||||
|
else:
|
||||||
|
dist = torch.square(dist)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue