# encoding: utf-8 """ @author: xingyu liao @contact: sherlockliao01@gmail.com """ import torch def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def normalize(x, axis=-1): """Normalizing to unit length along the specified dimension. Args: x: pytorch Variable Returns: x: pytorch Variable, same shape as input """ x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) return x def euclidean_dist(x, y): m, n = x.size(0), y.size(0) xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() dist = xx + yy dist.addmm_(x, y.t(), beta=1, alpha=-2) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability return dist def cosine_dist(x, y): bs1, bs2 = x.size(0), y.size(0) frac_up = torch.matmul(x, y.transpose(0, 1)) frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) cosine = frac_up / frac_down return 1 - cosine