75 lines
3.3 KiB
Python
75 lines
3.3 KiB
Python
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
|
|
class SoftSupConLoss(nn.Layer):
|
|
"""
|
|
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
|
It also supports the unsupervised contrastive loss in SimCLR
|
|
"""
|
|
def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07):
|
|
super(SoftSupConLoss, self).__init__()
|
|
self.temperature = temperature
|
|
self.contrast_mode = contrast_mode
|
|
self.base_temperature = base_temperature
|
|
|
|
def __call__(self, feat, batch, max_probs=None, labels=None, mask=None, reduction="mean", select_matrix=None):
|
|
"""Compute loss for model. If both `labels` and `mask` are None,
|
|
it degenerates to SimCLR unsupervised loss:
|
|
https://arxiv.org/pdf/2002.05709.pdf
|
|
|
|
Args:
|
|
feat: hidden vector of shape [batch_size, n_views, ...].
|
|
labels: ground truth of shape [batch_size].
|
|
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
|
|
has the same class as sample i. Can be asymmetric.
|
|
Returns:
|
|
A loss scalar.
|
|
"""
|
|
max_probs = batch['max_probs']
|
|
labels = batch['p_targets_u_w']
|
|
# reduction = batch['reduction']
|
|
batch_size = feat.shape[0]
|
|
if labels is not None:
|
|
labels = labels.reshape((-1, 1))
|
|
mask = paddle.equal(labels, labels.T).astype('float32')
|
|
max_probs = max_probs.reshape((-1, 1))
|
|
score_mask = paddle.matmul(max_probs, max_probs.T)
|
|
mask = paddle.multiply(mask, score_mask)
|
|
|
|
contrast_count = feat.shape[1]
|
|
contrast_feat = paddle.concat(paddle.unbind(feat, axis=1), axis=0) # (2n, d)
|
|
if self.contrast_mode == 'all':
|
|
anchor_feat = contrast_feat
|
|
anchor_count = contrast_count
|
|
anchor_dot_contrast = paddle.matmul(anchor_feat, contrast_feat.T) / self.temperature
|
|
logits_max = anchor_dot_contrast.max(axis=1, keepdim=True)
|
|
logits = anchor_dot_contrast - logits_max.detach()
|
|
mask = paddle.concat([mask, mask], axis=0)
|
|
mask = paddle.concat([mask, mask], axis=1)
|
|
|
|
logits_mask = 1 - paddle.eye(batch_size * contrast_count, dtype=paddle.float64)
|
|
mask = mask * logits_mask
|
|
exp_logits = paddle.exp(logits) * logits_mask
|
|
log_prob = logits - paddle.log(exp_logits.sum(axis=1, keepdim=True))
|
|
|
|
mean_log_prob_pos = (mask * log_prob).sum(axis=1) / mask.sum(axis=1)
|
|
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
|
|
loss = loss.reshape((anchor_count, batch_size))
|
|
if reduction == 'mean':
|
|
loss = loss.mean()
|
|
|
|
return {"SoftSupConLoss": loss} |