fast-reid/layers/loss.py

29 lines
685 B
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
from .triplet_loss import TripletLoss
__all__ = ['reidLoss']
class reidLoss(nn.Module):
def __init__(self, lossType:list, margin:float):
super().__init__()
self.lossType = lossType
self.ce_loss = nn.CrossEntropyLoss()
self.triplet_loss = TripletLoss(margin)
def forward(self, out, target):
scores, feats = out
loss = 0
if 'softmax' in self.lossType: loss += self.ce_loss(scores, target)
if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0]
return loss