remove code that is irrelevant to the paper
parent
1edc11736d
commit
d426692b95
|
@ -7,22 +7,15 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth
|
||||
from .cluster_loss import ClusterLoss
|
||||
from .center_loss import CenterLoss
|
||||
from .range_loss import RangeLoss
|
||||
|
||||
|
||||
def make_loss(cfg, num_classes): # modified by gu
|
||||
sampler = cfg.DATALOADER.SAMPLER
|
||||
if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
|
||||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'cluster':
|
||||
cluster = ClusterLoss(cfg.SOLVER.CLUSTER_MARGIN, True, True, cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE, cfg.DATALOADER.NUM_INSTANCE)
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_cluster':
|
||||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||
cluster = ClusterLoss(cfg.SOLVER.CLUSTER_MARGIN, True, True, cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE, cfg.DATALOADER.NUM_INSTANCE)
|
||||
else:
|
||||
print('expected METRIC_LOSS_TYPE should be triplet, cluster, triplet_cluster'
|
||||
print('expected METRIC_LOSS_TYPE should be triplet'
|
||||
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
|
||||
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
|
@ -39,23 +32,11 @@ def make_loss(cfg, num_classes): # modified by gu
|
|||
def loss_func(score, feat, target):
|
||||
if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + triplet(feat, target)[0] # new add by luo, open label smooth
|
||||
return xent(score, target) + triplet(feat, target)[0]
|
||||
else:
|
||||
return F.cross_entropy(score, target) + triplet(feat, target)[0] # new add by luo, no label smooth
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'cluster':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + cluster(feat, target)[0] # new add by luo, open label smooth
|
||||
else:
|
||||
return F.cross_entropy(score, target) + cluster(feat, target)[0] # new add by luo, no label smooth
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_cluster':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + triplet(feat, target)[0] + cluster(feat, target)[0] # new add by luo, open label smooth
|
||||
else:
|
||||
return F.cross_entropy(score, target) + triplet(feat, target)[0] + cluster(feat, target)[0] # new add by luo, no label smooth
|
||||
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
||||
else:
|
||||
print('expected METRIC_LOSS_TYPE should be triplet, cluster, triplet_cluster,'
|
||||
print('expected METRIC_LOSS_TYPE should be triplet'
|
||||
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
|
||||
else:
|
||||
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
||||
|
@ -72,27 +53,12 @@ def make_loss_with_center(cfg, num_classes): # modified by gu
|
|||
if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
|
||||
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'range_center':
|
||||
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center_range loss
|
||||
range_criterion = RangeLoss(k=cfg.SOLVER.RANGE_K, margin=cfg.SOLVER.RANGE_MARGIN, alpha=cfg.SOLVER.RANGE_ALPHA,
|
||||
beta=cfg.SOLVER.RANGE_BETA, ordered=True, use_gpu=True,
|
||||
ids_per_batch=cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE,
|
||||
imgs_per_id=cfg.DATALOADER.NUM_INSTANCE)
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
|
||||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_range_center':
|
||||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center_range loss
|
||||
range_criterion = RangeLoss(k=cfg.SOLVER.RANGE_K, margin=cfg.SOLVER.RANGE_MARGIN, alpha=cfg.SOLVER.RANGE_ALPHA,
|
||||
beta=cfg.SOLVER.RANGE_BETA, ordered=True, use_gpu=True,
|
||||
ids_per_batch=cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE,
|
||||
imgs_per_id=cfg.DATALOADER.NUM_INSTANCE)
|
||||
else:
|
||||
print('expected METRIC_LOSS_TYPE with center should be center, '
|
||||
'range_center,triplet_center, triplet_range_center '
|
||||
print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
|
||||
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
|
||||
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
|
@ -103,45 +69,22 @@ def make_loss_with_center(cfg, num_classes): # modified by gu
|
|||
if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, open label smooth
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
|
||||
else:
|
||||
return F.cross_entropy(score, target) + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, no label smooth
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'range_center':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \
|
||||
cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, open label smooth
|
||||
else:
|
||||
return F.cross_entropy(score, target) + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \
|
||||
cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, no label smooth
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + \
|
||||
triplet(feat, target)[0] + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, open label smooth
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
|
||||
else:
|
||||
return F.cross_entropy(score, target) + \
|
||||
triplet(feat, target)[0] + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, no label smooth
|
||||
|
||||
elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_range_center':
|
||||
if cfg.MODEL.IF_LABELSMOOTH == 'on':
|
||||
return xent(score, target) + \
|
||||
triplet(feat, target)[0] + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \
|
||||
cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, open label smooth
|
||||
else:
|
||||
return F.cross_entropy(score, target) + \
|
||||
triplet(feat, target)[0] + \
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \
|
||||
cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, no label smooth
|
||||
cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
|
||||
|
||||
else:
|
||||
print('expected METRIC_LOSS_TYPE with center should be center,'
|
||||
' range_center, triplet_center, triplet_range_center '
|
||||
print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
|
||||
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
|
||||
return loss_func, center_criterion
|
|
@ -1,269 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ClusterLoss(nn.Module):
|
||||
def __init__(self, margin=10, use_gpu=True, ordered=True, ids_per_batch=16, imgs_per_id=4):
|
||||
super(ClusterLoss, self).__init__()
|
||||
self.use_gpu = use_gpu
|
||||
self.margin = margin
|
||||
self.ordered = ordered
|
||||
self.ids_per_batch = ids_per_batch
|
||||
self.imgs_per_id = imgs_per_id
|
||||
|
||||
def _euclidean_dist(self, x, y):
|
||||
"""
|
||||
Args:
|
||||
x: pytorch Variable, with shape [m, d]
|
||||
y: pytorch Variable, with shape [n, d]
|
||||
Returns:
|
||||
dist: pytorch Variable, with shape [m, n]
|
||||
"""
|
||||
m, n = x.size(0), y.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
def _cluster_loss(self, features, targets, ordered=True, ids_per_batch=16, imgs_per_id=4):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
cluster_loss
|
||||
"""
|
||||
if self.use_gpu:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
|
||||
inter_min_distance = torch.zeros(unique_labels.size(0))
|
||||
intra_max_distance = torch.zeros(unique_labels.size(0))
|
||||
center_features = torch.zeros(unique_labels.size(0), features.size(1))
|
||||
|
||||
if self.use_gpu:
|
||||
inter_min_distance = inter_min_distance.cuda()
|
||||
intra_max_distance = intra_max_distance.cuda()
|
||||
center_features = center_features.cuda()
|
||||
|
||||
index = torch.range(0, unique_labels.size(0) - 1)
|
||||
for i in range(unique_labels.size(0)):
|
||||
label = unique_labels[i]
|
||||
same_class_features = features[targets == label]
|
||||
center_features[i] = same_class_features.mean(dim=0)
|
||||
intra_class_distance = self._euclidean_dist(center_features[index==i], same_class_features)
|
||||
# print('intra_class_distance', intra_class_distance)
|
||||
intra_max_distance[i] = intra_class_distance.max()
|
||||
# print('intra_max_distance:', intra_max_distance)
|
||||
|
||||
for i in range(unique_labels.size(0)):
|
||||
inter_class_distance = self._euclidean_dist(center_features[index==i], center_features[index != i])
|
||||
# print('inter_class_distance', inter_class_distance)
|
||||
inter_min_distance[i] = inter_class_distance.min()
|
||||
# print('inter_min_distance:', inter_min_distance)
|
||||
cluster_loss = torch.mean(torch.relu(intra_max_distance - inter_min_distance + self.margin))
|
||||
return cluster_loss, intra_max_distance, inter_min_distance
|
||||
|
||||
def forward(self, features, targets):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
cluster_loss
|
||||
"""
|
||||
assert features.size(0) == targets.size(0), "features.size(0) is not equal to targets.size(0)"
|
||||
cluster_loss, cluster_dist_ap, cluster_dist_an = self._cluster_loss(features, targets, self.ordered, self.ids_per_batch, self.imgs_per_id)
|
||||
return cluster_loss, cluster_dist_ap, cluster_dist_an
|
||||
|
||||
|
||||
class ClusterLoss_local(nn.Module):
|
||||
def __init__(self, margin=10, use_gpu=True, ordered=True, ids_per_batch=32, imgs_per_id=4):
|
||||
super(ClusterLoss_local, self).__init__()
|
||||
self.use_gpu = use_gpu
|
||||
self.margin = margin
|
||||
self.ordered = ordered
|
||||
self.ids_per_batch = ids_per_batch
|
||||
self.imgs_per_id = imgs_per_id
|
||||
|
||||
def _euclidean_dist(self, x, y):
|
||||
"""
|
||||
Args:
|
||||
x: pytorch Variable, with shape [m, d]
|
||||
y: pytorch Variable, with shape [n, d]
|
||||
Returns:
|
||||
dist: pytorch Variable, with shape [m, n]
|
||||
"""
|
||||
m, n = x.size(0), y.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
def _shortest_dist(self, dist_mat):
|
||||
"""Parallel version.
|
||||
Args:
|
||||
dist_mat: pytorch Variable, available shape:
|
||||
1) [m, n]
|
||||
2) [m, n, N], N is batch size
|
||||
3) [m, n, *], * can be arbitrary additional dimensions
|
||||
Returns:
|
||||
dist: three cases corresponding to `dist_mat`:
|
||||
1) scalar
|
||||
2) pytorch Variable, with shape [N]
|
||||
3) pytorch Variable, with shape [*]
|
||||
"""
|
||||
m, n = dist_mat.size()[:2]
|
||||
# Just offering some reference for accessing intermediate distance.
|
||||
dist = [[0 for _ in range(n)] for _ in range(m)]
|
||||
for i in range(m):
|
||||
for j in range(n):
|
||||
if (i == 0) and (j == 0):
|
||||
dist[i][j] = dist_mat[i, j]
|
||||
elif (i == 0) and (j > 0):
|
||||
dist[i][j] = dist[i][j - 1] + dist_mat[i, j]
|
||||
elif (i > 0) and (j == 0):
|
||||
dist[i][j] = dist[i - 1][j] + dist_mat[i, j]
|
||||
else:
|
||||
dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]
|
||||
dist = dist[-1][-1]
|
||||
return dist
|
||||
|
||||
def _local_dist(self, x, y):
|
||||
"""
|
||||
Args:
|
||||
x: pytorch Variable, with shape [M, m, d]
|
||||
y: pytorch Variable, with shape [N, n, d]
|
||||
Returns:
|
||||
dist: pytorch Variable, with shape [M, N]
|
||||
"""
|
||||
M, m, d = x.size()
|
||||
N, n, d = y.size()
|
||||
x = x.contiguous().view(M * m, d)
|
||||
y = y.contiguous().view(N * n, d)
|
||||
# shape [M * m, N * n]
|
||||
dist_mat = self._euclidean_dist(x, y)
|
||||
dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)
|
||||
# shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N]
|
||||
dist_mat = dist_mat.contiguous().view(M, m, N, n).permute(1, 3, 0, 2)
|
||||
# shape [M, N]
|
||||
dist_mat = self._shortest_dist(dist_mat)
|
||||
return dist_mat
|
||||
|
||||
def _cluster_loss(self, features, targets,ordered=True, ids_per_batch=32, imgs_per_id=4):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, H, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
cluster_loss
|
||||
"""
|
||||
if self.use_gpu:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
|
||||
inter_min_distance = torch.zeros(unique_labels.size(0))
|
||||
intra_max_distance = torch.zeros(unique_labels.size(0))
|
||||
center_features = torch.zeros(unique_labels.size(0), features.size(1), features.size(2))
|
||||
|
||||
if self.use_gpu:
|
||||
inter_min_distance = inter_min_distance.cuda()
|
||||
intra_max_distance = intra_max_distance.cuda()
|
||||
center_features = center_features.cuda()
|
||||
|
||||
index = torch.range(0, unique_labels.size(0) - 1)
|
||||
for i in range(unique_labels.size(0)):
|
||||
label = unique_labels[i]
|
||||
same_class_features = features[targets == label]
|
||||
center_features[i] = same_class_features.mean(dim=0)
|
||||
intra_class_distance = self._local_dist(center_features[index==i], same_class_features)
|
||||
# print('intra_class_distance', intra_class_distance)
|
||||
intra_max_distance[i] = intra_class_distance.max()
|
||||
# print('intra_max_distance:', intra_max_distance)
|
||||
|
||||
for i in range(unique_labels.size(0)):
|
||||
inter_class_distance = self._local_dist(center_features[index==i], center_features[index != i])
|
||||
# print('inter_class_distance', inter_class_distance)
|
||||
inter_min_distance[i] = inter_class_distance.min()
|
||||
# print('inter_min_distance:', inter_min_distance)
|
||||
|
||||
cluster_loss = torch.mean(torch.relu(intra_max_distance - inter_min_distance + self.margin))
|
||||
return cluster_loss, intra_max_distance, inter_min_distance
|
||||
|
||||
def forward(self, features, targets):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, H, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
cluster_loss
|
||||
"""
|
||||
assert features.size(0) == targets.size(0), "features.size(0) is not equal to targets.size(0)"
|
||||
cluster_loss, cluster_dist_ap, cluster_dist_an = self._cluster_loss(features, targets, self.ordered, self.ids_per_batch, self.imgs_per_id)
|
||||
return cluster_loss, cluster_dist_ap, cluster_dist_an
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
use_gpu = True
|
||||
cluster_loss = ClusterLoss(use_gpu=use_gpu, ids_per_batch=4, imgs_per_id=4)
|
||||
features = torch.rand(16, 2048)
|
||||
targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])
|
||||
if use_gpu:
|
||||
features = torch.rand(16, 2048).cuda()
|
||||
targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).cuda()
|
||||
loss = cluster_loss(features, targets)
|
||||
print(loss)
|
||||
|
||||
cluster_loss_local = ClusterLoss_local(use_gpu=use_gpu, ids_per_batch=4, imgs_per_id=4)
|
||||
features = torch.rand(16, 8, 2048)
|
||||
targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])
|
||||
if use_gpu:
|
||||
features = torch.rand(16, 8, 2048).cuda()
|
||||
targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).cuda()
|
||||
loss = cluster_loss_local(features, targets)
|
||||
print(loss)
|
|
@ -1,232 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RangeLoss(nn.Module):
|
||||
"""
|
||||
Range_loss = alpha * intra_class_loss + beta * inter_class_loss
|
||||
intra_class_loss is the harmonic mean value of the top_k largest distances beturn intra_class_pairs
|
||||
inter_class_loss is the shortest distance between different class centers
|
||||
"""
|
||||
def __init__(self, k=2, margin=0.1, alpha=0.5, beta=0.5, use_gpu=True, ordered=True, ids_per_batch=32, imgs_per_id=4):
|
||||
super(RangeLoss, self).__init__()
|
||||
self.use_gpu = use_gpu
|
||||
self.margin = margin
|
||||
self.k = k
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.ordered = ordered
|
||||
self.ids_per_batch = ids_per_batch
|
||||
self.imgs_per_id = imgs_per_id
|
||||
|
||||
def _pairwise_distance(self, features):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
Return:
|
||||
pairwise distance matrix with shape(batch_size, batch_size)
|
||||
"""
|
||||
n = features.size(0)
|
||||
dist = torch.pow(features, 2).sum(dim=1, keepdim=True).expand(n, n)
|
||||
dist = dist + dist.t()
|
||||
dist.addmm_(1, -2, features, features.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
def _compute_top_k(self, features):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
Return:
|
||||
top_k largest distances
|
||||
"""
|
||||
# reading the codes below can help understand better
|
||||
'''
|
||||
dist_array_2 = self._pairwise_distance(features)
|
||||
n = features.size(0)
|
||||
mask = torch.zeros(n, n)
|
||||
if self.use_gpu: mask=mask.cuda()
|
||||
for i in range(0, n):
|
||||
for j in range(i+1, n):
|
||||
mask[i, j] += 1
|
||||
dist_array_2 = dist_array_2 * mask
|
||||
dist_array_2 = dist_array_2.view(1, -1)
|
||||
dist_array_2 = dist_array_2[torch.gt(dist_array_2, 0)]
|
||||
top_k_2 = dist_array_2.sort()[0][-self.k:]
|
||||
print(top_k_2)
|
||||
'''
|
||||
dist_array = self._pairwise_distance(features)
|
||||
dist_array = dist_array.view(1, -1)
|
||||
top_k = dist_array.sort()[0][0, -self.k * 2::2] # Because there are 2 same value of same feature pair in the dist_array
|
||||
# print('top k intra class dist:', top_k)
|
||||
return top_k
|
||||
|
||||
def _compute_min_dist(self, center_features):
|
||||
"""
|
||||
Args:
|
||||
center_features: center matrix (before softmax) with shape (center_number, center_dim)
|
||||
Return:
|
||||
minimum center distance
|
||||
"""
|
||||
'''
|
||||
# reading codes below can help understand better
|
||||
dist_array = self._pairwise_distance(center_features)
|
||||
n = center_features.size(0)
|
||||
mask = torch.zeros(n, n)
|
||||
if self.use_gpu: mask=mask.cuda()
|
||||
for i in range(0, n):
|
||||
for j in range(i + 1, n):
|
||||
mask[i, j] += 1
|
||||
dist_array *= mask
|
||||
dist_array = dist_array.view(1, -1)
|
||||
dist_array = dist_array[torch.gt(dist_array, 0)]
|
||||
min_inter_class_dist = dist_array.min()
|
||||
print(min_inter_class_dist)
|
||||
'''
|
||||
n = center_features.size(0)
|
||||
dist_array2 = self._pairwise_distance(center_features)
|
||||
min_inter_class_dist2 = dist_array2.view(1, -1).sort()[0][0][n] # exclude self compare, the first one is the min_inter_class_dist
|
||||
return min_inter_class_dist2
|
||||
|
||||
def _calculate_centers(self, features, targets, ordered, ids_per_batch, imgs_per_id):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
center_features: center matrix (before softmax) with shape (center_number, center_dim)
|
||||
"""
|
||||
if self.use_gpu:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
|
||||
center_features = torch.zeros(unique_labels.size(0), features.size(1))
|
||||
if self.use_gpu:
|
||||
center_features = center_features.cuda()
|
||||
|
||||
for i in range(unique_labels.size(0)):
|
||||
label = unique_labels[i]
|
||||
same_class_features = features[targets == label]
|
||||
center_features[i] = same_class_features.mean(dim=0)
|
||||
return center_features
|
||||
|
||||
def _inter_class_loss(self, features, targets, ordered, ids_per_batch, imgs_per_id):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
margin: inter class ringe loss margin
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
inter_class_loss
|
||||
"""
|
||||
center_features = self._calculate_centers(features, targets, ordered, ids_per_batch, imgs_per_id)
|
||||
min_inter_class_center_distance = self._compute_min_dist(center_features)
|
||||
# print('min_inter_class_center_dist:', min_inter_class_center_distance)
|
||||
return torch.relu(self.margin - min_inter_class_center_distance)
|
||||
|
||||
def _intra_class_loss(self, features, targets, ordered, ids_per_batch, imgs_per_id):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
intra_class_loss
|
||||
"""
|
||||
if self.use_gpu:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
unique_labels = targets.cpu().unique().cuda()
|
||||
else:
|
||||
if ordered:
|
||||
if targets.size(0) == ids_per_batch * imgs_per_id:
|
||||
unique_labels = targets[0:targets.size(0):imgs_per_id]
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
else:
|
||||
unique_labels = targets.unique()
|
||||
|
||||
intra_distance = torch.zeros(unique_labels.size(0))
|
||||
if self.use_gpu:
|
||||
intra_distance = intra_distance.cuda()
|
||||
|
||||
for i in range(unique_labels.size(0)):
|
||||
label = unique_labels[i]
|
||||
same_class_distances = 1.0 / self._compute_top_k(features[targets == label])
|
||||
intra_distance[i] = self.k / torch.sum(same_class_distances)
|
||||
# print('intra_distace:', intra_distance)
|
||||
return torch.sum(intra_distance)
|
||||
|
||||
def _range_loss(self, features, targets, ordered, ids_per_batch, imgs_per_id):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
range_loss
|
||||
"""
|
||||
inter_class_loss = self._inter_class_loss(features, targets, ordered, ids_per_batch, imgs_per_id)
|
||||
intra_class_loss = self._intra_class_loss(features, targets, ordered, ids_per_batch, imgs_per_id)
|
||||
range_loss = self.alpha * intra_class_loss + self.beta * inter_class_loss
|
||||
return range_loss, intra_class_loss, inter_class_loss
|
||||
|
||||
def forward(self, features, targets):
|
||||
"""
|
||||
Args:
|
||||
features: prediction matrix (before softmax) with shape (batch_size, feature_dim)
|
||||
targets: ground truth labels with shape (batch_size)
|
||||
ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id.
|
||||
ids_per_batch: num of different ids per batch
|
||||
imgs_per_id: num of images per id
|
||||
Return:
|
||||
range_loss
|
||||
"""
|
||||
assert features.size(0) == targets.size(0), "features.size(0) is not equal to targets.size(0)"
|
||||
if self.use_gpu:
|
||||
features = features.cuda()
|
||||
targets = targets.cuda()
|
||||
|
||||
range_loss, intra_class_loss, inter_class_loss = self._range_loss(features, targets, self.ordered, self.ids_per_batch, self.imgs_per_id)
|
||||
return range_loss, intra_class_loss, inter_class_loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
use_gpu = False
|
||||
range_loss = RangeLoss(use_gpu=use_gpu, ids_per_batch=4, imgs_per_id=4)
|
||||
features = torch.rand(16, 2048)
|
||||
targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])
|
||||
if use_gpu:
|
||||
features = torch.rand(16, 2048).cuda()
|
||||
targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).cuda()
|
||||
loss = range_loss(features, targets)
|
||||
print(loss)
|
Loading…
Reference in New Issue