84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
|
from __future__ import absolute_import
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.autograd import Variable
|
||
|
|
||
|
class CrossEntropyLabelSmooth(nn.Module):
|
||
|
"""Cross entropy loss with label smoothing regularizer.
|
||
|
|
||
|
Reference:
|
||
|
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
|
||
|
Equation: y = (1 - epsilon) * y + epsilon / K.
|
||
|
|
||
|
Args:
|
||
|
num_classes (int): number of classes.
|
||
|
epsilon (float): weight.
|
||
|
"""
|
||
|
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
|
||
|
super(CrossEntropyLabelSmooth, self).__init__()
|
||
|
self.num_classes = num_classes
|
||
|
self.epsilon = epsilon
|
||
|
self.use_gpu = use_gpu
|
||
|
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||
|
|
||
|
def forward(self, inputs, targets):
|
||
|
"""
|
||
|
Args:
|
||
|
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
|
||
|
targets: ground truth labels with shape (num_classes)
|
||
|
"""
|
||
|
log_probs = self.logsoftmax(inputs)
|
||
|
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
|
||
|
if self.use_gpu: targets = targets.cuda()
|
||
|
targets = Variable(targets, requires_grad=False)
|
||
|
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||
|
loss = (- targets * log_probs).mean(0).sum()
|
||
|
return loss
|
||
|
|
||
|
class TripletLoss(nn.Module):
|
||
|
"""Triplet loss with hard positive/negative mining.
|
||
|
|
||
|
Reference:
|
||
|
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
|
||
|
|
||
|
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
|
||
|
|
||
|
Args:
|
||
|
margin (float): margin for triplet.
|
||
|
"""
|
||
|
def __init__(self, margin=0.3):
|
||
|
super(TripletLoss, self).__init__()
|
||
|
self.margin = margin
|
||
|
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||
|
|
||
|
def forward(self, inputs, targets):
|
||
|
"""
|
||
|
Args:
|
||
|
inputs: feature matrix with shape (batch_size, feat_dim)
|
||
|
targets: ground truth labels with shape (num_classes)
|
||
|
"""
|
||
|
n = inputs.size(0)
|
||
|
# Compute pairwise distance, replace by the official when merged
|
||
|
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
|
||
|
dist = dist + dist.t()
|
||
|
dist.addmm_(1, -2, inputs, inputs.t())
|
||
|
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||
|
# For each anchor, find the hardest positive and negative
|
||
|
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
|
||
|
dist_ap, dist_an = [], []
|
||
|
for i in range(n):
|
||
|
dist_ap.append(dist[i][mask[i]].max())
|
||
|
dist_an.append(dist[i][mask[i] == 0].min())
|
||
|
dist_ap = torch.cat(dist_ap)
|
||
|
dist_an = torch.cat(dist_an)
|
||
|
# Compute ranking hinge loss
|
||
|
y = dist_an.data.new()
|
||
|
y.resize_as_(dist_an.data)
|
||
|
y.fill_(1)
|
||
|
y = Variable(y)
|
||
|
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||
|
return loss
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
pass
|