fast-reid/layers/loss.py
liaoxingyu 3d5f7d24aa 1. fix minor bug
2. update experiment results in readme
2019-08-21 09:35:34 +08:00

29 lines
685 B
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
from .triplet_loss import TripletLoss
__all__ = ['reidLoss']
class reidLoss(nn.Module):
def __init__(self, lossType:list, margin:float):
super().__init__()
self.lossType = lossType
self.ce_loss = nn.CrossEntropyLoss()
self.triplet_loss = TripletLoss(margin)
def forward(self, out, target):
scores, feats = out
loss = 0
if 'softmax' in self.lossType: loss += self.ce_loss(scores, target)
if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0]
return loss