PaddleClas/ppcls/loss/softsuploss.py

75 lines
3.3 KiB
Python

# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
class SoftSupConLoss(nn.Layer):
"""
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR
"""
def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07):
super(SoftSupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def __call__(self, feat, batch, max_probs=None, labels=None, mask=None, reduction="mean", select_matrix=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
feat: hidden vector of shape [batch_size, n_views, ...].
labels: ground truth of shape [batch_size].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
max_probs = batch['max_probs']
labels = batch['p_targets_u_w']
# reduction = batch['reduction']
batch_size = feat.shape[0]
if labels is not None:
labels = labels.reshape((-1, 1))
mask = paddle.equal(labels, labels.T).astype('float32')
max_probs = max_probs.reshape((-1, 1))
score_mask = paddle.matmul(max_probs, max_probs.T)
mask = paddle.multiply(mask, score_mask)
contrast_count = feat.shape[1]
contrast_feat = paddle.concat(paddle.unbind(feat, axis=1), axis=0) # (2n, d)
if self.contrast_mode == 'all':
anchor_feat = contrast_feat
anchor_count = contrast_count
anchor_dot_contrast = paddle.matmul(anchor_feat, contrast_feat.T) / self.temperature
logits_max = anchor_dot_contrast.max(axis=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
mask = paddle.concat([mask, mask], axis=0)
mask = paddle.concat([mask, mask], axis=1)
logits_mask = 1 - paddle.eye(batch_size * contrast_count, dtype=paddle.float64)
mask = mask * logits_mask
exp_logits = paddle.exp(logits) * logits_mask
log_prob = logits - paddle.log(exp_logits.sum(axis=1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(axis=1) / mask.sum(axis=1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.reshape((anchor_count, batch_size))
if reduction == 'mean':
loss = loss.mean()
return {"SoftSupConLoss": loss}