mirror of https://github.com/JDAI-CV/fast-reid.git
32 lines
944 B
Python
32 lines
944 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from .triplet_loss import TripletLoss
|
|
|
|
|
|
def make_loss(cfg):
|
|
sampler = cfg.DATALOADER.SAMPLER
|
|
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
|
|
|
if sampler == 'softmax':
|
|
def loss_func(out, target):
|
|
score, feat = out
|
|
return F.cross_entropy(score, target)
|
|
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
|
def loss_func(out, target):
|
|
score, feat = out
|
|
return triplet(feat, target)[0]
|
|
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
|
|
def loss_func(out, target):
|
|
score, feat = out
|
|
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
|
else:
|
|
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
|
'but got {}'.format(cfg.DATALOADER.SAMPLER))
|
|
return loss_func
|