mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
switch between soft and hard margin when inf
Summary: Add a mechnism to automatic switch triplet loss with soft margin to hard margin when loss becomes inf.
This commit is contained in:
parent
5982f90920
commit
d4b71de3aa
@ -5,7 +5,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -103,7 +102,6 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
|
|||||||
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
||||||
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(dist_mat.size()) == 2
|
assert len(dist_mat.size()) == 2
|
||||||
assert dist_mat.size(0) == dist_mat.size(1)
|
assert dist_mat.size(0) == dist_mat.size(1)
|
||||||
|
|
||||||
@ -133,11 +131,6 @@ class TripletLoss(object):
|
|||||||
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
|
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
|
||||||
self._use_cosine_dist = cfg.MODEL.LOSSES.TRI.USE_COSINE_DIST
|
self._use_cosine_dist = cfg.MODEL.LOSSES.TRI.USE_COSINE_DIST
|
||||||
|
|
||||||
if self._margin > 0:
|
|
||||||
self.ranking_loss = nn.MarginRankingLoss(margin=self._margin)
|
|
||||||
else:
|
|
||||||
self.ranking_loss = nn.SoftMarginLoss()
|
|
||||||
|
|
||||||
def __call__(self, _, global_features, targets):
|
def __call__(self, _, global_features, targets):
|
||||||
if self._normalize_feature:
|
if self._normalize_feature:
|
||||||
global_features = normalize(global_features, axis=-1)
|
global_features = normalize(global_features, axis=-1)
|
||||||
@ -159,9 +152,11 @@ class TripletLoss(object):
|
|||||||
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
||||||
|
|
||||||
if self._margin > 0:
|
if self._margin > 0:
|
||||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin)
|
||||||
else:
|
else:
|
||||||
loss = self.ranking_loss(dist_an - dist_ap, y)
|
loss = F.soft_margin_loss(dist_an - dist_ap, y)
|
||||||
|
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loss_triplet": loss * self._scale,
|
"loss_triplet": loss * self._scale,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user