mirror of https://github.com/JDAI-CV/fast-reid.git
114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from .utils import euclidean_dist, cosine_dist
|
|
|
|
|
|
def softmax_weights(dist, mask):
|
|
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
|
|
diff = dist - max_v
|
|
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
|
|
W = torch.exp(diff) * mask / Z
|
|
return W
|
|
|
|
|
|
def hard_example_mining(dist_mat, is_pos, is_neg):
|
|
"""For each anchor, find the hardest positive and negative sample.
|
|
Args:
|
|
dist_mat: pair wise distance between samples, shape [N, M]
|
|
is_pos: positive index with shape [N, M]
|
|
is_neg: negative index with shape [N, M]
|
|
Returns:
|
|
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
|
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
|
p_inds: pytorch LongTensor, with shape [N];
|
|
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
|
|
n_inds: pytorch LongTensor, with shape [N];
|
|
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
|
|
NOTE: Only consider the case in which all labels have same num of samples,
|
|
thus we can cope with all anchors in parallel.
|
|
"""
|
|
|
|
assert len(dist_mat.size()) == 2
|
|
|
|
# `dist_ap` means distance(anchor, positive)
|
|
# both `dist_ap` and `relative_p_inds` with shape [N]
|
|
dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)
|
|
# `dist_an` means distance(anchor, negative)
|
|
# both `dist_an` and `relative_n_inds` with shape [N]
|
|
dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)
|
|
|
|
return dist_ap, dist_an
|
|
|
|
|
|
def weighted_example_mining(dist_mat, is_pos, is_neg):
|
|
"""For each anchor, find the weighted positive and negative sample.
|
|
Args:
|
|
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
|
is_pos:
|
|
is_neg:
|
|
Returns:
|
|
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
|
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
|
"""
|
|
assert len(dist_mat.size()) == 2
|
|
|
|
is_pos = is_pos
|
|
is_neg = is_neg
|
|
dist_ap = dist_mat * is_pos
|
|
dist_an = dist_mat * is_neg
|
|
|
|
weights_ap = softmax_weights(dist_ap, is_pos)
|
|
weights_an = softmax_weights(-dist_an, is_neg)
|
|
|
|
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
|
|
dist_an = torch.sum(dist_an * weights_an, dim=1)
|
|
|
|
return dist_ap, dist_an
|
|
|
|
|
|
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
|
|
r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
|
|
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
|
Loss for Person Re-Identification'."""
|
|
|
|
if norm_feat:
|
|
dist_mat = cosine_dist(embedding, embedding)
|
|
else:
|
|
dist_mat = euclidean_dist(embedding, embedding)
|
|
|
|
# For distributed training, gather all features from different process.
|
|
# if comm.get_world_size() > 1:
|
|
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
|
|
# all_targets = concat_all_gather(targets)
|
|
# else:
|
|
# all_embedding = embedding
|
|
# all_targets = targets
|
|
|
|
N = dist_mat.size(0)
|
|
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
|
|
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
|
|
|
|
if hard_mining:
|
|
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
|
else:
|
|
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
|
|
|
|
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
|
|
|
if margin > 0:
|
|
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
|
|
else:
|
|
loss = F.soft_margin_loss(dist_an - dist_ap, y)
|
|
# fmt: off
|
|
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
|
|
# fmt: on
|
|
|
|
return loss
|