Updating augmentations, esp randaug to support full torch.Tensor pipeline

This commit is contained in:
Ross Wightman 2024-12-18 12:24:04 -08:00
parent ea231079f5
commit 3b181b78d1
4 changed files with 159 additions and 98 deletions

View File

@ -24,12 +24,18 @@ import random
import math import math
import re import re
from functools import partial from functools import partial
from typing import Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter import torch
import PIL import PIL
import numpy as np import numpy as np
from PIL import Image, ImageFilter
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
try:
import torchvision.transforms.v2.functional as TF2
except ImportError:
TF2 = None
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
@ -42,118 +48,111 @@ _HPARAMS_DEFAULT = dict(
img_mean=_FILL, img_mean=_FILL,
) )
if hasattr(Image, "Resampling"):
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) _RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC _DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC
else:
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_DEFAULT_INTERPOLATION = Image.BICUBIC
def _interpolation(kwargs): def _interpolation(kwargs, basic_only=False):
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION)
if isinstance(interpolation, (list, tuple)): if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation) interpolation = random.choice(interpolation)
if basic_only:
if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR):
interpolation = InterpolationMode.BILINEAR
return interpolation return interpolation
def _check_args_tf(kwargs): def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0): kwargs['interpolation'] = _interpolation(kwargs)
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)
def _check_args_affine(img, kwargs):
if isinstance(img, torch.Tensor):
kwargs['interpolation'] = _interpolation(kwargs, basic_only=True)
else:
kwargs['interpolation'] = _interpolation(kwargs)
def shear_x(img, factor, **kwargs): def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs) _check_args_affine(img, kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs)
def shear_y(img, factor, **kwargs): def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs) _check_args_affine(img, kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_x_abs(img, pixels, **kwargs): def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs) _check_args_affine(img, kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs)
def translate_y_abs(img, pixels, **kwargs): def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs) _check_args_affine(img, kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * TF.get_image_size(img)[0]
return translate_x_abs(img, pixels, **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * TF.get_image_size(img)[1]
return translate_y_abs(img, pixels, **kwargs)
def rotate(img, degrees, **kwargs): def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs) _check_args_affine(img, kwargs)
if _PIL_VER >= (5, 2): return TF.rotate(img, degrees, **kwargs)
return img.rotate(degrees, **kwargs)
if _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
return img.rotate(degrees, resample=kwargs['resample'])
def auto_contrast(img, **__): def auto_contrast(img, **__):
return ImageOps.autocontrast(img) return TF.autocontrast(img)
def invert(img, **__): def invert(img, **__):
return ImageOps.invert(img) return TF.invert(img)
def equalize(img, **__): def equalize(img, **__):
return ImageOps.equalize(img) if isinstance(img, torch.Tensor) and img.is_floating_point():
if TF2 is None:
# FIXME warn / assert?
return img
return TF2.equalize(img)
return TF.equalize(img)
def solarize(img, thresh, **__): def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh) if isinstance(img, torch.Tensor) and img.is_floating_point():
thresh = min(thresh / 255, 1.0)
return TF.solarize(img, thresh)
def solarize_add(img, add, thresh=128, **__): def solarize_add(img, add, thresh=128, **__):
lut = [] if isinstance(img, torch.Tensor):
for i in range(256): if img.is_floating_point():
if i < thresh: thresh = thresh / 255
lut.append(min(255, i + add)) add = add / 255
img_sum = (img + add).clamp_(max=1.0)
else: else:
lut.append(i) img_sum = (img + add).clamp_(max=255)
return torch.where(img >= thresh, img_sum, img)
else:
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"): if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256: if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut lut = lut + lut + lut
return img.point(lut) return img.point(lut)
return img return img
@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__):
def posterize(img, bits_to_keep, **__): def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8: if bits_to_keep >= 8:
return img return img
return ImageOps.posterize(img, bits_to_keep) if isinstance(img, torch.Tensor) and img.is_floating_point():
if TF2 is None:
# FIXME warn / assert?
return img
return TF2.posterize(img, bits_to_keep)
return TF.posterize(img, bits_to_keep)
def contrast(img, factor, **__): def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor) return TF.adjust_contrast(img, factor)
def color(img, factor, **__): def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor) return TF.adjust_saturation(img, factor)
def brightness(img, factor, **__): def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor) return TF.adjust_brightness(img, factor)
def sharpness(img, factor, **__): def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor) return TF.adjust_sharpness(img, factor)
def gaussian_blur(img, factor, **__): def gaussian_blur(img, factor, **__):
img = img.filter(ImageFilter.GaussianBlur(radius=factor)) if isinstance(img, torch.Tensor):
kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive
img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor)
else:
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
return img return img
def gaussian_blur_rand(img, factor, **__): def gaussian_blur_rand(img, factor, **__):
radius_min = 0.1 radius_min = 0.1
radius_max = 2.0 radius_max = 2.0
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) radius = random.uniform(radius_min, radius_max * factor)
return img return gaussian_blur(img, radius)
def desaturate(img, factor, **_): def desaturate(img, factor, **_):
factor = min(1., max(0., 1. - factor)) factor = min(1., max(0., 1. - factor))
# enhance factor 0 = grayscale, 1.0 = no-change # enhance factor 0 = grayscale, 1.0 = no-change
return ImageEnhance.Color(img).enhance(factor) return TF.adjust_saturation(img, factor)
def _randomly_negate(v): def _randomly_negate(v):
@ -356,7 +364,13 @@ NAME_TO_OP = {
class AugmentOp: class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None): def __init__(
self,
name: str,
prob: float = 0.5,
magnitude: float = 10,
hparams: Optional[Dict[str, Any]] = None
):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
self.name = name self.name = name
self.aug_fn = NAME_TO_OP[name] self.aug_fn = NAME_TO_OP[name]
@ -365,8 +379,8 @@ class AugmentOp:
self.magnitude = magnitude self.magnitude = magnitude
self.hparams = hparams.copy() self.hparams = hparams.copy()
self.kwargs = dict( self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
) )
# If magnitude_std is > 0, we introduce some randomness # If magnitude_std is > 0, we introduce some randomness
@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):
class AutoAugment: class AutoAugment:
def __init__(self, policy): def __init__(self, policy: List):
self.policy = policy self.policy = policy
def __call__(self, img): def __call__(self, img):
@ -729,8 +743,14 @@ def rand_augment_ops(
): ):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp( return [
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] AugmentOp(
name,
prob=prob,
magnitude=magnitude,
hparams=hparams
) for name in transforms
]
class RandAugment: class RandAugment:

View File

@ -87,7 +87,8 @@ class PrefetchLoader:
re_prob=0., re_prob=0.,
re_mode='const', re_mode='const',
re_count=1, re_count=1,
re_num_splits=0): re_num_splits=0,
):
mean = adapt_to_chs(mean, channels) mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels) std = adapt_to_chs(std, channels)

View File

@ -12,7 +12,8 @@ try:
has_interpolation_mode = True has_interpolation_mode = True
except ImportError: except ImportError:
has_interpolation_mode = False has_interpolation_mode = False
from PIL import Image from PIL import Image, ImageCms
import numpy as np import numpy as np
__all__ = [ __all__ = [
@ -89,6 +90,31 @@ class MaybePILToTensor:
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"
class ToLab(transforms.ToTensor):
def __init__(self) -> None:
super().__init__()
rgb_profile = ImageCms.createProfile(colorSpace='sRGB')
lab_profile = ImageCms.createProfile(colorSpace='LAB')
# Create a transform object from the input and output profiles
self.rgb_to_lab_transform = ImageCms.buildTransform(
inputProfile=rgb_profile,
outputProfile=lab_profile,
inMode='RGB',
outMode='LAB'
)
def __call__(self, pic) -> torch.Tensor:
lab_image = ImageCms.applyTransform(
im=pic,
transform=self.rgb_to_lab_transform
)
return lab_image
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be # favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10. # removed in Pillow 10.

View File

@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import math import math
from copy import deepcopy
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@ -84,6 +85,7 @@ def transforms_imagenet_train(
use_prefetcher: bool = False, use_prefetcher: bool = False,
normalize: bool = True, normalize: bool = True,
separate: bool = False, separate: bool = False,
use_tensor: Optional[bool] = True, # FIXME forced True for testing
): ):
""" ImageNet-oriented image transforms for training. """ ImageNet-oriented image transforms for training.
@ -111,6 +113,7 @@ def transforms_imagenet_train(
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple. separate: Output transforms in 3-stage tuple.
use_tensor: Use of float [0, 1.0) tensors for image transforms
Returns: Returns:
If separate==True, the transforms are returned as a tuple of 3 separate transforms If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -119,13 +122,18 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform * a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform * normalizes and converts the branches above with the third, final transform
""" """
if use_tensor:
primary_tfl = [MaybeToTensor()]
else:
primary_tfl = []
train_crop_mode = train_crop_mode or 'rrc' train_crop_mode = train_crop_mode or 'rrc'
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
if train_crop_mode in ('rkrc', 'rkrr'): if train_crop_mode in ('rkrc', 'rkrr'):
# FIXME integration of RKR is a WIP # FIXME integration of RKR is a WIP
scale = tuple(scale or (0.8, 1.00)) scale = tuple(scale or (0.8, 1.00))
ratio = tuple(ratio or (0.9, 1/.9)) ratio = tuple(ratio or (0.9, 1/.9))
primary_tfl = [ primary_tfl += [
ResizeKeepRatio( ResizeKeepRatio(
img_size, img_size,
interpolation=interpolation, interpolation=interpolation,
@ -142,7 +150,7 @@ def transforms_imagenet_train(
else: else:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl = [ primary_tfl += [
RandomResizedCropAndInterpolation( RandomResizedCropAndInterpolation(
img_size, img_size,
scale=scale, scale=scale,
@ -166,9 +174,13 @@ def transforms_imagenet_train(
img_size_min = min(img_size) img_size_min = min(img_size)
else: else:
img_size_min = img_size img_size_min = img_size
if use_tensor:
aa_mean = deepcopy(mean)
else:
aa_mean = tuple([min(255, round(255 * x)) for x in mean])
aa_params = dict( aa_params = dict(
translate_const=int(img_size_min * 0.45), translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]), img_mean=aa_mean,
) )
if interpolation and interpolation != 'random': if interpolation and interpolation != 'random':
aa_params['interpolation'] = str_to_pil_interp(interpolation) aa_params['interpolation'] = str_to_pil_interp(interpolation)
@ -218,10 +230,12 @@ def transforms_imagenet_train(
final_tfl += [ToNumpy()] final_tfl += [ToNumpy()]
elif not normalize: elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype # when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [MaybePILToTensor()] if not use_tensor:
final_tfl += [MaybePILToTensor()]
else: else:
if not use_tensor:
final_tfl += [MaybeToTensor()]
final_tfl += [ final_tfl += [
MaybeToTensor(),
transforms.Normalize( transforms.Normalize(
mean=torch.tensor(mean), mean=torch.tensor(mean),
std=torch.tensor(std), std=torch.tensor(std),