import torchvision.transforms as transforms import cv2 import numpy as np class DataTransform(object): def __init__(self, transform): self.transform = transform def __call__(self, sample): xi = self.transform(sample) xj = self.transform(sample) return xi, xj class GaussianBlur(object): # Implements Gaussian blur as described in the SimCLR paper def __init__(self, kernel_size, min=0.1, max=2.0): self.min = min self.max = max # kernel size is set to be 10% of the image height/width self.kernel_size = kernel_size def __call__(self, sample): sample = np.array(sample) # blur the image with a 50% chance prob = np.random.random_sample() if prob < 0.5: sigma = (self.max - self.min) * np.random.random_sample() + self.min sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) return sample def get_data_transform_opes(s, crop_size): # get a set of data augmentation transformations as described in the SimCLR paper. color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size), transforms.RandomHorizontalFlip(), transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2), GaussianBlur(kernel_size=int(0.1 * crop_size)), transforms.ToTensor()]) return data_transforms