2019-12-01 10:35:44 +08:00
|
|
|
from __future__ import division, print_function, absolute_import
|
2019-03-20 01:26:08 +08:00
|
|
|
import torch
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
def compute_distance_matrix(input1, input2, metric='euclidean'):
|
2019-03-22 08:14:41 +08:00
|
|
|
"""A wrapper function for computing distance matrix.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input1 (torch.Tensor): 2-D feature matrix.
|
|
|
|
input2 (torch.Tensor): 2-D feature matrix.
|
|
|
|
metric (str, optional): "euclidean" or "cosine".
|
|
|
|
Default is "euclidean".
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: distance matrix.
|
|
|
|
|
|
|
|
Examples::
|
|
|
|
>>> from torchreid import metrics
|
|
|
|
>>> input1 = torch.rand(10, 2048)
|
|
|
|
>>> input2 = torch.rand(100, 2048)
|
|
|
|
>>> distmat = metrics.compute_distance_matrix(input1, input2)
|
|
|
|
>>> distmat.size() # (10, 100)
|
|
|
|
"""
|
2019-03-20 01:26:08 +08:00
|
|
|
# check input
|
|
|
|
assert isinstance(input1, torch.Tensor)
|
|
|
|
assert isinstance(input2, torch.Tensor)
|
2019-12-01 10:35:44 +08:00
|
|
|
assert input1.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(
|
|
|
|
input1.dim()
|
|
|
|
)
|
|
|
|
assert input2.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(
|
|
|
|
input2.dim()
|
|
|
|
)
|
2019-03-20 01:26:08 +08:00
|
|
|
assert input1.size(1) == input2.size(1)
|
|
|
|
|
|
|
|
if metric == 'euclidean':
|
|
|
|
distmat = euclidean_squared_distance(input1, input2)
|
|
|
|
elif metric == 'cosine':
|
|
|
|
distmat = cosine_distance(input1, input2)
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
'Unknown distance metric: {}. '
|
|
|
|
'Please choose either "euclidean" or "cosine"'.format(metric)
|
|
|
|
)
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
return distmat
|
|
|
|
|
|
|
|
|
|
|
|
def euclidean_squared_distance(input1, input2):
|
2019-03-22 08:14:41 +08:00
|
|
|
"""Computes euclidean squared distance.
|
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
Args:
|
2019-03-22 08:14:41 +08:00
|
|
|
input1 (torch.Tensor): 2-D feature matrix.
|
|
|
|
input2 (torch.Tensor): 2-D feature matrix.
|
2019-03-20 01:26:08 +08:00
|
|
|
|
|
|
|
Returns:
|
2019-03-22 08:14:41 +08:00
|
|
|
torch.Tensor: distance matrix.
|
2019-03-20 01:26:08 +08:00
|
|
|
"""
|
|
|
|
m, n = input1.size(0), input2.size(0)
|
2020-05-05 22:58:00 +08:00
|
|
|
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
|
|
|
|
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
|
|
distmat = mat1 + mat2
|
2021-02-03 21:45:08 +08:00
|
|
|
distmat.addmm_(input1, input2.t(), beta=1, alpha=-2)
|
2019-03-22 08:14:41 +08:00
|
|
|
return distmat
|
2019-03-20 01:26:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
def cosine_distance(input1, input2):
|
2019-03-22 08:14:41 +08:00
|
|
|
"""Computes cosine distance.
|
|
|
|
|
2019-03-20 01:26:08 +08:00
|
|
|
Args:
|
2019-03-22 08:14:41 +08:00
|
|
|
input1 (torch.Tensor): 2-D feature matrix.
|
|
|
|
input2 (torch.Tensor): 2-D feature matrix.
|
2019-03-20 01:26:08 +08:00
|
|
|
|
|
|
|
Returns:
|
2019-03-22 08:14:41 +08:00
|
|
|
torch.Tensor: distance matrix.
|
2019-03-20 01:26:08 +08:00
|
|
|
"""
|
|
|
|
input1_normed = F.normalize(input1, p=2, dim=1)
|
|
|
|
input2_normed = F.normalize(input2, p=2, dim=1)
|
|
|
|
distmat = 1 - torch.mm(input1_normed, input2_normed.t())
|
2019-12-01 10:35:44 +08:00
|
|
|
return distmat
|