SimCLR/utils.py

63 lines
2.1 KiB
Python
Raw Normal View History

2020-02-17 16:05:44 -03:00
import numpy as np
2020-02-24 18:23:44 -03:00
import torch
2020-03-10 10:35:11 -03:00
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
2020-02-17 16:05:44 -03:00
np.random.seed(0)
2020-03-10 20:41:59 -03:00
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
2020-02-17 16:05:44 -03:00
2020-03-12 22:34:21 -03:00
def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored):
2020-03-10 10:35:11 -03:00
# obtain training indices that will be used for validation
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
2020-03-12 22:34:21 -03:00
split = int(np.floor(valid_size * num_train))
2020-03-10 10:35:11 -03:00
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
2020-03-12 22:34:21 -03:00
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, drop_last=True, shuffle=False)
2020-03-10 10:35:11 -03:00
2020-03-12 22:34:21 -03:00
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, drop_last=True)
2020-03-10 10:35:11 -03:00
return train_loader, valid_loader
2020-02-24 18:23:44 -03:00
def get_negative_mask(batch_size):
# return a mask that removes the similarity score of equal/similar images.
# this function ensures that only distinct pair of images get their similarity scores
# passed as negative examples
2020-02-24 18:23:44 -03:00
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
for i in range(batch_size):
negative_mask[i, i] = 0
negative_mask[i, i + batch_size] = 0
return negative_mask
2020-02-29 07:53:14 -03:00
def _dot_simililarity_dim2(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
2020-02-29 07:53:14 -03:00
return v
def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
2020-03-10 20:41:59 -03:00
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
2020-02-29 07:53:14 -03:00
return v
def get_similarity_function(use_cosine_similarity):
if use_cosine_similarity:
2020-03-12 22:34:21 -03:00
return _cosine_simililarity_dim2
2020-02-29 07:53:14 -03:00
else:
2020-03-12 22:34:21 -03:00
return _dot_simililarity_dim2