put losses in separate files (to ease extension)
parent
bd18a9489e
commit
366784361d
161
losses.py
161
losses.py
|
@ -1,161 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def DeepSupervision(criterion, xs, y):
|
||||
"""
|
||||
Args:
|
||||
- criterion: loss function
|
||||
- xs: tuple of inputs
|
||||
- y: ground truth
|
||||
"""
|
||||
loss = 0.
|
||||
for x in xs:
|
||||
loss += criterion(x, y)
|
||||
loss /= len(xs)
|
||||
return loss
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
"""Cross entropy loss with label smoothing regularizer.
|
||||
|
||||
Reference:
|
||||
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
|
||||
Equation: y = (1 - epsilon) * y + epsilon / K.
|
||||
|
||||
Args:
|
||||
- num_classes (int): number of classes.
|
||||
- epsilon (float): weight.
|
||||
"""
|
||||
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.use_gpu = use_gpu
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Args:
|
||||
- inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
|
||||
- targets: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
|
||||
if self.use_gpu: targets = targets.cuda()
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (- targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
"""Triplet loss with hard positive/negative mining.
|
||||
|
||||
Reference:
|
||||
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
|
||||
|
||||
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
|
||||
|
||||
Args:
|
||||
- margin (float): margin for triplet.
|
||||
"""
|
||||
def __init__(self, margin=0.3):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Args:
|
||||
- inputs: feature matrix with shape (batch_size, feat_dim)
|
||||
- targets: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
n = inputs.size(0)
|
||||
|
||||
# Compute pairwise distance, replace by the official when merged
|
||||
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
|
||||
dist = dist + dist.t()
|
||||
dist.addmm_(1, -2, inputs, inputs.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
|
||||
# For each anchor, find the hardest positive and negative
|
||||
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
|
||||
dist_ap, dist_an = [], []
|
||||
for i in range(n):
|
||||
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
|
||||
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
|
||||
dist_ap = torch.cat(dist_ap)
|
||||
dist_an = torch.cat(dist_an)
|
||||
|
||||
# Compute ranking hinge loss
|
||||
y = torch.ones_like(dist_an)
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
return loss
|
||||
|
||||
|
||||
class CenterLoss(nn.Module):
|
||||
"""Center loss.
|
||||
|
||||
Reference:
|
||||
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
|
||||
Args:
|
||||
- num_classes (int): number of classes.
|
||||
- feat_dim (int): feature dimension.
|
||||
"""
|
||||
def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
if self.use_gpu:
|
||||
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
|
||||
else:
|
||||
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
|
||||
|
||||
def forward(self, x, labels):
|
||||
"""
|
||||
Args:
|
||||
- x: feature matrix with shape (batch_size, feat_dim).
|
||||
- labels: ground truth labels with shape (num_classes).
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
|
||||
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
|
||||
distmat.addmm_(1, -2, x, self.centers.t())
|
||||
|
||||
classes = torch.arange(self.num_classes).long()
|
||||
if self.use_gpu: classes = classes.cuda()
|
||||
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
|
||||
mask = labels.eq(classes.expand(batch_size, self.num_classes))
|
||||
|
||||
dist = []
|
||||
for i in range(batch_size):
|
||||
value = distmat[i][mask[i]]
|
||||
value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
|
||||
dist.append(value)
|
||||
dist = torch.cat(dist)
|
||||
loss = dist.mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class RingLoss(nn.Module):
|
||||
"""Ring loss.
|
||||
|
||||
Reference:
|
||||
Zheng et al. Ring loss: Convex Feature Normalization for Face Recognition. CVPR 2018.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(RingLoss, self).__init__()
|
||||
self.radius = nn.Parameter(torch.ones(1, dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
loss = ((x.norm(p=2, dim=1) - self.radius)**2).mean()
|
||||
return loss
|
|
@ -0,0 +1,22 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .cross_entropy_loss import CrossEntropyLabelSmooth
|
||||
from .hard_mine_triplet_loss import TripletLoss
|
||||
from .center_loss import CenterLoss
|
||||
from .ring_loss import RingLoss
|
||||
|
||||
|
||||
def DeepSupervision(criterion, xs, y):
|
||||
"""
|
||||
Args:
|
||||
- criterion: loss function
|
||||
- xs: tuple of inputs
|
||||
- y: ground truth
|
||||
"""
|
||||
loss = 0.
|
||||
for x in xs:
|
||||
loss += criterion(x, y)
|
||||
loss /= len(xs)
|
||||
return loss
|
|
@ -0,0 +1,53 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CenterLoss(nn.Module):
|
||||
"""Center loss.
|
||||
|
||||
Reference:
|
||||
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
|
||||
Args:
|
||||
- num_classes (int): number of classes.
|
||||
- feat_dim (int): feature dimension.
|
||||
"""
|
||||
def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
if self.use_gpu:
|
||||
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
|
||||
else:
|
||||
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
|
||||
|
||||
def forward(self, x, labels):
|
||||
"""
|
||||
Args:
|
||||
- x: feature matrix with shape (batch_size, feat_dim).
|
||||
- labels: ground truth labels with shape (num_classes).
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
|
||||
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
|
||||
distmat.addmm_(1, -2, x, self.centers.t())
|
||||
|
||||
classes = torch.arange(self.num_classes).long()
|
||||
if self.use_gpu: classes = classes.cuda()
|
||||
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
|
||||
mask = labels.eq(classes.expand(batch_size, self.num_classes))
|
||||
|
||||
dist = []
|
||||
for i in range(batch_size):
|
||||
value = distmat[i][mask[i]]
|
||||
value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
|
||||
dist.append(value)
|
||||
dist = torch.cat(dist)
|
||||
loss = dist.mean()
|
||||
|
||||
return loss
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
"""Cross entropy loss with label smoothing regularizer.
|
||||
|
||||
Reference:
|
||||
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
|
||||
Equation: y = (1 - epsilon) * y + epsilon / K.
|
||||
|
||||
Args:
|
||||
- num_classes (int): number of classes.
|
||||
- epsilon (float): weight.
|
||||
"""
|
||||
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.use_gpu = use_gpu
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Args:
|
||||
- inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
|
||||
- targets: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
|
||||
if self.use_gpu: targets = targets.cuda()
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (- targets * log_probs).mean(0).sum()
|
||||
return loss
|
|
@ -0,0 +1,49 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
"""Triplet loss with hard positive/negative mining.
|
||||
|
||||
Reference:
|
||||
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
|
||||
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
|
||||
|
||||
Args:
|
||||
- margin (float): margin for triplet.
|
||||
"""
|
||||
def __init__(self, margin=0.3):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Args:
|
||||
- inputs: feature matrix with shape (batch_size, feat_dim)
|
||||
- targets: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
n = inputs.size(0)
|
||||
|
||||
# Compute pairwise distance, replace by the official when merged
|
||||
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
|
||||
dist = dist + dist.t()
|
||||
dist.addmm_(1, -2, inputs, inputs.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
|
||||
# For each anchor, find the hardest positive and negative
|
||||
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
|
||||
dist_ap, dist_an = [], []
|
||||
for i in range(n):
|
||||
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
|
||||
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
|
||||
dist_ap = torch.cat(dist_ap)
|
||||
dist_an = torch.cat(dist_an)
|
||||
|
||||
# Compute ranking hinge loss
|
||||
y = torch.ones_like(dist_an)
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
return loss
|
|
@ -0,0 +1,20 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RingLoss(nn.Module):
|
||||
"""Ring loss.
|
||||
|
||||
Reference:
|
||||
Zheng et al. Ring loss: Convex Feature Normalization for Face Recognition. CVPR 2018.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(RingLoss, self).__init__()
|
||||
self.radius = nn.Parameter(torch.ones(1, dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
loss = ((x.norm(p=2, dim=1) - self.radius)**2).mean()
|
||||
return loss
|
Loading…
Reference in New Issue