mirror of https://github.com/JDAI-CV/fast-reid.git
29 lines
685 B
Python
29 lines
685 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
from torch import nn
|
|
|
|
from .triplet_loss import TripletLoss
|
|
|
|
|
|
__all__ = ['reidLoss']
|
|
|
|
|
|
class reidLoss(nn.Module):
|
|
def __init__(self, lossType:list, margin:float):
|
|
super().__init__()
|
|
self.lossType = lossType
|
|
|
|
self.ce_loss = nn.CrossEntropyLoss()
|
|
self.triplet_loss = TripletLoss(margin)
|
|
|
|
def forward(self, out, target):
|
|
scores, feats = out
|
|
loss = 0
|
|
if 'softmax' in self.lossType: loss += self.ce_loss(scores, target)
|
|
if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0]
|
|
|
|
return loss
|