mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Updating augmentations, esp randaug to support full torch.Tensor pipeline
This commit is contained in:
parent
ea231079f5
commit
3b181b78d1
@ -24,12 +24,18 @@ import random
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
|
||||
import torch
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms import InterpolationMode
|
||||
import torchvision.transforms.functional as TF
|
||||
try:
|
||||
import torchvision.transforms.v2.functional as TF2
|
||||
except ImportError:
|
||||
TF2 = None
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
@ -42,118 +48,111 @@ _HPARAMS_DEFAULT = dict(
|
||||
img_mean=_FILL,
|
||||
)
|
||||
|
||||
if hasattr(Image, "Resampling"):
|
||||
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
|
||||
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
|
||||
else:
|
||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||
_DEFAULT_INTERPOLATION = Image.BICUBIC
|
||||
|
||||
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
|
||||
_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC
|
||||
|
||||
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
|
||||
def _interpolation(kwargs, basic_only=False):
|
||||
interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
interpolation = random.choice(interpolation)
|
||||
if basic_only:
|
||||
if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR):
|
||||
interpolation = InterpolationMode.BILINEAR
|
||||
return interpolation
|
||||
|
||||
|
||||
def _check_args_tf(kwargs):
|
||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
||||
kwargs.pop('fillcolor')
|
||||
kwargs['resample'] = _interpolation(kwargs)
|
||||
kwargs['interpolation'] = _interpolation(kwargs)
|
||||
|
||||
|
||||
def _check_args_affine(img, kwargs):
|
||||
if isinstance(img, torch.Tensor):
|
||||
kwargs['interpolation'] = _interpolation(kwargs, basic_only=True)
|
||||
else:
|
||||
kwargs['interpolation'] = _interpolation(kwargs)
|
||||
|
||||
|
||||
def shear_x(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
||||
_check_args_affine(img, kwargs)
|
||||
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs)
|
||||
|
||||
|
||||
def shear_y(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_x_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[0]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
|
||||
def translate_y_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[1]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
_check_args_affine(img, kwargs)
|
||||
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs)
|
||||
|
||||
|
||||
def translate_x_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
_check_args_affine(img, kwargs)
|
||||
return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs)
|
||||
|
||||
|
||||
def translate_y_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
_check_args_affine(img, kwargs)
|
||||
return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs)
|
||||
|
||||
|
||||
def translate_x_rel(img, pct, **kwargs):
|
||||
pixels = pct * TF.get_image_size(img)[0]
|
||||
return translate_x_abs(img, pixels, **kwargs)
|
||||
|
||||
|
||||
def translate_y_rel(img, pct, **kwargs):
|
||||
pixels = pct * TF.get_image_size(img)[1]
|
||||
return translate_y_abs(img, pixels, **kwargs)
|
||||
|
||||
|
||||
def rotate(img, degrees, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate(degrees, **kwargs)
|
||||
if _PIL_VER >= (5, 0):
|
||||
w, h = img.size
|
||||
post_trans = (0, 0)
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
angle = -math.radians(degrees)
|
||||
matrix = [
|
||||
round(math.cos(angle), 15),
|
||||
round(math.sin(angle), 15),
|
||||
0.0,
|
||||
round(-math.sin(angle), 15),
|
||||
round(math.cos(angle), 15),
|
||||
0.0,
|
||||
]
|
||||
|
||||
def transform(x, y, matrix):
|
||||
(a, b, c, d, e, f) = matrix
|
||||
return a * x + b * y + c, d * x + e * y + f
|
||||
|
||||
matrix[2], matrix[5] = transform(
|
||||
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
||||
)
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
_check_args_affine(img, kwargs)
|
||||
return TF.rotate(img, degrees, **kwargs)
|
||||
|
||||
|
||||
def auto_contrast(img, **__):
|
||||
return ImageOps.autocontrast(img)
|
||||
return TF.autocontrast(img)
|
||||
|
||||
|
||||
def invert(img, **__):
|
||||
return ImageOps.invert(img)
|
||||
return TF.invert(img)
|
||||
|
||||
|
||||
def equalize(img, **__):
|
||||
return ImageOps.equalize(img)
|
||||
if isinstance(img, torch.Tensor) and img.is_floating_point():
|
||||
if TF2 is None:
|
||||
# FIXME warn / assert?
|
||||
return img
|
||||
return TF2.equalize(img)
|
||||
return TF.equalize(img)
|
||||
|
||||
|
||||
def solarize(img, thresh, **__):
|
||||
return ImageOps.solarize(img, thresh)
|
||||
if isinstance(img, torch.Tensor) and img.is_floating_point():
|
||||
thresh = min(thresh / 255, 1.0)
|
||||
return TF.solarize(img, thresh)
|
||||
|
||||
|
||||
def solarize_add(img, add, thresh=128, **__):
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.is_floating_point():
|
||||
thresh = thresh / 255
|
||||
add = add / 255
|
||||
img_sum = (img + add).clamp_(max=1.0)
|
||||
else:
|
||||
lut.append(i)
|
||||
img_sum = (img + add).clamp_(max=255)
|
||||
return torch.where(img >= thresh, img_sum, img)
|
||||
else:
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
|
||||
return img
|
||||
|
||||
@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__):
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
if bits_to_keep >= 8:
|
||||
return img
|
||||
return ImageOps.posterize(img, bits_to_keep)
|
||||
if isinstance(img, torch.Tensor) and img.is_floating_point():
|
||||
if TF2 is None:
|
||||
# FIXME warn / assert?
|
||||
return img
|
||||
return TF2.posterize(img, bits_to_keep)
|
||||
return TF.posterize(img, bits_to_keep)
|
||||
|
||||
|
||||
def contrast(img, factor, **__):
|
||||
return ImageEnhance.Contrast(img).enhance(factor)
|
||||
return TF.adjust_contrast(img, factor)
|
||||
|
||||
|
||||
def color(img, factor, **__):
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
return TF.adjust_saturation(img, factor)
|
||||
|
||||
|
||||
def brightness(img, factor, **__):
|
||||
return ImageEnhance.Brightness(img).enhance(factor)
|
||||
return TF.adjust_brightness(img, factor)
|
||||
|
||||
|
||||
def sharpness(img, factor, **__):
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
return TF.adjust_sharpness(img, factor)
|
||||
|
||||
|
||||
def gaussian_blur(img, factor, **__):
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
|
||||
if isinstance(img, torch.Tensor):
|
||||
kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive
|
||||
img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor)
|
||||
else:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
|
||||
return img
|
||||
|
||||
|
||||
def gaussian_blur_rand(img, factor, **__):
|
||||
radius_min = 0.1
|
||||
radius_max = 2.0
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
|
||||
return img
|
||||
radius = random.uniform(radius_min, radius_max * factor)
|
||||
return gaussian_blur(img, radius)
|
||||
|
||||
|
||||
def desaturate(img, factor, **_):
|
||||
factor = min(1., max(0., 1. - factor))
|
||||
# enhance factor 0 = grayscale, 1.0 = no-change
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
return TF.adjust_saturation(img, factor)
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
@ -356,7 +364,13 @@ NAME_TO_OP = {
|
||||
|
||||
class AugmentOp:
|
||||
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
prob: float = 0.5,
|
||||
magnitude: float = 10,
|
||||
hparams: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.name = name
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
@ -365,8 +379,8 @@ class AugmentOp:
|
||||
self.magnitude = magnitude
|
||||
self.hparams = hparams.copy()
|
||||
self.kwargs = dict(
|
||||
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
||||
fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
||||
)
|
||||
|
||||
# If magnitude_std is > 0, we introduce some randomness
|
||||
@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):
|
||||
|
||||
class AutoAugment:
|
||||
|
||||
def __init__(self, policy):
|
||||
def __init__(self, policy: List):
|
||||
self.policy = policy
|
||||
|
||||
def __call__(self, img):
|
||||
@ -729,8 +743,14 @@ def rand_augment_ops(
|
||||
):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
return [
|
||||
AugmentOp(
|
||||
name,
|
||||
prob=prob,
|
||||
magnitude=magnitude,
|
||||
hparams=hparams
|
||||
) for name in transforms
|
||||
]
|
||||
|
||||
|
||||
class RandAugment:
|
||||
|
@ -87,7 +87,8 @@ class PrefetchLoader:
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
re_num_splits=0,
|
||||
):
|
||||
|
||||
mean = adapt_to_chs(mean, channels)
|
||||
std = adapt_to_chs(std, channels)
|
||||
|
@ -12,7 +12,8 @@ try:
|
||||
has_interpolation_mode = True
|
||||
except ImportError:
|
||||
has_interpolation_mode = False
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageCms
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
@ -89,6 +90,31 @@ class MaybePILToTensor:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
class ToLab(transforms.ToTensor):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
rgb_profile = ImageCms.createProfile(colorSpace='sRGB')
|
||||
lab_profile = ImageCms.createProfile(colorSpace='LAB')
|
||||
# Create a transform object from the input and output profiles
|
||||
self.rgb_to_lab_transform = ImageCms.buildTransform(
|
||||
inputProfile=rgb_profile,
|
||||
outputProfile=lab_profile,
|
||||
inMode='RGB',
|
||||
outMode='LAB'
|
||||
)
|
||||
|
||||
def __call__(self, pic) -> torch.Tensor:
|
||||
lab_image = ImageCms.applyTransform(
|
||||
im=pic,
|
||||
transform=self.rgb_to_lab_transform
|
||||
)
|
||||
return lab_image
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
||||
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
||||
# removed in Pillow 10.
|
||||
|
@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -84,6 +85,7 @@ def transforms_imagenet_train(
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
use_tensor: Optional[bool] = True, # FIXME forced True for testing
|
||||
):
|
||||
""" ImageNet-oriented image transforms for training.
|
||||
|
||||
@ -111,6 +113,7 @@ def transforms_imagenet_train(
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
use_tensor: Use of float [0, 1.0) tensors for image transforms
|
||||
|
||||
Returns:
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
@ -119,13 +122,18 @@ def transforms_imagenet_train(
|
||||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
if use_tensor:
|
||||
primary_tfl = [MaybeToTensor()]
|
||||
else:
|
||||
primary_tfl = []
|
||||
|
||||
train_crop_mode = train_crop_mode or 'rrc'
|
||||
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
||||
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||
# FIXME integration of RKR is a WIP
|
||||
scale = tuple(scale or (0.8, 1.00))
|
||||
ratio = tuple(ratio or (0.9, 1/.9))
|
||||
primary_tfl = [
|
||||
primary_tfl += [
|
||||
ResizeKeepRatio(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
@ -142,7 +150,7 @@ def transforms_imagenet_train(
|
||||
else:
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||
primary_tfl = [
|
||||
primary_tfl += [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size,
|
||||
scale=scale,
|
||||
@ -166,9 +174,13 @@ def transforms_imagenet_train(
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
if use_tensor:
|
||||
aa_mean = deepcopy(mean)
|
||||
else:
|
||||
aa_mean = tuple([min(255, round(255 * x)) for x in mean])
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||
img_mean=aa_mean,
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = str_to_pil_interp(interpolation)
|
||||
@ -218,10 +230,12 @@ def transforms_imagenet_train(
|
||||
final_tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||
final_tfl += [MaybePILToTensor()]
|
||||
if not use_tensor:
|
||||
final_tfl += [MaybePILToTensor()]
|
||||
else:
|
||||
if not use_tensor:
|
||||
final_tfl += [MaybeToTensor()]
|
||||
final_tfl += [
|
||||
MaybeToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std),
|
||||
|
Loading…
x
Reference in New Issue
Block a user