# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. """ 3Augment implementation Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) and timm DA(https://github.com/rwightman/pytorch-image-models) """ import torch from torchvision import transforms from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor import numpy as np from torchvision import datasets, transforms import random from PIL import ImageFilter, ImageOps import torchvision.transforms.functional as TF class GaussianBlur(object): """ Apply Gaussian Blur to the PIL image. """ def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): self.prob = p self.radius_min = radius_min self.radius_max = radius_max def __call__(self, img): do_it = random.random() <= self.prob if not do_it: return img img = img.filter( ImageFilter.GaussianBlur( radius=random.uniform(self.radius_min, self.radius_max) ) ) return img class Solarization(object): """ Apply Solarization to the PIL image. """ def __init__(self, p=0.2): self.p = p def __call__(self, img): if random.random() < self.p: return ImageOps.solarize(img) else: return img class gray_scale(object): """ Apply Solarization to the PIL image. """ def __init__(self, p=0.2): self.p = p self.transf = transforms.Grayscale(3) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img class horizontal_flip(object): """ Apply Solarization to the PIL image. """ def __init__(self, p=0.2,activate_pred=False): self.p = p self.transf = transforms.RandomHorizontalFlip(p=1.0) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img def new_data_aug_generator(args = None): img_size = args.input_size remove_random_resized_crop = args.src mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] primary_tfl = [] scale=(0.08, 1.0) interpolation='bicubic' if remove_random_resized_crop: primary_tfl = [ transforms.Resize(img_size, interpolation=3), transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), transforms.RandomHorizontalFlip() ] else: primary_tfl = [ RandomResizedCropAndInterpolation( img_size, scale=scale, interpolation=interpolation), transforms.RandomHorizontalFlip() ] secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), Solarization(p=1.0), GaussianBlur(p=1.0)])] if args.color_jitter is not None and not args.color_jitter==0: secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) final_tfl = [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)