# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch from torch import nn from .build import LOSS_REGISTRY def normalize(x, axis=-1): """Normalizing to unit length along the specified dimension. Args: x: pytorch Variable Returns: x: pytorch Variable, same shape as input """ x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) return x def euclidean_dist(x, y): m, n = x.size(0), y.size(0) xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() dist = xx + yy dist.addmm_(1, -2, x, y.t()) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability return dist def cosine_dist(x, y): bs1, bs2 = x.size(0), y.size(0) frac_up = torch.matmul(x, y.transpose(0, 1)) frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) cosine = frac_up / frac_down return 1 - cosine def hard_example_mining(dist_mat, labels, return_inds=False): """For each anchor, find the hardest positive and negative sample. Args: dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] labels: pytorch LongTensor, with shape [N] return_inds: whether to return the indices. Save time if `False`(?) Returns: dist_ap: pytorch Variable, distance(anchor, positive); shape [N] dist_an: pytorch Variable, distance(anchor, negative); shape [N] p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 n_inds: pytorch LongTensor, with shape [N]; indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 NOTE: Only consider the case in which all labels have same num of samples, thus we can cope with all anchors in parallel. """ assert len(dist_mat.size()) == 2 assert dist_mat.size(0) == dist_mat.size(1) N = dist_mat.size(0) # shape [N, N] is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) # `dist_ap` means distance(anchor, positive) # both `dist_ap` and `relative_p_inds` with shape [N, 1] # pos_dist = dist_mat[is_pos].contiguous().view(N, -1) # ap_weight = F.softmax(pos_dist, dim=1) # dist_ap = torch.sum(ap_weight * pos_dist, dim=1) dist_ap, relative_p_inds = torch.max( dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N, 1] dist_an, relative_n_inds = torch.min( dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) # neg_dist = dist_mat[is_neg].contiguous().view(N, -1) # an_weight = F.softmax(-neg_dist, dim=1) # dist_an = torch.sum(an_weight * neg_dist, dim=1) # shape [N] dist_ap = dist_ap.squeeze(1) dist_an = dist_an.squeeze(1) if return_inds: # shape [N, N] ind = (labels.new().resize_as_(labels) .copy_(torch.arange(0, N).long()) .unsqueeze(0).expand(N, N)) # shape [N, 1] p_inds = torch.gather( ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) n_inds = torch.gather( ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) # shape [N] p_inds = p_inds.squeeze(1) n_inds = n_inds.squeeze(1) return dist_ap, dist_an, p_inds, n_inds return dist_ap, dist_an @LOSS_REGISTRY.register() class TripletLoss(object): """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). Related Triplet Loss theory can be found in paper 'In Defense of the Triplet Loss for Person Re-Identification'.""" def __init__(self, cfg): self._margin = cfg.MODEL.LOSSES.MARGIN self._normalize_feature = cfg.MODEL.LOSSES.NORM_FEAT self._scale = cfg.MODEL.LOSSES.SCALE_TRI if self._margin > 0: self.ranking_loss = nn.MarginRankingLoss(margin=self._margin) else: self.ranking_loss = nn.SoftMarginLoss() def __call__(self, pred_class_logits, global_features, targets): if self._normalize_feature: global_features = normalize(global_features, axis=-1) dist_mat = euclidean_dist(global_features, global_features) dist_ap, dist_an = hard_example_mining(dist_mat, targets) y = dist_an.new().resize_as_(dist_an).fill_(1) if self._margin > 0: loss = self.ranking_loss(dist_an, dist_ap, y) else: loss = self.ranking_loss(dist_an - dist_ap, y) return { "loss_triplet": loss * self._scale, }