2020-02-17 16:05:44 -03:00
|
|
|
import cv2
|
|
|
|
import numpy as np
|
2020-02-24 18:23:44 -03:00
|
|
|
import torch
|
2020-02-25 11:32:40 -03:00
|
|
|
import torchvision.transforms as transforms
|
2020-02-17 16:05:44 -03:00
|
|
|
|
2020-02-20 10:02:03 -03:00
|
|
|
np.random.seed(0)
|
2020-02-29 07:53:14 -03:00
|
|
|
cos1d = torch.nn.CosineSimilarity(dim=1)
|
|
|
|
cos2d = torch.nn.CosineSimilarity(dim=2)
|
2020-02-20 10:02:03 -03:00
|
|
|
|
2020-02-17 16:05:44 -03:00
|
|
|
|
2020-02-24 18:23:44 -03:00
|
|
|
def get_negative_mask(batch_size):
|
2020-02-25 11:32:40 -03:00
|
|
|
# return a mask that removes the similarity score of equal/similar images.
|
|
|
|
# this function ensures that only distinct pair of images get their similarity scores
|
|
|
|
# passed as negative examples
|
2020-02-24 18:23:44 -03:00
|
|
|
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
|
|
|
|
for i in range(batch_size):
|
|
|
|
negative_mask[i, i] = 0
|
|
|
|
negative_mask[i, i + batch_size] = 0
|
|
|
|
return negative_mask
|
|
|
|
|
|
|
|
|
2020-02-17 16:05:44 -03:00
|
|
|
class GaussianBlur(object):
|
2020-02-25 11:32:40 -03:00
|
|
|
# Implements Gaussian blur as described in the SimCLR paper
|
2020-02-17 16:05:44 -03:00
|
|
|
def __init__(self, min=0.1, max=2.0, kernel_size=9):
|
|
|
|
self.min = min
|
|
|
|
self.max = max
|
|
|
|
self.kernel_size = kernel_size
|
|
|
|
|
|
|
|
def __call__(self, sample):
|
|
|
|
sample = np.array(sample)
|
|
|
|
|
2020-02-18 17:21:50 -03:00
|
|
|
# blur the image with a 50% chance
|
2020-02-18 15:06:14 -03:00
|
|
|
prob = np.random.random_sample()
|
|
|
|
|
|
|
|
if prob < 0.5:
|
|
|
|
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
|
|
|
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
|
|
|
|
|
|
|
return sample
|
2020-02-20 10:02:03 -03:00
|
|
|
|
2020-02-25 11:32:40 -03:00
|
|
|
|
2020-02-29 07:53:14 -03:00
|
|
|
def get_augmentation_transform(s, crop_size):
|
2020-02-25 11:32:40 -03:00
|
|
|
# get a set of data augmentation transformations as described in the SimCLR paper.
|
|
|
|
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
|
|
|
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
|
2020-02-29 07:53:14 -03:00
|
|
|
transforms.RandomResizedCrop(crop_size),
|
2020-02-25 11:32:40 -03:00
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.RandomApply([color_jitter], p=0.8),
|
|
|
|
transforms.RandomGrayscale(p=0.2),
|
|
|
|
GaussianBlur(),
|
|
|
|
transforms.ToTensor()])
|
|
|
|
return data_aug_ope
|
|
|
|
|
2020-02-29 07:53:14 -03:00
|
|
|
|
|
|
|
def _dot_simililarity_dim1(x, y):
|
|
|
|
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2))
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
def _dot_simililarity_dim2(x, y):
|
|
|
|
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
def _cosine_simililarity_dim1(x, y):
|
|
|
|
v = cos1d(x, y)
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
def _cosine_simililarity_dim2(x, y):
|
|
|
|
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
def get_similarity_function(use_cosine_similarity):
|
|
|
|
if use_cosine_similarity:
|
|
|
|
return _cosine_simililarity_dim1, _cosine_simililarity_dim2
|
|
|
|
else:
|
|
|
|
return _dot_simililarity_dim1, _dot_simililarity_dim2
|