mirror of https://github.com/sthalles/SimCLR.git
fix loss function labels
parent
cfced8c7c4
commit
4c056cb919
31
simclr.py
31
simclr.py
|
@ -34,19 +34,38 @@ class SimCLR(object):
|
|||
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
|
||||
|
||||
def info_nce_loss(self, features):
|
||||
batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device)
|
||||
batch_targets = torch.cat(self.args.n_views * [batch_targets])
|
||||
|
||||
labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
|
||||
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
|
||||
labels = labels.to(self.args.device)
|
||||
|
||||
features = F.normalize(features, dim=1)
|
||||
|
||||
similarity_matrix = torch.matmul(features, features.T)
|
||||
# assert similarity_matrix.shape == (
|
||||
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
|
||||
# assert similarity_matrix.shape == labels.shape
|
||||
|
||||
mask = torch.eye(len(batch_targets)).to(self.args.device)
|
||||
similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1)
|
||||
similarities = similarities / self.args.temperature
|
||||
return similarities, batch_targets
|
||||
# discard the main diagonal from both: labels and similarities matrix
|
||||
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
|
||||
labels = labels[~mask].view(labels.shape[0], -1)
|
||||
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
|
||||
# assert similarity_matrix.shape == labels.shape
|
||||
|
||||
# select and combine multiple positives
|
||||
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
|
||||
|
||||
# if there is more than one potive (n_views >= 2) combine the multiple positives
|
||||
positives = positives.mean(dim=1).unsqueeze(1)
|
||||
|
||||
# select only the negatives the negatives
|
||||
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
|
||||
|
||||
logits = torch.cat([positives, negatives], dim=1)
|
||||
labels = torch.zeros(logits.shape[0]).to(self.args.device)
|
||||
|
||||
logits = logits / self.args.temperature
|
||||
return logits, labels
|
||||
|
||||
def train(self, train_loader):
|
||||
|
||||
|
|
Loading…
Reference in New Issue