rename losses -> loss
parent
51f0b78bd4
commit
69f563d234
|
@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
|
|||
from ppcls.utils import logger
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model
|
||||
from ppcls.losses import build_loss
|
||||
from ppcls.loss import build_loss
|
||||
from ppcls.arch.loss_metrics import build_metrics
|
||||
from ppcls.optimizer import build_optimizer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
|
|
|
@ -5,12 +5,15 @@ import paddle
|
|||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class CenterLoss(nn.Layer):
|
||||
def __init__(self, num_classes=5013, feat_dim=2048):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center
|
||||
self.centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim]).astype(
|
||||
"float64") #random center
|
||||
|
||||
def __call__(self, input, target):
|
||||
"""
|
||||
|
@ -23,25 +26,29 @@ class CenterLoss(nn.Layer):
|
|||
|
||||
#calc feat * feat
|
||||
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
|
||||
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
|
||||
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
|
||||
|
||||
#dist2 of centers
|
||||
dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes
|
||||
dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64")
|
||||
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
|
||||
keepdim=True) #num_classes
|
||||
dist2 = paddle.expand(dist2,
|
||||
[self.num_classes, batch_size]).astype("float64")
|
||||
dist2 = paddle.transpose(dist2, [1, 0])
|
||||
|
||||
#first x * x + y * y
|
||||
distmat = paddle.add(dist1, dist2)
|
||||
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
|
||||
distmat = distmat - 2.0 * tmp
|
||||
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
|
||||
distmat = distmat - 2.0 * tmp
|
||||
|
||||
#generate the mask
|
||||
classes = paddle.arange(self.num_classes).astype("int64")
|
||||
labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
|
||||
mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask
|
||||
labels = paddle.expand(
|
||||
paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
|
||||
mask = paddle.equal(
|
||||
paddle.expand(classes, [batch_size, self.num_classes]),
|
||||
labels).astype("float64") #get mask
|
||||
|
||||
dist = paddle.multiply(distmat, mask)
|
||||
dist = paddle.multiply(distmat, mask)
|
||||
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
|
||||
|
||||
return {'CenterLoss': loss}
|
||||
|
|
@ -18,26 +18,27 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rerange_index(batch_size, samples_each_class):
|
||||
tmp = np.arange(0, batch_size * batch_size)
|
||||
tmp = tmp.reshape(-1, batch_size)
|
||||
tmp = np.arange(0, batch_size * batch_size)
|
||||
tmp = tmp.reshape(-1, batch_size)
|
||||
rerange_index = []
|
||||
|
||||
for i in range(batch_size):
|
||||
step = i // samples_each_class
|
||||
start = step * samples_each_class
|
||||
end = (step + 1) * samples_each_class
|
||||
end = (step + 1) * samples_each_class
|
||||
|
||||
pos_idx = []
|
||||
neg_idx = []
|
||||
pos_idx = []
|
||||
neg_idx = []
|
||||
for j, k in enumerate(tmp[i]):
|
||||
if j >= start and j < end:
|
||||
if j == i:
|
||||
pos_idx.insert(0, k)
|
||||
else:
|
||||
pos_idx.append(k)
|
||||
pos_idx.append(k)
|
||||
else:
|
||||
neg_idx.append(k)
|
||||
neg_idx.append(k)
|
||||
rerange_index += (pos_idx + neg_idx)
|
||||
|
||||
rerange_index = np.array(rerange_index).astype(np.int32)
|
|
@ -21,56 +21,64 @@ import paddle
|
|||
import numpy as np
|
||||
from .comfunc import rerange_index
|
||||
|
||||
|
||||
class EmlLoss(paddle.nn.Layer):
|
||||
def __init__(self, batch_size = 40, samples_each_class = 2):
|
||||
def __init__(self, batch_size=40, samples_each_class=2):
|
||||
super(EmlLoss, self).__init__()
|
||||
assert(batch_size % samples_each_class == 0)
|
||||
assert (batch_size % samples_each_class == 0)
|
||||
self.samples_each_class = samples_each_class
|
||||
self.batch_size = batch_size
|
||||
self.rerange_index = rerange_index(batch_size, samples_each_class)
|
||||
self.batch_size = batch_size
|
||||
self.rerange_index = rerange_index(batch_size, samples_each_class)
|
||||
self.thresh = 20.0
|
||||
self.beta = 100000
|
||||
|
||||
self.beta = 100000
|
||||
|
||||
def surrogate_function(self, beta, theta, bias):
|
||||
x = theta * paddle.exp(bias)
|
||||
x = theta * paddle.exp(bias)
|
||||
output = paddle.log(1 + beta * x) / math.log(1 + beta)
|
||||
return output
|
||||
|
||||
def surrogate_function_approximate(self, beta, theta, bias):
|
||||
output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta)
|
||||
output = (
|
||||
paddle.log(theta) + bias + math.log(beta)) / math.log(1 + beta)
|
||||
return output
|
||||
|
||||
def surrogate_function_stable(self, beta, theta, target, thresh):
|
||||
max_gap = paddle.to_tensor(thresh, dtype='float32')
|
||||
max_gap.stop_gradient = True
|
||||
|
||||
|
||||
target_max = paddle.maximum(target, max_gap)
|
||||
target_min = paddle.minimum(target, max_gap)
|
||||
|
||||
|
||||
loss1 = self.surrogate_function(beta, theta, target_min)
|
||||
loss2 = self.surrogate_function_approximate(beta, theta, target_max)
|
||||
bias = self.surrogate_function(beta, theta, max_gap)
|
||||
loss = loss1 + loss2 - bias
|
||||
bias = self.surrogate_function(beta, theta, max_gap)
|
||||
loss = loss1 + loss2 - bias
|
||||
return loss
|
||||
|
||||
def forward(self, input, target=None):
|
||||
features = input["features"]
|
||||
samples_each_class = self.samples_each_class
|
||||
batch_size = self.batch_size
|
||||
rerange_index = self.rerange_index
|
||||
|
||||
batch_size = self.batch_size
|
||||
rerange_index = self.rerange_index
|
||||
|
||||
#calc distance
|
||||
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
|
||||
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
|
||||
|
||||
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
|
||||
diffs = paddle.unsqueeze(
|
||||
features, axis=1) - paddle.unsqueeze(
|
||||
features, axis=0)
|
||||
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
|
||||
|
||||
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
|
||||
rerange_index = paddle.to_tensor(rerange_index)
|
||||
tmp = paddle.gather(tmp, index=rerange_index)
|
||||
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
|
||||
|
||||
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
|
||||
samples_each_class - 1, batch_size - samples_each_class], axis = 1)
|
||||
ignore.stop_gradient = True
|
||||
tmp = paddle.gather(tmp, index=rerange_index)
|
||||
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
|
||||
|
||||
ignore, pos, neg = paddle.split(
|
||||
similary_matrix,
|
||||
num_or_sections=[
|
||||
1, samples_each_class - 1, batch_size - samples_each_class
|
||||
],
|
||||
axis=1)
|
||||
ignore.stop_gradient = True
|
||||
|
||||
pos_max = paddle.max(pos, axis=1, keepdim=True)
|
||||
pos = paddle.exp(pos - pos_max)
|
||||
|
@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer):
|
|||
neg_min = paddle.min(neg, axis=1, keepdim=True)
|
||||
neg = paddle.exp(neg_min - neg)
|
||||
neg_mean = paddle.mean(neg, axis=1, keepdim=True)
|
||||
|
||||
|
||||
bias = pos_max - neg_min
|
||||
theta = paddle.multiply(neg_mean, pos_mean)
|
||||
|
||||
loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh)
|
||||
loss = self.surrogate_function_stable(self.beta, theta, bias,
|
||||
self.thresh)
|
||||
loss = paddle.mean(loss)
|
||||
return {"emlloss": loss}
|
||||
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
import paddle
|
||||
from .comfunc import rerange_index
|
||||
|
||||
|
||||
class MSMLoss(paddle.nn.Layer):
|
||||
"""
|
||||
MSMLoss Loss, based on triplet loss. USE P * K samples.
|
||||
|
@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer):
|
|||
]
|
||||
only consider samples_each_class = 2
|
||||
"""
|
||||
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
|
||||
|
||||
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
|
||||
super(MSMLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.samples_each_class = samples_each_class
|
||||
self.batch_size = batch_size
|
||||
self.rerange_index = rerange_index(batch_size, samples_each_class)
|
||||
self.batch_size = batch_size
|
||||
self.rerange_index = rerange_index(batch_size, samples_each_class)
|
||||
|
||||
def forward(self, input, target=None):
|
||||
#normalization
|
||||
features = input["features"]
|
||||
features = self._nomalize(features)
|
||||
samples_each_class = self.samples_each_class
|
||||
rerange_index = paddle.to_tensor(self.rerange_index)
|
||||
rerange_index = paddle.to_tensor(self.rerange_index)
|
||||
|
||||
#calc sm
|
||||
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
|
||||
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
|
||||
|
||||
#rerange
|
||||
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
|
||||
tmp = paddle.gather(tmp, index=rerange_index)
|
||||
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
|
||||
|
||||
#split
|
||||
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
|
||||
samples_each_class - 1, -1], axis = 1)
|
||||
ignore.stop_gradient = True
|
||||
diffs = paddle.unsqueeze(
|
||||
features, axis=1) - paddle.unsqueeze(
|
||||
features, axis=0)
|
||||
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
|
||||
|
||||
hard_pos = paddle.max(pos)
|
||||
#rerange
|
||||
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
|
||||
tmp = paddle.gather(tmp, index=rerange_index)
|
||||
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
|
||||
|
||||
#split
|
||||
ignore, pos, neg = paddle.split(
|
||||
similary_matrix,
|
||||
num_or_sections=[1, samples_each_class - 1, -1],
|
||||
axis=1)
|
||||
ignore.stop_gradient = True
|
||||
|
||||
hard_pos = paddle.max(pos)
|
||||
hard_neg = paddle.min(neg)
|
||||
|
||||
loss = hard_pos + self.margin - hard_neg
|
||||
loss = paddle.nn.ReLU()(loss)
|
||||
loss = paddle.nn.ReLU()(loss)
|
||||
return {"msmloss": loss}
|
||||
|
||||
def _nomalize(self, input):
|
||||
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
return paddle.divide(input, input_norm)
|
||||
|
|
@ -3,12 +3,12 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
import paddle
|
||||
|
||||
|
||||
class NpairsLoss(paddle.nn.Layer):
|
||||
|
||||
def __init__(self, reg_lambda=0.01):
|
||||
super(NpairsLoss, self).__init__()
|
||||
self.reg_lambda = reg_lambda
|
||||
|
||||
|
||||
def forward(self, input, target=None):
|
||||
"""
|
||||
anchor and positive(should include label)
|
||||
|
@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer):
|
|||
features = input["features"]
|
||||
reg_lambda = self.reg_lambda
|
||||
batch_size = features.shape[0]
|
||||
fea_dim = features.shape[1]
|
||||
fea_dim = features.shape[1]
|
||||
num_class = batch_size // 2
|
||||
|
||||
|
||||
#reshape
|
||||
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
|
||||
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1)
|
||||
anc_feas = paddle.squeeze(anc_feas, axis=1)
|
||||
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
|
||||
anc_feas = paddle.squeeze(anc_feas, axis=1)
|
||||
pos_feas = paddle.squeeze(pos_feas, axis=1)
|
||||
|
||||
|
||||
#get simi matrix
|
||||
similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix
|
||||
similarity_matrix = paddle.matmul(
|
||||
anc_feas, pos_feas, transpose_y=True) #get similarity matrix
|
||||
sparse_labels = paddle.arange(0, num_class, dtype='int64')
|
||||
xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean
|
||||
|
||||
xentloss = paddle.nn.CrossEntropyLoss()(
|
||||
similarity_matrix, sparse_labels) #by default: mean
|
||||
|
||||
#l2 norm
|
||||
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
|
||||
l2loss = 0.5 * reg_lambda * reg
|
||||
return {"npairsloss": xentloss + l2loss}
|
||||
|
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||
import paddle
|
||||
from .comfunc import rerange_index
|
||||
|
||||
|
||||
class TriHardLoss(paddle.nn.Layer):
|
||||
"""
|
||||
TriHard Loss, based on triplet loss. USE P * K samples.
|
||||
|
@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer):
|
|||
]
|
||||
only consider samples_each_class = 2
|
||||
"""
|
||||
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
|
||||
|
||||
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
|
||||
super(TriHardLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.samples_each_class = samples_each_class
|
||||
self.batch_size = batch_size
|
||||
self.rerange_index = rerange_index(batch_size, samples_each_class)
|
||||
self.batch_size = batch_size
|
||||
self.rerange_index = rerange_index(batch_size, samples_each_class)
|
||||
|
||||
def forward(self, input, target=None):
|
||||
features = input["features"]
|
||||
assert (self.batch_size == features.shape[0])
|
||||
|
||||
|
||||
#normalization
|
||||
features = self._nomalize(features)
|
||||
samples_each_class = self.samples_each_class
|
||||
rerange_index = paddle.to_tensor(self.rerange_index)
|
||||
rerange_index = paddle.to_tensor(self.rerange_index)
|
||||
|
||||
#calc sm
|
||||
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
|
||||
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
|
||||
|
||||
diffs = paddle.unsqueeze(
|
||||
features, axis=1) - paddle.unsqueeze(
|
||||
features, axis=0)
|
||||
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
|
||||
|
||||
#rerange
|
||||
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
|
||||
tmp = paddle.gather(tmp, index=rerange_index)
|
||||
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
|
||||
|
||||
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
|
||||
tmp = paddle.gather(tmp, index=rerange_index)
|
||||
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
|
||||
|
||||
#split
|
||||
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
|
||||
samples_each_class - 1, -1], axis = 1)
|
||||
|
||||
ignore.stop_gradient = True
|
||||
hard_pos = paddle.max(pos, axis=1)
|
||||
ignore, pos, neg = paddle.split(
|
||||
similary_matrix,
|
||||
num_or_sections=[1, samples_each_class - 1, -1],
|
||||
axis=1)
|
||||
|
||||
ignore.stop_gradient = True
|
||||
hard_pos = paddle.max(pos, axis=1)
|
||||
hard_neg = paddle.min(neg, axis=1)
|
||||
|
||||
loss = hard_pos + self.margin - hard_neg
|
||||
loss = paddle.nn.ReLU()(loss)
|
||||
loss = paddle.nn.ReLU()(loss)
|
||||
loss = paddle.mean(loss)
|
||||
return {"trihardloss": loss}
|
||||
|
||||
def _nomalize(self, input):
|
||||
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
return paddle.divide(input, input_norm)
|
||||
|
Loading…
Reference in New Issue