feat(loss): add TripletAngularMarginLoss_XBM and refine XBM queue

This commit is contained in:
HydrogenSulfate 2022-12-13 07:33:28 +00:00
parent 4060957669
commit 0288285c91
3 changed files with 143 additions and 4 deletions

View File

@ -14,7 +14,7 @@ from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2
from .tripletangularmarginloss import TripletAngularMarginLoss
from .tripletangularmarginloss import TripletAngularMarginLoss, TripletAngularMarginLoss_XBM
from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss

View File

@ -18,6 +18,7 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
from ppcls.loss.xbm import CrossBatchMemory
class TripletAngularMarginLoss(nn.Layer):
@ -113,3 +114,128 @@ class TripletAngularMarginLoss(nn.Layer):
) * self.absolute_loss_weight + loss.mean()
return {"TripletAngularMarginLoss": loss}
class TripletAngularMarginLoss_XBM(TripletAngularMarginLoss):
"""TripletAngularMarginLoss combined with CrossBatchMemory
Args:
start_iter: (int): from which step CrossBatchMemory is enabled
xbm_size: (int): Size of CrossBatchMemory
xbm_weight: (float): Weight of CrossBatchMemory loss
feat_dim: (int): Channels of features in CrossBatchMemory
margin (float, optional): angular margin. Defaults to 0.5.
normalize_feature (bool, optional): whether to apply L2-norm in feature before computing distance(cos-similarity). Defaults to True.
reduction (str, optional): reducing option within an batch . Defaults to "mean".
add_absolute (bool, optional): whether add absolute loss within d(a,p) or d(a,n). Defaults to False.
absolute_loss_weight (float, optional): weight for absolute loss. Defaults to 1.0.
ap_value (float, optional): weight for d(a, p). Defaults to 0.9.
an_value (float, optional): weight for d(a, n). Defaults to 0.5.
feature_from (str, optional): which key feature from. Defaults to "features".
"""
def __init__(self,
start_iter: int,
xbm_size: int,
xbm_weight: float,
feat_dim: int,
margin=0.5,
normalize_feature=True,
reduction="mean",
add_absolute=False,
absolute_loss_weight=1.0,
ap_value=0.9,
an_value=0.5,
feature_from="features"):
super(TripletAngularMarginLoss_XBM, self).__init__(
margin, normalize_feature, reduction, add_absolute,
absolute_loss_weight, ap_value, an_value, feature_from)
self.start_iter = start_iter
self.xbm = CrossBatchMemory(xbm_size, feat_dim)
self.xbm_weight = xbm_weight
self.inf = 10 # 10 is big enough as inf for cos-similarity
self.register_buffer("iter", paddle.to_tensor(0, dtype="int64"))
def forward(self, input, target):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
feats = input[self.feature_from]
if self.normalize_feature:
feats = nn.functional.normalize(feats, p=2, axis=1)
labels = target
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
loss = self._compute_loss(feats, labels, feats, labels)
# XBM loss below
self.iter += 1
if self.iter.item() > self.start_iter:
self.xbm.enqueue_dequeue(feats.detach(), labels.detach())
xbm_feats, xbm_labels = self.xbm.get()
xbm_loss = self._compute_loss(feats, labels, xbm_feats, xbm_labels)
loss = loss + self.xbm_weight * xbm_loss
return {"TripletAngularMarginLoss_XBM": loss}
def _masked_max(self, tensor, mask, axis):
masked = paddle.multiply(tensor, mask.astype(tensor.dtype))
neg_inf = paddle.zeros_like(tensor)
neg_inf.stop_gradient = True
neg_inf[paddle.logical_not(mask)] = -self.inf
return paddle.max(masked + neg_inf, axis=axis, keepdim=True)
def _masked_min(self, tensor, mask, axis):
masked = paddle.multiply(tensor, mask.astype(tensor.dtype))
pos_inf = paddle.zeros_like(tensor)
pos_inf.stop_gradient = True
pos_inf[paddle.logical_not(mask)] = self.inf
return paddle.min(masked + pos_inf, axis=axis, keepdim=True)
def _compute_loss(self,
inputs_q: paddle.Tensor,
targets_q: paddle.Tensor,
inputs_k: paddle.Tensor,
targets_k: paddle.Tensor) -> paddle.Tensor:
Q = inputs_q.shape[0]
K = inputs_k.shape[0]
# compute distance(cos-similarity)
dist = paddle.matmul(inputs_q, inputs_k.t()) # [Q, K]
# hard negative mining
is_pos = paddle.expand(paddle.unsqueeze(targets_q, 1), (Q, K)).equal(
paddle.expand(paddle.unsqueeze(targets_k, 1),
(K, Q)).t()) # [Q, K]
is_neg = paddle.expand(paddle.unsqueeze(targets_q, 1),
(Q, K)).not_equal(
paddle.expand(
paddle.unsqueeze(targets_k, 1),
(K, Q)).t()) # [Q, K]
dist_ap = self._masked_min(dist, is_pos, axis=1) # [Q, ]
dist_an = self._masked_max(dist, is_neg, axis=1) # [Q, ]
# Compute ranking hinge loss
y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_ap, dist_an, y)
if self.add_absolute:
absolut_loss_ap = self.ap_value - dist_ap
absolut_loss_ap = paddle.where(absolut_loss_ap > 0,
absolut_loss_ap,
paddle.zeros_like(absolut_loss_ap))
absolut_loss_an = dist_an - self.an_value
absolut_loss_an = paddle.where(absolut_loss_an > 0,
absolut_loss_an,
paddle.ones_like(absolut_loss_an))
loss = (absolut_loss_an.mean() + absolut_loss_ap.mean()
) * self.absolute_loss_weight + loss.mean()
return loss

View File

@ -21,7 +21,7 @@ from typing import Tuple
import paddle
class CrossBatchMemory(object):
class CrossBatchMemory(nn.Layer):
"""
CrossBatchMemory Implementation. refer to "Cross-Batch Memory for Embedding Learning".
@ -33,10 +33,18 @@ class CrossBatchMemory(object):
"""
def __init__(self, size: int, embedding_size: int):
super().__init__()
self.size = size
self.embedding_size = embedding_size
self.feats = paddle.zeros([self.size, self.embedding_size])
self.targets = paddle.zeros([self.size, ], dtype="int64")
# initialize and register feature queue for resume training
feats = paddle.zeros([self.size, self.embedding_size])
self.register_buffer("feats", feats)
# initialize and register label queue for resume training
targets = paddle.zeros([self.size, ], dtype="int64")
self.register_buffer("targets", targets)
self.ptr = 0
# self.accumulated_size = 0
@ -74,3 +82,8 @@ class CrossBatchMemory(object):
self.targets[self.ptr:self.ptr + input_size] = targets
self.ptr += input_size
# self.accumulated_size += input_size
def forward(self, *kargs, **kwargs):
raise NotImplementedError(
"CrossBatchMemory module is for memory-bank, forward method is not needed"
)