switch between soft and hard margin when inf

Summary: Add a mechnism to automatic switch triplet loss with soft margin to hard margin when loss becomes inf.
This commit is contained in:
liaoxingyu 2020-05-26 14:36:33 +08:00
parent 5982f90920
commit d4b71de3aa

View File

@ -5,7 +5,6 @@
"""
import torch
from torch import nn
import torch.nn.functional as F
__all__ = [
@ -103,7 +102,6 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
@ -133,11 +131,6 @@ class TripletLoss(object):
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
self._use_cosine_dist = cfg.MODEL.LOSSES.TRI.USE_COSINE_DIST
if self._margin > 0:
self.ranking_loss = nn.MarginRankingLoss(margin=self._margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, _, global_features, targets):
if self._normalize_feature:
global_features = normalize(global_features, axis=-1)
@ -159,9 +152,11 @@ class TripletLoss(object):
y = dist_an.new().resize_as_(dist_an).fill_(1)
if self._margin > 0:
loss = self.ranking_loss(dist_an, dist_ap, y)
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
loss = F.soft_margin_loss(dist_an - dist_ap, y)
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
return {
"loss_triplet": loss * self._scale,
}