PaddleClas/ppcls/loss/contrasiveloss.py

153 lines
5.6 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Dict
import paddle
import paddle.nn as nn
from ppcls.loss.xbm import CrossBatchMemory
class ContrastiveLoss(nn.Layer):
"""ContrastiveLoss
Args:
margin (float): margin
embedding_size (int): number of embedding's dimension
normalize_feature (bool, optional): whether to normalize embedding. Defaults to True.
epsilon (float, optional): epsilon. Defaults to 1e-5.
feature_from (str, optional): which key embedding from input dict. Defaults to "features".
"""
def __init__(self,
margin: float,
embedding_size: int,
normalize_feature=True,
epsilon: float=1e-5,
feature_from: str="features"):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.embedding_size = embedding_size
self.normalize_feature = normalize_feature
self.epsilon = epsilon
self.feature_from = feature_from
def forward(self, input: Dict[str, paddle.Tensor],
target: paddle.Tensor) -> Dict[str, paddle.Tensor]:
feats = input[self.feature_from]
labels = target
# normalize along feature dim
if self.normalize_feature:
feats = nn.functional.normalize(feats, p=2, axis=1)
# squeeze labels to shape (batch_size, )
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
loss = self._compute_loss(feats, target, feats, target)
return {'ContrastiveLoss': loss}
def _compute_loss(self,
inputs_q: paddle.Tensor,
targets_q: paddle.Tensor,
inputs_k: paddle.Tensor,
targets_k: paddle.Tensor) -> paddle.Tensor:
batch_size = inputs_q.shape[0]
# Compute similarity matrix
sim_mat = paddle.matmul(inputs_q, inputs_k.t())
loss = []
for i in range(batch_size):
pos_pair_ = paddle.masked_select(sim_mat[i],
targets_q[i] == targets_k)
pos_pair_ = paddle.masked_select(pos_pair_,
pos_pair_ < 1 - self.epsilon)
neg_pair_ = paddle.masked_select(sim_mat[i],
targets_q[i] != targets_k)
neg_pair = paddle.masked_select(neg_pair_, neg_pair_ > self.margin)
pos_loss = paddle.sum(-pos_pair_ + 1)
if len(neg_pair) > 0:
neg_loss = paddle.sum(neg_pair)
else:
neg_loss = 0
loss.append(pos_loss + neg_loss)
loss = sum(loss) / batch_size
return loss
class ContrastiveLoss_XBM(ContrastiveLoss):
"""ContrastiveLoss with CrossBatchMemory
Args:
xbm_size (int): size of memory bank
xbm_weight (int): weight of CrossBatchMemory's loss
start_iter (int): store embeddings after start_iter
margin (float): margin
embedding_size (int): number of embedding's dimension
epsilon (float, optional): epsilon. Defaults to 1e-5.
normalize_feature (bool, optional): whether to normalize embedding. Defaults to True.
feature_from (str, optional): which key embedding from input dict. Defaults to "features".
"""
def __init__(self,
xbm_size: int,
xbm_weight: int,
start_iter: int,
margin: float,
embedding_size: int,
epsilon: float=1e-5,
normalize_feature=True,
feature_from: str="features"):
super(ContrastiveLoss_XBM, self).__init__(
margin, embedding_size, normalize_feature, epsilon, feature_from)
self.xbm = CrossBatchMemory(xbm_size, embedding_size)
self.xbm_weight = xbm_weight
self.start_iter = start_iter
self.iter = 0
def __call__(self, input: Dict[str, paddle.Tensor],
target: paddle.Tensor) -> Dict[str, paddle.Tensor]:
feats = input[self.feature_from]
labels = target
# normalize along feature dim
if self.normalize_feature:
feats = nn.functional.normalize(feats, p=2, axis=1)
# squeeze labels to shape (batch_size, )
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
loss = self._compute_loss(feats, labels, feats, labels)
# compute contrastive loss from memory bank
self.iter += 1
if self.iter > self.start_iter:
self.xbm.enqueue_dequeue(feats.detach(), labels.detach())
xbm_feats, xbm_labels = self.xbm.get()
xbm_loss = self._compute_loss(feats, labels, xbm_feats, xbm_labels)
loss = loss + self.xbm_weight * xbm_loss
return {'ContrastiveLoss_XBM': loss}