# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch import torch.nn.functional as F from fastreid.utils import comm from .utils import concat_all_gather, euclidean_dist, normalize def softmax_weights(dist, mask): max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] diff = dist - max_v Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero W = torch.exp(diff) * mask / Z return W def hard_example_mining(dist_mat, is_pos, is_neg): """For each anchor, find the hardest positive and negative sample. Args: dist_mat: pair wise distance between samples, shape [N, M] is_pos: positive index with shape [N, M] is_neg: negative index with shape [N, M] Returns: dist_ap: pytorch Variable, distance(anchor, positive); shape [N] dist_an: pytorch Variable, distance(anchor, negative); shape [N] p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 n_inds: pytorch LongTensor, with shape [N]; indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 NOTE: Only consider the case in which all labels have same num of samples, thus we can cope with all anchors in parallel. """ assert len(dist_mat.size()) == 2 N = dist_mat.size(0) # `dist_ap` means distance(anchor, positive) # both `dist_ap` and `relative_p_inds` with shape [N, 1] # pos_dist = dist_mat[is_pos].contiguous().view(N, -1) # ap_weight = F.softmax(pos_dist, dim=1) # dist_ap = torch.sum(ap_weight * pos_dist, dim=1) dist_ap, relative_p_inds = torch.max( dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N, 1] dist_an, relative_n_inds = torch.min( dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) # neg_dist = dist_mat[is_neg].contiguous().view(N, -1) # an_weight = F.softmax(-neg_dist, dim=1) # dist_an = torch.sum(an_weight * neg_dist, dim=1) # shape [N] dist_ap = dist_ap.squeeze(1) dist_an = dist_an.squeeze(1) return dist_ap, dist_an def weighted_example_mining(dist_mat, is_pos, is_neg): """For each anchor, find the weighted positive and negative sample. Args: dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] is_pos: is_neg: Returns: dist_ap: pytorch Variable, distance(anchor, positive); shape [N] dist_an: pytorch Variable, distance(anchor, negative); shape [N] """ assert len(dist_mat.size()) == 2 is_pos = is_pos.float() is_neg = is_neg.float() dist_ap = dist_mat * is_pos dist_an = dist_mat * is_neg weights_ap = softmax_weights(dist_ap, is_pos) weights_an = softmax_weights(-dist_an, is_neg) dist_ap = torch.sum(dist_ap * weights_ap, dim=1) dist_an = torch.sum(dist_an * weights_an, dim=1) return dist_ap, dist_an class TripletLoss(object): """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). Related Triplet Loss theory can be found in paper 'In Defense of the Triplet Loss for Person Re-Identification'.""" def __init__(self, cfg): self._margin = cfg.MODEL.LOSSES.TRI.MARGIN self._normalize_feature = cfg.MODEL.LOSSES.TRI.NORM_FEAT self._scale = cfg.MODEL.LOSSES.TRI.SCALE self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING def __call__(self, embedding, targets): if self._normalize_feature: embedding = normalize(embedding, axis=-1) # For distributed training, gather all features from different process. if comm.get_world_size() > 1: all_embedding = concat_all_gather(embedding) all_targets = concat_all_gather(targets) else: all_embedding = embedding all_targets = targets dist_mat = euclidean_dist(all_embedding, all_embedding) N = dist_mat.size(0) is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()) is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) if self._hard_mining: dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) else: dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg) y = dist_an.new().resize_as_(dist_an).fill_(1) if self._margin > 0: loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin) else: loss = F.soft_margin_loss(dist_an - dist_ap, y) if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) return loss * self._scale