mirror of https://github.com/facebookresearch/deit
124 lines
3.4 KiB
Python
124 lines
3.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
"""
|
|
3Augment implementation
|
|
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
|
|
and timm DA(https://github.com/rwightman/pytorch-image-models)
|
|
"""
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
|
|
|
import numpy as np
|
|
from torchvision import datasets, transforms
|
|
import random
|
|
|
|
|
|
|
|
from PIL import ImageFilter, ImageOps
|
|
import torchvision.transforms.functional as TF
|
|
|
|
|
|
class GaussianBlur(object):
|
|
"""
|
|
Apply Gaussian Blur to the PIL image.
|
|
"""
|
|
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
|
|
self.prob = p
|
|
self.radius_min = radius_min
|
|
self.radius_max = radius_max
|
|
|
|
def __call__(self, img):
|
|
do_it = random.random() <= self.prob
|
|
if not do_it:
|
|
return img
|
|
|
|
img = img.filter(
|
|
ImageFilter.GaussianBlur(
|
|
radius=random.uniform(self.radius_min, self.radius_max)
|
|
)
|
|
)
|
|
return img
|
|
|
|
class Solarization(object):
|
|
"""
|
|
Apply Solarization to the PIL image.
|
|
"""
|
|
def __init__(self, p=0.2):
|
|
self.p = p
|
|
|
|
def __call__(self, img):
|
|
if random.random() < self.p:
|
|
return ImageOps.solarize(img)
|
|
else:
|
|
return img
|
|
|
|
class gray_scale(object):
|
|
"""
|
|
Apply Solarization to the PIL image.
|
|
"""
|
|
def __init__(self, p=0.2):
|
|
self.p = p
|
|
self.transf = transforms.Grayscale(3)
|
|
|
|
def __call__(self, img):
|
|
if random.random() < self.p:
|
|
return self.transf(img)
|
|
else:
|
|
return img
|
|
|
|
|
|
|
|
class horizontal_flip(object):
|
|
"""
|
|
Apply Solarization to the PIL image.
|
|
"""
|
|
def __init__(self, p=0.2,activate_pred=False):
|
|
self.p = p
|
|
self.transf = transforms.RandomHorizontalFlip(p=1.0)
|
|
|
|
def __call__(self, img):
|
|
if random.random() < self.p:
|
|
return self.transf(img)
|
|
else:
|
|
return img
|
|
|
|
|
|
|
|
def new_data_aug_generator(args = None):
|
|
img_size = args.input_size
|
|
remove_random_resized_crop = args.src
|
|
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
|
primary_tfl = []
|
|
scale=(0.08, 1.0)
|
|
interpolation='bicubic'
|
|
if remove_random_resized_crop:
|
|
primary_tfl = [
|
|
transforms.Resize(img_size, interpolation=3),
|
|
transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
|
|
transforms.RandomHorizontalFlip()
|
|
]
|
|
else:
|
|
primary_tfl = [
|
|
RandomResizedCropAndInterpolation(
|
|
img_size, scale=scale, interpolation=interpolation),
|
|
transforms.RandomHorizontalFlip()
|
|
]
|
|
|
|
|
|
secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
|
|
Solarization(p=1.0),
|
|
GaussianBlur(p=1.0)])]
|
|
|
|
if args.color_jitter is not None and not args.color_jitter==0:
|
|
secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
|
|
final_tfl = [
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=torch.tensor(mean),
|
|
std=torch.tensor(std))
|
|
]
|
|
return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
|