From 3b181b78d1f6fc1f35848d0f26bb37efb1b12798 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Dec 2024 12:24:04 -0800 Subject: [PATCH] Updating augmentations, esp randaug to support full torch.Tensor pipeline --- timm/data/auto_augment.py | 202 ++++++++++++++++++-------------- timm/data/loader.py | 3 +- timm/data/transforms.py | 28 ++++- timm/data/transforms_factory.py | 24 +++- 4 files changed, 159 insertions(+), 98 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 94438a0e..f1feada3 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -24,12 +24,18 @@ import random import math import re from functools import partial -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union -from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter +import torch import PIL import numpy as np - +from PIL import Image, ImageFilter +from torchvision.transforms import InterpolationMode +import torchvision.transforms.functional as TF +try: + import torchvision.transforms.v2.functional as TF2 +except ImportError: + TF2 = None _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) @@ -42,118 +48,111 @@ _HPARAMS_DEFAULT = dict( img_mean=_FILL, ) -if hasattr(Image, "Resampling"): - _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) - _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC -else: - _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) - _DEFAULT_INTERPOLATION = Image.BICUBIC + +_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC) +_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC -def _interpolation(kwargs): - interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) +def _interpolation(kwargs, basic_only=False): + interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION) if isinstance(interpolation, (list, tuple)): - return random.choice(interpolation) + interpolation = random.choice(interpolation) + if basic_only: + if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR): + interpolation = InterpolationMode.BILINEAR return interpolation def _check_args_tf(kwargs): - if 'fillcolor' in kwargs and _PIL_VER < (5, 0): - kwargs.pop('fillcolor') - kwargs['resample'] = _interpolation(kwargs) + kwargs['interpolation'] = _interpolation(kwargs) + + +def _check_args_affine(img, kwargs): + if isinstance(img, torch.Tensor): + kwargs['interpolation'] = _interpolation(kwargs, basic_only=True) + else: + kwargs['interpolation'] = _interpolation(kwargs) def shear_x(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs) def shear_y(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) - - -def translate_x_rel(img, pct, **kwargs): - pixels = pct * img.size[0] - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) - - -def translate_y_rel(img, pct, **kwargs): - pixels = pct * img.size[1] - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs) def translate_x_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs) def translate_y_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * TF.get_image_size(img)[0] + return translate_x_abs(img, pixels, **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * TF.get_image_size(img)[1] + return translate_y_abs(img, pixels, **kwargs) def rotate(img, degrees, **kwargs): - _check_args_tf(kwargs) - if _PIL_VER >= (5, 2): - return img.rotate(degrees, **kwargs) - if _PIL_VER >= (5, 0): - w, h = img.size - post_trans = (0, 0) - rotn_center = (w / 2.0, h / 2.0) - angle = -math.radians(degrees) - matrix = [ - round(math.cos(angle), 15), - round(math.sin(angle), 15), - 0.0, - round(-math.sin(angle), 15), - round(math.cos(angle), 15), - 0.0, - ] - - def transform(x, y, matrix): - (a, b, c, d, e, f) = matrix - return a * x + b * y + c, d * x + e * y + f - - matrix[2], matrix[5] = transform( - -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix - ) - matrix[2] += rotn_center[0] - matrix[5] += rotn_center[1] - return img.transform(img.size, Image.AFFINE, matrix, **kwargs) - return img.rotate(degrees, resample=kwargs['resample']) + _check_args_affine(img, kwargs) + return TF.rotate(img, degrees, **kwargs) def auto_contrast(img, **__): - return ImageOps.autocontrast(img) + return TF.autocontrast(img) def invert(img, **__): - return ImageOps.invert(img) + return TF.invert(img) def equalize(img, **__): - return ImageOps.equalize(img) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + if TF2 is None: + # FIXME warn / assert? + return img + return TF2.equalize(img) + return TF.equalize(img) def solarize(img, thresh, **__): - return ImageOps.solarize(img, thresh) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + thresh = min(thresh / 255, 1.0) + return TF.solarize(img, thresh) def solarize_add(img, add, thresh=128, **__): - lut = [] - for i in range(256): - if i < thresh: - lut.append(min(255, i + add)) + if isinstance(img, torch.Tensor): + if img.is_floating_point(): + thresh = thresh / 255 + add = add / 255 + img_sum = (img + add).clamp_(max=1.0) else: - lut.append(i) + img_sum = (img + add).clamp_(max=255) + return torch.where(img >= thresh, img_sum, img) + else: + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) - if img.mode in ("L", "RGB"): - if img.mode == "RGB" and len(lut) == 256: - lut = lut + lut + lut - return img.point(lut) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) return img @@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__): def posterize(img, bits_to_keep, **__): if bits_to_keep >= 8: return img - return ImageOps.posterize(img, bits_to_keep) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + if TF2 is None: + # FIXME warn / assert? + return img + return TF2.posterize(img, bits_to_keep) + return TF.posterize(img, bits_to_keep) def contrast(img, factor, **__): - return ImageEnhance.Contrast(img).enhance(factor) + return TF.adjust_contrast(img, factor) def color(img, factor, **__): - return ImageEnhance.Color(img).enhance(factor) + return TF.adjust_saturation(img, factor) def brightness(img, factor, **__): - return ImageEnhance.Brightness(img).enhance(factor) + return TF.adjust_brightness(img, factor) def sharpness(img, factor, **__): - return ImageEnhance.Sharpness(img).enhance(factor) + return TF.adjust_sharpness(img, factor) def gaussian_blur(img, factor, **__): - img = img.filter(ImageFilter.GaussianBlur(radius=factor)) + if isinstance(img, torch.Tensor): + kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive + img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor) + else: + img = img.filter(ImageFilter.GaussianBlur(radius=factor)) return img def gaussian_blur_rand(img, factor, **__): radius_min = 0.1 radius_max = 2.0 - img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) - return img + radius = random.uniform(radius_min, radius_max * factor) + return gaussian_blur(img, radius) def desaturate(img, factor, **_): factor = min(1., max(0., 1. - factor)) # enhance factor 0 = grayscale, 1.0 = no-change - return ImageEnhance.Color(img).enhance(factor) + return TF.adjust_saturation(img, factor) def _randomly_negate(v): @@ -356,7 +364,13 @@ NAME_TO_OP = { class AugmentOp: - def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + def __init__( + self, + name: str, + prob: float = 0.5, + magnitude: float = 10, + hparams: Optional[Dict[str, Any]] = None + ): hparams = hparams or _HPARAMS_DEFAULT self.name = name self.aug_fn = NAME_TO_OP[name] @@ -365,8 +379,8 @@ class AugmentOp: self.magnitude = magnitude self.hparams = hparams.copy() self.kwargs = dict( - fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, - resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, ) # If magnitude_std is > 0, we introduce some randomness @@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None): class AutoAugment: - def __init__(self, policy): + def __init__(self, policy: List): self.policy = policy def __call__(self, img): @@ -729,8 +743,14 @@ def rand_augment_ops( ): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS - return [AugmentOp( - name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] + return [ + AugmentOp( + name, + prob=prob, + magnitude=magnitude, + hparams=hparams + ) for name in transforms + ] class RandAugment: diff --git a/timm/data/loader.py b/timm/data/loader.py index 3b4a6d0e..8253b2e0 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -87,7 +87,8 @@ class PrefetchLoader: re_prob=0., re_mode='const', re_count=1, - re_num_splits=0): + re_num_splits=0, + ): mean = adapt_to_chs(mean, channels) std = adapt_to_chs(std, channels) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 215b7b5b..e0c7e7f9 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -12,7 +12,8 @@ try: has_interpolation_mode = True except ImportError: has_interpolation_mode = False -from PIL import Image +from PIL import Image, ImageCms + import numpy as np __all__ = [ @@ -89,6 +90,31 @@ class MaybePILToTensor: return f"{self.__class__.__name__}()" +class ToLab(transforms.ToTensor): + + def __init__(self) -> None: + super().__init__() + rgb_profile = ImageCms.createProfile(colorSpace='sRGB') + lab_profile = ImageCms.createProfile(colorSpace='LAB') + # Create a transform object from the input and output profiles + self.rgb_to_lab_transform = ImageCms.buildTransform( + inputProfile=rgb_profile, + outputProfile=lab_profile, + inMode='RGB', + outMode='LAB' + ) + + def __call__(self, pic) -> torch.Tensor: + lab_image = ImageCms.applyTransform( + im=pic, + transform=self.rgb_to_lab_transform + ) + return lab_image + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # favor of the Image.Resampling enum. The top-level resampling attributes will be # removed in Pillow 10. diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 9be0e3bf..5653109f 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M Hacked together by / Copyright 2019, Ross Wightman """ import math +from copy import deepcopy from typing import Optional, Tuple, Union import torch @@ -84,6 +85,7 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + use_tensor: Optional[bool] = True, # FIXME forced True for testing ): """ ImageNet-oriented image transforms for training. @@ -111,6 +113,7 @@ def transforms_imagenet_train( use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. + use_tensor: Use of float [0, 1.0) tensors for image transforms Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -119,13 +122,18 @@ def transforms_imagenet_train( * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ + if use_tensor: + primary_tfl = [MaybeToTensor()] + else: + primary_tfl = [] + train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} if train_crop_mode in ('rkrc', 'rkrr'): # FIXME integration of RKR is a WIP scale = tuple(scale or (0.8, 1.00)) ratio = tuple(ratio or (0.9, 1/.9)) - primary_tfl = [ + primary_tfl += [ ResizeKeepRatio( img_size, interpolation=interpolation, @@ -142,7 +150,7 @@ def transforms_imagenet_train( else: scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range - primary_tfl = [ + primary_tfl += [ RandomResizedCropAndInterpolation( img_size, scale=scale, @@ -166,9 +174,13 @@ def transforms_imagenet_train( img_size_min = min(img_size) else: img_size_min = img_size + if use_tensor: + aa_mean = deepcopy(mean) + else: + aa_mean = tuple([min(255, round(255 * x)) for x in mean]) aa_params = dict( translate_const=int(img_size_min * 0.45), - img_mean=tuple([min(255, round(255 * x)) for x in mean]), + img_mean=aa_mean, ) if interpolation and interpolation != 'random': aa_params['interpolation'] = str_to_pil_interp(interpolation) @@ -218,10 +230,12 @@ def transforms_imagenet_train( final_tfl += [ToNumpy()] elif not normalize: # when normalize disable, converted to tensor without scaling, keeps original dtype - final_tfl += [MaybePILToTensor()] + if not use_tensor: + final_tfl += [MaybePILToTensor()] else: + if not use_tensor: + final_tfl += [MaybeToTensor()] final_tfl += [ - MaybeToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std),