mirror of https://github.com/sthalles/SimCLR.git
60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
|
|
class NTXentLoss(torch.nn.Module):
|
|
|
|
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
|
super(NTXentLoss, self).__init__()
|
|
self.batch_size = batch_size
|
|
self.temperature = temperature
|
|
self.device = device
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
self.mask_samples_from_same_repr = self._get_correlated_mask()
|
|
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
|
self.labels = self._get_labels()
|
|
|
|
def _get_similarity_function(self, use_cosine_similarity):
|
|
if use_cosine_similarity:
|
|
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
|
return self._cosine_simililarity
|
|
else:
|
|
return self._dot_simililarity
|
|
|
|
def _get_labels(self):
|
|
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size)
|
|
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=self.batch_size - 1)
|
|
labels = torch.from_numpy((l1 + l2).astype(np.int))
|
|
return labels.to(self.device)
|
|
|
|
def _get_correlated_mask(self):
|
|
return (1 - torch.eye(2 * self.batch_size)).type(torch.bool)
|
|
|
|
@staticmethod
|
|
def _dot_simililarity(x, y):
|
|
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
|
# x shape: (N, 1, C)
|
|
# y shape: (1, C, 2N)
|
|
# v shape: (N, 2N)
|
|
return v
|
|
|
|
def _cosine_simililarity(self, x, y):
|
|
# x shape: (N, 1, C)
|
|
# y shape: (1, 2N, C)
|
|
# v shape: (N, 2N)
|
|
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
|
return v
|
|
|
|
def forward(self, zis, zjs):
|
|
negatives = torch.cat([zjs, zis], dim=0)
|
|
|
|
logits = self.similarity_function(negatives, negatives)
|
|
logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1)
|
|
logits /= self.temperature
|
|
# assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str(
|
|
# logits.shape)
|
|
|
|
probs = self.softmax(logits)
|
|
loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1))
|
|
return loss
|