reid-strong-baseline/layers/__init__.py

90 lines
3.7 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth
from .center_loss import CenterLoss
def make_loss(cfg, num_classes): # modified by gu
sampler = cfg.DATALOADER.SAMPLER
if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
else:
print('expected METRIC_LOSS_TYPE should be triplet'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
if cfg.MODEL.IF_LABELSMOOTH == 'on':
xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
print("label smooth on, numclasses:", num_classes)
if sampler == 'softmax':
def loss_func(score, feat, target):
return F.cross_entropy(score, target)
elif cfg.DATALOADER.SAMPLER == 'triplet':
def loss_func(score, feat, target):
return triplet(feat, target)[0]
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
def loss_func(score, feat, target):
if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
return xent(score, target) + triplet(feat, target)[0]
else:
return F.cross_entropy(score, target) + triplet(feat, target)[0]
else:
print('expected METRIC_LOSS_TYPE should be triplet'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
else:
print('expected sampler should be softmax, triplet or softmax_triplet, '
'but got {}'.format(cfg.DATALOADER.SAMPLER))
return loss_func
def make_loss_with_center(cfg, num_classes): # modified by gu
if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34':
feat_dim = 512
else:
feat_dim = 2048
if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
else:
print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
if cfg.MODEL.IF_LABELSMOOTH == 'on':
xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
print("label smooth on, numclasses:", num_classes)
def loss_func(score, feat, target):
if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
return xent(score, target) + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
else:
return F.cross_entropy(score, target) + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
return xent(score, target) + \
triplet(feat, target)[0] + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
else:
return F.cross_entropy(score, target) + \
triplet(feat, target)[0] + \
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
else:
print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
return loss_func, center_criterion