mirror of https://github.com/sthalles/SimCLR.git
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
import torchvision.transforms as transforms
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
class DataTransform(object):
|
|
def __init__(self, transform):
|
|
self.transform = transform
|
|
|
|
def __call__(self, sample):
|
|
xi = self.transform(sample)
|
|
xj = self.transform(sample)
|
|
return xi, xj
|
|
|
|
|
|
class GaussianBlur(object):
|
|
# Implements Gaussian blur as described in the SimCLR paper
|
|
def __init__(self, kernel_size, min=0.1, max=2.0):
|
|
self.min = min
|
|
self.max = max
|
|
# kernel size is set to be 10% of the image height/width
|
|
self.kernel_size = kernel_size
|
|
|
|
def __call__(self, sample):
|
|
sample = np.array(sample)
|
|
|
|
# blur the image with a 50% chance
|
|
prob = np.random.random_sample()
|
|
|
|
if prob < 0.5:
|
|
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
|
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
|
|
|
return sample
|
|
|
|
|
|
def get_data_transform_opes(s, crop_size):
|
|
# get a set of data augmentation transformations as described in the SimCLR paper.
|
|
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
|
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomApply([color_jitter], p=0.8),
|
|
transforms.RandomGrayscale(p=0.2),
|
|
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
|
transforms.ToTensor()])
|
|
return data_transforms
|