mirror of https://github.com/JDAI-CV/fast-reid.git
200 lines
6.8 KiB
Python
200 lines
6.8 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
__all__ = ["TripletLoss", "CircleLoss"]
|
|
|
|
def normalize(x, axis=-1):
|
|
"""Normalizing to unit length along the specified dimension.
|
|
Args:
|
|
x: pytorch Variable
|
|
Returns:
|
|
x: pytorch Variable, same shape as input
|
|
"""
|
|
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
|
|
return x
|
|
|
|
|
|
def euclidean_dist(x, y):
|
|
m, n = x.size(0), y.size(0)
|
|
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
|
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
|
dist = xx + yy
|
|
dist.addmm_(1, -2, x, y.t())
|
|
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
|
return dist
|
|
|
|
|
|
def cosine_dist(x, y):
|
|
bs1, bs2 = x.size(0), y.size(0)
|
|
frac_up = torch.matmul(x, y.transpose(0, 1))
|
|
frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
|
|
(torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
|
|
cosine = frac_up / frac_down
|
|
return 1 - cosine
|
|
|
|
|
|
def softmax_weights(dist, mask):
|
|
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
|
|
diff = dist - max_v
|
|
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
|
|
W = torch.exp(diff) * mask / Z
|
|
return W
|
|
|
|
|
|
def hard_example_mining(dist_mat, is_pos, is_neg):
|
|
"""For each anchor, find the hardest positive and negative sample.
|
|
Args:
|
|
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
|
labels: pytorch LongTensor, with shape [N]
|
|
return_inds: whether to return the indices. Save time if `False`(?)
|
|
Returns:
|
|
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
|
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
|
p_inds: pytorch LongTensor, with shape [N];
|
|
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
|
|
n_inds: pytorch LongTensor, with shape [N];
|
|
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
|
|
NOTE: Only consider the case in which all labels have same num of samples,
|
|
thus we can cope with all anchors in parallel.
|
|
"""
|
|
|
|
assert len(dist_mat.size()) == 2
|
|
assert dist_mat.size(0) == dist_mat.size(1)
|
|
N = dist_mat.size(0)
|
|
|
|
# `dist_ap` means distance(anchor, positive)
|
|
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
|
|
# pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
|
|
# ap_weight = F.softmax(pos_dist, dim=1)
|
|
# dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
|
|
dist_ap, relative_p_inds = torch.max(
|
|
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
|
|
# `dist_an` means distance(anchor, negative)
|
|
# both `dist_an` and `relative_n_inds` with shape [N, 1]
|
|
dist_an, relative_n_inds = torch.min(
|
|
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
|
|
# neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
|
|
# an_weight = F.softmax(-neg_dist, dim=1)
|
|
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
|
|
|
|
# shape [N]
|
|
dist_ap = dist_ap.squeeze(1)
|
|
dist_an = dist_an.squeeze(1)
|
|
|
|
return dist_ap, dist_an
|
|
|
|
|
|
def weighted_example_mining(dist_mat, is_pos, is_neg):
|
|
"""For each anchor, find the weighted positive and negative sample.
|
|
Args:
|
|
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
|
Returns:
|
|
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
|
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
|
"""
|
|
|
|
assert len(dist_mat.size()) == 2
|
|
assert dist_mat.size(0) == dist_mat.size(1)
|
|
|
|
is_pos = is_pos.float()
|
|
is_neg = is_neg.float()
|
|
dist_ap = dist_mat * is_pos
|
|
dist_an = dist_mat * is_neg
|
|
|
|
weights_ap = softmax_weights(dist_ap, is_pos)
|
|
weights_an = softmax_weights(-dist_an, is_neg)
|
|
|
|
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
|
|
dist_an = torch.sum(dist_an * weights_an, dim=1)
|
|
|
|
return dist_ap, dist_an
|
|
|
|
|
|
class TripletLoss(object):
|
|
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
|
|
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
|
Loss for Person Re-Identification'."""
|
|
|
|
def __init__(self, cfg):
|
|
self._margin = cfg.MODEL.LOSSES.MARGIN
|
|
self._normalize_feature = cfg.MODEL.LOSSES.NORM_FEAT
|
|
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
|
|
self._hard_mining = cfg.MODEL.LOSSES.HARD_MINING
|
|
self._use_cosine_dist = cfg.MODEL.LOSSES.USE_COSINE_DIST
|
|
|
|
if self._margin > 0:
|
|
self.ranking_loss = nn.MarginRankingLoss(margin=self._margin)
|
|
else:
|
|
self.ranking_loss = nn.SoftMarginLoss()
|
|
|
|
def __call__(self, _, global_features, targets):
|
|
if self._normalize_feature:
|
|
global_features = normalize(global_features, axis=-1)
|
|
|
|
if self._use_cosine_dist:
|
|
dist_mat = cosine_dist(global_features, global_features)
|
|
else:
|
|
dist_mat = euclidean_dist(global_features, global_features)
|
|
|
|
N = dist_mat.size(0)
|
|
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t())
|
|
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
|
|
|
|
if self._hard_mining:
|
|
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
|
else:
|
|
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
|
|
|
|
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
|
|
|
if self._margin > 0:
|
|
loss = self.ranking_loss(dist_an, dist_ap, y)
|
|
else:
|
|
loss = self.ranking_loss(dist_an - dist_ap, y)
|
|
return {
|
|
"loss_triplet": loss * self._scale,
|
|
}
|
|
|
|
|
|
class CircleLoss(object):
|
|
def __init__(self, cfg):
|
|
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
|
|
|
|
self.m = 0.25
|
|
self.s = 128
|
|
|
|
def __call__(self, _, global_features, targets):
|
|
global_features = normalize(global_features, axis=-1)
|
|
|
|
sim_mat = torch.matmul(global_features, global_features.t())
|
|
|
|
N = sim_mat.size(0)
|
|
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() - torch.eye(N).to(sim_mat.device)
|
|
is_pos = is_pos.bool()
|
|
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
|
|
|
|
s_p = sim_mat[is_pos].contiguous().view(N, -1)
|
|
s_n = sim_mat[is_neg].contiguous().view(N, -1)
|
|
|
|
alpha_p = F.relu(-s_p.detach() + 1 + self.m)
|
|
alpha_n = F.relu(s_n.detach() + self.m)
|
|
delta_p = 1 - self.m
|
|
delta_n = self.m
|
|
|
|
logit_p = - self.s * alpha_p * (s_p - delta_p)
|
|
logit_n = self.s * alpha_n * (s_n - delta_n)
|
|
|
|
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
|
|
|
|
return {
|
|
"loss_circle": loss * self._scale,
|
|
}
|