diff --git a/fastreid/modeling/losses/metric_loss.py b/fastreid/modeling/losses/metric_loss.py index 9480525..6719f49 100644 --- a/fastreid/modeling/losses/metric_loss.py +++ b/fastreid/modeling/losses/metric_loss.py @@ -5,7 +5,6 @@ """ import torch -from torch import nn import torch.nn.functional as F __all__ = [ @@ -103,7 +102,6 @@ def weighted_example_mining(dist_mat, is_pos, is_neg): dist_ap: pytorch Variable, distance(anchor, positive); shape [N] dist_an: pytorch Variable, distance(anchor, negative); shape [N] """ - assert len(dist_mat.size()) == 2 assert dist_mat.size(0) == dist_mat.size(1) @@ -133,11 +131,6 @@ class TripletLoss(object): self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING self._use_cosine_dist = cfg.MODEL.LOSSES.TRI.USE_COSINE_DIST - if self._margin > 0: - self.ranking_loss = nn.MarginRankingLoss(margin=self._margin) - else: - self.ranking_loss = nn.SoftMarginLoss() - def __call__(self, _, global_features, targets): if self._normalize_feature: global_features = normalize(global_features, axis=-1) @@ -159,9 +152,11 @@ class TripletLoss(object): y = dist_an.new().resize_as_(dist_an).fill_(1) if self._margin > 0: - loss = self.ranking_loss(dist_an, dist_ap, y) + loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin) else: - loss = self.ranking_loss(dist_an - dist_ap, y) + loss = F.soft_margin_loss(dist_an - dist_ap, y) + if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) + return { "loss_triplet": loss * self._scale, }