mirror of https://github.com/JDAI-CV/fast-reid.git
29 lines
685 B
Python
29 lines
685 B
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
@author: liaoxingyu
|
||
|
@contact: sherlockliao01@gmail.com
|
||
|
"""
|
||
|
from torch import nn
|
||
|
|
||
|
from .triplet_loss import TripletLoss
|
||
|
|
||
|
|
||
|
__all__ = ['reidLoss']
|
||
|
|
||
|
|
||
|
class reidLoss(nn.Module):
|
||
|
def __init__(self, lossType:list, margin:float):
|
||
|
super().__init__()
|
||
|
self.lossType = lossType
|
||
|
|
||
|
self.ce_loss = nn.CrossEntropyLoss()
|
||
|
self.triplet_loss = TripletLoss(margin)
|
||
|
|
||
|
def forward(self, out, target):
|
||
|
scores, feats = out
|
||
|
loss = 0
|
||
|
if 'softmax' in self.lossType: loss += self.ce_loss(scores, target)
|
||
|
if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0]
|
||
|
|
||
|
return loss
|