mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
switch between soft and hard margin when inf
Summary: Add a mechnism to automatic switch triplet loss with soft margin to hard margin when loss becomes inf.
This commit is contained in:
parent
5982f90920
commit
d4b71de3aa
@ -5,7 +5,6 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = [
|
||||
@ -103,7 +102,6 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
|
||||
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
||||
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
||||
"""
|
||||
|
||||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
|
||||
@ -133,11 +131,6 @@ class TripletLoss(object):
|
||||
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
|
||||
self._use_cosine_dist = cfg.MODEL.LOSSES.TRI.USE_COSINE_DIST
|
||||
|
||||
if self._margin > 0:
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=self._margin)
|
||||
else:
|
||||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def __call__(self, _, global_features, targets):
|
||||
if self._normalize_feature:
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
@ -159,9 +152,11 @@ class TripletLoss(object):
|
||||
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
||||
|
||||
if self._margin > 0:
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin)
|
||||
else:
|
||||
loss = self.ranking_loss(dist_an - dist_ap, y)
|
||||
loss = F.soft_margin_loss(dist_an - dist_ap, y)
|
||||
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
|
||||
|
||||
return {
|
||||
"loss_triplet": loss * self._scale,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user