mirror of https://github.com/JDAI-CV/fast-reid.git
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def concat_all_gather(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
tensors_gather = [torch.ones_like(tensor)
|
|
for _ in range(torch.distributed.get_world_size())]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|
|
|
|
|
|
def normalize(x, axis=-1):
|
|
"""Normalizing to unit length along the specified dimension.
|
|
Args:
|
|
x: pytorch Variable
|
|
Returns:
|
|
x: pytorch Variable, same shape as input
|
|
"""
|
|
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
|
|
return x
|
|
|
|
|
|
def euclidean_dist(x, y):
|
|
m, n = x.size(0), y.size(0)
|
|
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
|
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
|
dist = xx + yy - 2 * torch.matmul(x, y.t())
|
|
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
|
return dist
|
|
|
|
|
|
def cosine_dist(x, y):
|
|
x = F.normalize(x, dim=1)
|
|
y = F.normalize(y, dim=1)
|
|
dist = 2 - 2 * torch.mm(x, y.t())
|
|
return dist
|