mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
29 lines
878 B
Python
29 lines
878 B
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
@author: liaoxingyu
|
||
|
@contact: sherlockliao01@gmail.com
|
||
|
"""
|
||
|
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from .triplet_loss import TripletLoss
|
||
|
|
||
|
|
||
|
def make_loss(cfg):
|
||
|
sampler = cfg.DATALOADER.SAMPLER
|
||
|
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||
|
|
||
|
if sampler == 'softmax':
|
||
|
def loss_func(score, feat, target):
|
||
|
return F.cross_entropy(score, target)
|
||
|
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
||
|
def loss_func(score, feat, target):
|
||
|
return triplet(feat, target)[0]
|
||
|
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
|
||
|
def loss_func(score, feat, target):
|
||
|
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
||
|
else:
|
||
|
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
||
|
'but got {}'.format(cfg.DATALOADER.SAMPLER))
|
||
|
return loss_func
|