Updating augmentations, esp randaug to support full torch.Tensor pipeline

This commit is contained in:
Ross Wightman 2024-12-18 12:24:04 -08:00
parent ea231079f5
commit 3b181b78d1
4 changed files with 159 additions and 98 deletions

View File

@ -24,12 +24,18 @@ import random
import math
import re
from functools import partial
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
import torch
import PIL
import numpy as np
from PIL import Image, ImageFilter
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
try:
import torchvision.transforms.v2.functional as TF2
except ImportError:
TF2 = None
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
@ -42,118 +48,111 @@ _HPARAMS_DEFAULT = dict(
img_mean=_FILL,
)
if hasattr(Image, "Resampling"):
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
else:
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_DEFAULT_INTERPOLATION = Image.BICUBIC
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
def _interpolation(kwargs, basic_only=False):
interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
interpolation = random.choice(interpolation)
if basic_only:
if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR):
interpolation = InterpolationMode.BILINEAR
return interpolation
def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)
kwargs['interpolation'] = _interpolation(kwargs)
def _check_args_affine(img, kwargs):
if isinstance(img, torch.Tensor):
kwargs['interpolation'] = _interpolation(kwargs, basic_only=True)
else:
kwargs['interpolation'] = _interpolation(kwargs)
def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs)
def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs)
def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs)
def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * TF.get_image_size(img)[0]
return translate_x_abs(img, pixels, **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * TF.get_image_size(img)[1]
return translate_y_abs(img, pixels, **kwargs)
def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
if _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
return img.rotate(degrees, resample=kwargs['resample'])
_check_args_affine(img, kwargs)
return TF.rotate(img, degrees, **kwargs)
def auto_contrast(img, **__):
return ImageOps.autocontrast(img)
return TF.autocontrast(img)
def invert(img, **__):
return ImageOps.invert(img)
return TF.invert(img)
def equalize(img, **__):
return ImageOps.equalize(img)
if isinstance(img, torch.Tensor) and img.is_floating_point():
if TF2 is None:
# FIXME warn / assert?
return img
return TF2.equalize(img)
return TF.equalize(img)
def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)
if isinstance(img, torch.Tensor) and img.is_floating_point():
thresh = min(thresh / 255, 1.0)
return TF.solarize(img, thresh)
def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
if isinstance(img, torch.Tensor):
if img.is_floating_point():
thresh = thresh / 255
add = add / 255
img_sum = (img + add).clamp_(max=1.0)
else:
lut.append(i)
img_sum = (img + add).clamp_(max=255)
return torch.where(img >= thresh, img_sum, img)
else:
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
return img
@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__):
def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
return ImageOps.posterize(img, bits_to_keep)
if isinstance(img, torch.Tensor) and img.is_floating_point():
if TF2 is None:
# FIXME warn / assert?
return img
return TF2.posterize(img, bits_to_keep)
return TF.posterize(img, bits_to_keep)
def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)
return TF.adjust_contrast(img, factor)
def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)
return TF.adjust_saturation(img, factor)
def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)
return TF.adjust_brightness(img, factor)
def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
return TF.adjust_sharpness(img, factor)
def gaussian_blur(img, factor, **__):
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
if isinstance(img, torch.Tensor):
kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive
img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor)
else:
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
return img
def gaussian_blur_rand(img, factor, **__):
radius_min = 0.1
radius_max = 2.0
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
return img
radius = random.uniform(radius_min, radius_max * factor)
return gaussian_blur(img, radius)
def desaturate(img, factor, **_):
factor = min(1., max(0., 1. - factor))
# enhance factor 0 = grayscale, 1.0 = no-change
return ImageEnhance.Color(img).enhance(factor)
return TF.adjust_saturation(img, factor)
def _randomly_negate(v):
@ -356,7 +364,13 @@ NAME_TO_OP = {
class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
def __init__(
self,
name: str,
prob: float = 0.5,
magnitude: float = 10,
hparams: Optional[Dict[str, Any]] = None
):
hparams = hparams or _HPARAMS_DEFAULT
self.name = name
self.aug_fn = NAME_TO_OP[name]
@ -365,8 +379,8 @@ class AugmentOp:
self.magnitude = magnitude
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)
# If magnitude_std is > 0, we introduce some randomness
@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):
class AutoAugment:
def __init__(self, policy):
def __init__(self, policy: List):
self.policy = policy
def __call__(self, img):
@ -729,8 +743,14 @@ def rand_augment_ops(
):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp(
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
return [
AugmentOp(
name,
prob=prob,
magnitude=magnitude,
hparams=hparams
) for name in transforms
]
class RandAugment:

View File

@ -87,7 +87,8 @@ class PrefetchLoader:
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
re_num_splits=0,
):
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)

View File

@ -12,7 +12,8 @@ try:
has_interpolation_mode = True
except ImportError:
has_interpolation_mode = False
from PIL import Image
from PIL import Image, ImageCms
import numpy as np
__all__ = [
@ -89,6 +90,31 @@ class MaybePILToTensor:
return f"{self.__class__.__name__}()"
class ToLab(transforms.ToTensor):
def __init__(self) -> None:
super().__init__()
rgb_profile = ImageCms.createProfile(colorSpace='sRGB')
lab_profile = ImageCms.createProfile(colorSpace='LAB')
# Create a transform object from the input and output profiles
self.rgb_to_lab_transform = ImageCms.buildTransform(
inputProfile=rgb_profile,
outputProfile=lab_profile,
inMode='RGB',
outMode='LAB'
)
def __call__(self, pic) -> torch.Tensor:
lab_image = ImageCms.applyTransform(
im=pic,
transform=self.rgb_to_lab_transform
)
return lab_image
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.

View File

@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2019, Ross Wightman
"""
import math
from copy import deepcopy
from typing import Optional, Tuple, Union
import torch
@ -84,6 +85,7 @@ def transforms_imagenet_train(
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
use_tensor: Optional[bool] = True, # FIXME forced True for testing
):
""" ImageNet-oriented image transforms for training.
@ -111,6 +113,7 @@ def transforms_imagenet_train(
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
use_tensor: Use of float [0, 1.0) tensors for image transforms
Returns:
If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -119,13 +122,18 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
if use_tensor:
primary_tfl = [MaybeToTensor()]
else:
primary_tfl = []
train_crop_mode = train_crop_mode or 'rrc'
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
if train_crop_mode in ('rkrc', 'rkrr'):
# FIXME integration of RKR is a WIP
scale = tuple(scale or (0.8, 1.00))
ratio = tuple(ratio or (0.9, 1/.9))
primary_tfl = [
primary_tfl += [
ResizeKeepRatio(
img_size,
interpolation=interpolation,
@ -142,7 +150,7 @@ def transforms_imagenet_train(
else:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl = [
primary_tfl += [
RandomResizedCropAndInterpolation(
img_size,
scale=scale,
@ -166,9 +174,13 @@ def transforms_imagenet_train(
img_size_min = min(img_size)
else:
img_size_min = img_size
if use_tensor:
aa_mean = deepcopy(mean)
else:
aa_mean = tuple([min(255, round(255 * x)) for x in mean])
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
img_mean=aa_mean,
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = str_to_pil_interp(interpolation)
@ -218,10 +230,12 @@ def transforms_imagenet_train(
final_tfl += [ToNumpy()]
elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [MaybePILToTensor()]
if not use_tensor:
final_tfl += [MaybePILToTensor()]
else:
if not use_tensor:
final_tfl += [MaybeToTensor()]
final_tfl += [
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),