2020-03-13 22:56:04 -03:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
class NTXentLoss(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
|
|
|
super(NTXentLoss, self).__init__()
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.temperature = temperature
|
|
|
|
self.device = device
|
|
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
2020-03-15 21:55:00 -03:00
|
|
|
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
|
2020-03-13 22:56:04 -03:00
|
|
|
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
2020-03-15 21:55:00 -03:00
|
|
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
2020-03-13 22:56:04 -03:00
|
|
|
|
|
|
|
def _get_similarity_function(self, use_cosine_similarity):
|
|
|
|
if use_cosine_similarity:
|
|
|
|
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
|
|
|
return self._cosine_simililarity
|
|
|
|
else:
|
|
|
|
return self._dot_simililarity
|
|
|
|
|
|
|
|
def _get_correlated_mask(self):
|
2020-03-15 21:55:00 -03:00
|
|
|
diag = np.eye(2 * self.batch_size)
|
|
|
|
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
|
|
|
|
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
|
|
|
|
mask = torch.from_numpy((diag + l1 + l2))
|
|
|
|
mask = (1 - mask).type(torch.bool)
|
|
|
|
return mask.to(self.device)
|
2020-03-13 22:56:04 -03:00
|
|
|
|
2020-03-14 07:01:49 -03:00
|
|
|
@staticmethod
|
|
|
|
def _dot_simililarity(x, y):
|
2020-03-13 22:56:04 -03:00
|
|
|
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
|
|
|
# x shape: (N, 1, C)
|
|
|
|
# y shape: (1, C, 2N)
|
|
|
|
# v shape: (N, 2N)
|
|
|
|
return v
|
|
|
|
|
|
|
|
def _cosine_simililarity(self, x, y):
|
|
|
|
# x shape: (N, 1, C)
|
|
|
|
# y shape: (1, 2N, C)
|
|
|
|
# v shape: (N, 2N)
|
|
|
|
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
|
|
|
return v
|
|
|
|
|
|
|
|
def forward(self, zis, zjs):
|
2020-03-15 21:55:00 -03:00
|
|
|
representations = torch.cat([zjs, zis], dim=0)
|
|
|
|
|
|
|
|
similarity_matrix = self.similarity_function(representations, representations)
|
|
|
|
|
|
|
|
# filter out the scores from the positive samples
|
|
|
|
l_pos = torch.diag(similarity_matrix, self.batch_size)
|
|
|
|
r_pos = torch.diag(similarity_matrix, -self.batch_size)
|
|
|
|
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
|
2020-03-13 22:56:04 -03:00
|
|
|
|
2020-03-15 21:55:00 -03:00
|
|
|
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
|
|
|
|
|
|
|
|
logits = torch.cat((positives, negatives), dim=1)
|
2020-03-13 22:56:04 -03:00
|
|
|
logits /= self.temperature
|
|
|
|
|
2020-03-15 21:55:00 -03:00
|
|
|
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
|
|
|
|
loss = self.criterion(logits, labels)
|
|
|
|
|
|
|
|
return loss / (2 * self.batch_size)
|