PaddleClas/ppcls/loss/npairsloss.py

44 lines
1.6 KiB
Python
Raw Normal View History

2021-05-31 14:20:48 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
2021-06-03 15:17:49 +08:00
2021-05-31 14:20:48 +08:00
class NpairsLoss(paddle.nn.Layer):
"""Npair_loss_
paper [Improved deep metric learning with multi-class N-pair loss objective](https://dl.acm.org/doi/10.5555/3157096.3157304)
code reference: https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss
"""
2021-05-31 14:20:48 +08:00
def __init__(self, reg_lambda=0.01):
super(NpairsLoss, self).__init__()
self.reg_lambda = reg_lambda
2021-06-03 15:17:49 +08:00
2021-05-31 14:20:48 +08:00
def forward(self, input, target=None):
"""
anchor and positive(should include label)
"""
features = input["features"]
reg_lambda = self.reg_lambda
batch_size = features.shape[0]
2021-06-03 15:17:49 +08:00
fea_dim = features.shape[1]
2021-05-31 14:20:48 +08:00
num_class = batch_size // 2
2021-06-03 15:17:49 +08:00
2021-05-31 14:20:48 +08:00
#reshape
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
2021-06-03 15:17:49 +08:00
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
anc_feas = paddle.squeeze(anc_feas, axis=1)
2021-05-31 14:20:48 +08:00
pos_feas = paddle.squeeze(pos_feas, axis=1)
2021-06-03 15:17:49 +08:00
2021-05-31 14:20:48 +08:00
#get simi matrix
2021-06-03 15:17:49 +08:00
similarity_matrix = paddle.matmul(
anc_feas, pos_feas, transpose_y=True) #get similarity matrix
2021-05-31 14:20:48 +08:00
sparse_labels = paddle.arange(0, num_class, dtype='int64')
2021-06-03 15:17:49 +08:00
xentloss = paddle.nn.CrossEntropyLoss()(
similarity_matrix, sparse_labels) #by default: mean
2021-05-31 14:20:48 +08:00
#l2 norm
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
l2loss = 0.5 * reg_lambda * reg
return {"npairsloss": xentloss + l2loss}