mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Updating augmentations, esp randaug to support full torch.Tensor pipeline
This commit is contained in:
parent
ea231079f5
commit
3b181b78d1
@ -24,12 +24,18 @@ import random
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
|
import torch
|
||||||
import PIL
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
try:
|
||||||
|
import torchvision.transforms.v2.functional as TF2
|
||||||
|
except ImportError:
|
||||||
|
TF2 = None
|
||||||
|
|
||||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||||
|
|
||||||
@ -42,118 +48,111 @@ _HPARAMS_DEFAULT = dict(
|
|||||||
img_mean=_FILL,
|
img_mean=_FILL,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(Image, "Resampling"):
|
|
||||||
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
|
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
|
||||||
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
|
_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC
|
||||||
else:
|
|
||||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
|
||||||
_DEFAULT_INTERPOLATION = Image.BICUBIC
|
|
||||||
|
|
||||||
|
|
||||||
def _interpolation(kwargs):
|
def _interpolation(kwargs, basic_only=False):
|
||||||
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
|
interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION)
|
||||||
if isinstance(interpolation, (list, tuple)):
|
if isinstance(interpolation, (list, tuple)):
|
||||||
return random.choice(interpolation)
|
interpolation = random.choice(interpolation)
|
||||||
|
if basic_only:
|
||||||
|
if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR):
|
||||||
|
interpolation = InterpolationMode.BILINEAR
|
||||||
return interpolation
|
return interpolation
|
||||||
|
|
||||||
|
|
||||||
def _check_args_tf(kwargs):
|
def _check_args_tf(kwargs):
|
||||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
kwargs['interpolation'] = _interpolation(kwargs)
|
||||||
kwargs.pop('fillcolor')
|
|
||||||
kwargs['resample'] = _interpolation(kwargs)
|
|
||||||
|
def _check_args_affine(img, kwargs):
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
kwargs['interpolation'] = _interpolation(kwargs, basic_only=True)
|
||||||
|
else:
|
||||||
|
kwargs['interpolation'] = _interpolation(kwargs)
|
||||||
|
|
||||||
|
|
||||||
def shear_x(img, factor, **kwargs):
|
def shear_x(img, factor, **kwargs):
|
||||||
_check_args_tf(kwargs)
|
_check_args_affine(img, kwargs)
|
||||||
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def shear_y(img, factor, **kwargs):
|
def shear_y(img, factor, **kwargs):
|
||||||
_check_args_tf(kwargs)
|
_check_args_affine(img, kwargs)
|
||||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def translate_x_rel(img, pct, **kwargs):
|
|
||||||
pixels = pct * img.size[0]
|
|
||||||
_check_args_tf(kwargs)
|
|
||||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def translate_y_rel(img, pct, **kwargs):
|
|
||||||
pixels = pct * img.size[1]
|
|
||||||
_check_args_tf(kwargs)
|
|
||||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def translate_x_abs(img, pixels, **kwargs):
|
def translate_x_abs(img, pixels, **kwargs):
|
||||||
_check_args_tf(kwargs)
|
_check_args_affine(img, kwargs)
|
||||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def translate_y_abs(img, pixels, **kwargs):
|
def translate_y_abs(img, pixels, **kwargs):
|
||||||
_check_args_tf(kwargs)
|
_check_args_affine(img, kwargs)
|
||||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_x_rel(img, pct, **kwargs):
|
||||||
|
pixels = pct * TF.get_image_size(img)[0]
|
||||||
|
return translate_x_abs(img, pixels, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_y_rel(img, pct, **kwargs):
|
||||||
|
pixels = pct * TF.get_image_size(img)[1]
|
||||||
|
return translate_y_abs(img, pixels, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def rotate(img, degrees, **kwargs):
|
def rotate(img, degrees, **kwargs):
|
||||||
_check_args_tf(kwargs)
|
_check_args_affine(img, kwargs)
|
||||||
if _PIL_VER >= (5, 2):
|
return TF.rotate(img, degrees, **kwargs)
|
||||||
return img.rotate(degrees, **kwargs)
|
|
||||||
if _PIL_VER >= (5, 0):
|
|
||||||
w, h = img.size
|
|
||||||
post_trans = (0, 0)
|
|
||||||
rotn_center = (w / 2.0, h / 2.0)
|
|
||||||
angle = -math.radians(degrees)
|
|
||||||
matrix = [
|
|
||||||
round(math.cos(angle), 15),
|
|
||||||
round(math.sin(angle), 15),
|
|
||||||
0.0,
|
|
||||||
round(-math.sin(angle), 15),
|
|
||||||
round(math.cos(angle), 15),
|
|
||||||
0.0,
|
|
||||||
]
|
|
||||||
|
|
||||||
def transform(x, y, matrix):
|
|
||||||
(a, b, c, d, e, f) = matrix
|
|
||||||
return a * x + b * y + c, d * x + e * y + f
|
|
||||||
|
|
||||||
matrix[2], matrix[5] = transform(
|
|
||||||
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
|
||||||
)
|
|
||||||
matrix[2] += rotn_center[0]
|
|
||||||
matrix[5] += rotn_center[1]
|
|
||||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
|
||||||
return img.rotate(degrees, resample=kwargs['resample'])
|
|
||||||
|
|
||||||
|
|
||||||
def auto_contrast(img, **__):
|
def auto_contrast(img, **__):
|
||||||
return ImageOps.autocontrast(img)
|
return TF.autocontrast(img)
|
||||||
|
|
||||||
|
|
||||||
def invert(img, **__):
|
def invert(img, **__):
|
||||||
return ImageOps.invert(img)
|
return TF.invert(img)
|
||||||
|
|
||||||
|
|
||||||
def equalize(img, **__):
|
def equalize(img, **__):
|
||||||
return ImageOps.equalize(img)
|
if isinstance(img, torch.Tensor) and img.is_floating_point():
|
||||||
|
if TF2 is None:
|
||||||
|
# FIXME warn / assert?
|
||||||
|
return img
|
||||||
|
return TF2.equalize(img)
|
||||||
|
return TF.equalize(img)
|
||||||
|
|
||||||
|
|
||||||
def solarize(img, thresh, **__):
|
def solarize(img, thresh, **__):
|
||||||
return ImageOps.solarize(img, thresh)
|
if isinstance(img, torch.Tensor) and img.is_floating_point():
|
||||||
|
thresh = min(thresh / 255, 1.0)
|
||||||
|
return TF.solarize(img, thresh)
|
||||||
|
|
||||||
|
|
||||||
def solarize_add(img, add, thresh=128, **__):
|
def solarize_add(img, add, thresh=128, **__):
|
||||||
lut = []
|
if isinstance(img, torch.Tensor):
|
||||||
for i in range(256):
|
if img.is_floating_point():
|
||||||
if i < thresh:
|
thresh = thresh / 255
|
||||||
lut.append(min(255, i + add))
|
add = add / 255
|
||||||
|
img_sum = (img + add).clamp_(max=1.0)
|
||||||
else:
|
else:
|
||||||
lut.append(i)
|
img_sum = (img + add).clamp_(max=255)
|
||||||
|
return torch.where(img >= thresh, img_sum, img)
|
||||||
|
else:
|
||||||
|
lut = []
|
||||||
|
for i in range(256):
|
||||||
|
if i < thresh:
|
||||||
|
lut.append(min(255, i + add))
|
||||||
|
else:
|
||||||
|
lut.append(i)
|
||||||
|
|
||||||
if img.mode in ("L", "RGB"):
|
if img.mode in ("L", "RGB"):
|
||||||
if img.mode == "RGB" and len(lut) == 256:
|
if img.mode == "RGB" and len(lut) == 256:
|
||||||
lut = lut + lut + lut
|
lut = lut + lut + lut
|
||||||
return img.point(lut)
|
return img.point(lut)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__):
|
|||||||
def posterize(img, bits_to_keep, **__):
|
def posterize(img, bits_to_keep, **__):
|
||||||
if bits_to_keep >= 8:
|
if bits_to_keep >= 8:
|
||||||
return img
|
return img
|
||||||
return ImageOps.posterize(img, bits_to_keep)
|
if isinstance(img, torch.Tensor) and img.is_floating_point():
|
||||||
|
if TF2 is None:
|
||||||
|
# FIXME warn / assert?
|
||||||
|
return img
|
||||||
|
return TF2.posterize(img, bits_to_keep)
|
||||||
|
return TF.posterize(img, bits_to_keep)
|
||||||
|
|
||||||
|
|
||||||
def contrast(img, factor, **__):
|
def contrast(img, factor, **__):
|
||||||
return ImageEnhance.Contrast(img).enhance(factor)
|
return TF.adjust_contrast(img, factor)
|
||||||
|
|
||||||
|
|
||||||
def color(img, factor, **__):
|
def color(img, factor, **__):
|
||||||
return ImageEnhance.Color(img).enhance(factor)
|
return TF.adjust_saturation(img, factor)
|
||||||
|
|
||||||
|
|
||||||
def brightness(img, factor, **__):
|
def brightness(img, factor, **__):
|
||||||
return ImageEnhance.Brightness(img).enhance(factor)
|
return TF.adjust_brightness(img, factor)
|
||||||
|
|
||||||
|
|
||||||
def sharpness(img, factor, **__):
|
def sharpness(img, factor, **__):
|
||||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
return TF.adjust_sharpness(img, factor)
|
||||||
|
|
||||||
|
|
||||||
def gaussian_blur(img, factor, **__):
|
def gaussian_blur(img, factor, **__):
|
||||||
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
|
if isinstance(img, torch.Tensor):
|
||||||
|
kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive
|
||||||
|
img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor)
|
||||||
|
else:
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def gaussian_blur_rand(img, factor, **__):
|
def gaussian_blur_rand(img, factor, **__):
|
||||||
radius_min = 0.1
|
radius_min = 0.1
|
||||||
radius_max = 2.0
|
radius_max = 2.0
|
||||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
|
radius = random.uniform(radius_min, radius_max * factor)
|
||||||
return img
|
return gaussian_blur(img, radius)
|
||||||
|
|
||||||
|
|
||||||
def desaturate(img, factor, **_):
|
def desaturate(img, factor, **_):
|
||||||
factor = min(1., max(0., 1. - factor))
|
factor = min(1., max(0., 1. - factor))
|
||||||
# enhance factor 0 = grayscale, 1.0 = no-change
|
# enhance factor 0 = grayscale, 1.0 = no-change
|
||||||
return ImageEnhance.Color(img).enhance(factor)
|
return TF.adjust_saturation(img, factor)
|
||||||
|
|
||||||
|
|
||||||
def _randomly_negate(v):
|
def _randomly_negate(v):
|
||||||
@ -356,7 +364,13 @@ NAME_TO_OP = {
|
|||||||
|
|
||||||
class AugmentOp:
|
class AugmentOp:
|
||||||
|
|
||||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
prob: float = 0.5,
|
||||||
|
magnitude: float = 10,
|
||||||
|
hparams: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
hparams = hparams or _HPARAMS_DEFAULT
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
self.name = name
|
self.name = name
|
||||||
self.aug_fn = NAME_TO_OP[name]
|
self.aug_fn = NAME_TO_OP[name]
|
||||||
@ -365,8 +379,8 @@ class AugmentOp:
|
|||||||
self.magnitude = magnitude
|
self.magnitude = magnitude
|
||||||
self.hparams = hparams.copy()
|
self.hparams = hparams.copy()
|
||||||
self.kwargs = dict(
|
self.kwargs = dict(
|
||||||
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||||
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If magnitude_std is > 0, we introduce some randomness
|
# If magnitude_std is > 0, we introduce some randomness
|
||||||
@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):
|
|||||||
|
|
||||||
class AutoAugment:
|
class AutoAugment:
|
||||||
|
|
||||||
def __init__(self, policy):
|
def __init__(self, policy: List):
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
@ -729,8 +743,14 @@ def rand_augment_ops(
|
|||||||
):
|
):
|
||||||
hparams = hparams or _HPARAMS_DEFAULT
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
transforms = transforms or _RAND_TRANSFORMS
|
transforms = transforms or _RAND_TRANSFORMS
|
||||||
return [AugmentOp(
|
return [
|
||||||
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
|
AugmentOp(
|
||||||
|
name,
|
||||||
|
prob=prob,
|
||||||
|
magnitude=magnitude,
|
||||||
|
hparams=hparams
|
||||||
|
) for name in transforms
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RandAugment:
|
class RandAugment:
|
||||||
|
@ -87,7 +87,8 @@ class PrefetchLoader:
|
|||||||
re_prob=0.,
|
re_prob=0.,
|
||||||
re_mode='const',
|
re_mode='const',
|
||||||
re_count=1,
|
re_count=1,
|
||||||
re_num_splits=0):
|
re_num_splits=0,
|
||||||
|
):
|
||||||
|
|
||||||
mean = adapt_to_chs(mean, channels)
|
mean = adapt_to_chs(mean, channels)
|
||||||
std = adapt_to_chs(std, channels)
|
std = adapt_to_chs(std, channels)
|
||||||
|
@ -12,7 +12,8 @@ try:
|
|||||||
has_interpolation_mode = True
|
has_interpolation_mode = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_interpolation_mode = False
|
has_interpolation_mode = False
|
||||||
from PIL import Image
|
from PIL import Image, ImageCms
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -89,6 +90,31 @@ class MaybePILToTensor:
|
|||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
class ToLab(transforms.ToTensor):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
rgb_profile = ImageCms.createProfile(colorSpace='sRGB')
|
||||||
|
lab_profile = ImageCms.createProfile(colorSpace='LAB')
|
||||||
|
# Create a transform object from the input and output profiles
|
||||||
|
self.rgb_to_lab_transform = ImageCms.buildTransform(
|
||||||
|
inputProfile=rgb_profile,
|
||||||
|
outputProfile=lab_profile,
|
||||||
|
inMode='RGB',
|
||||||
|
outMode='LAB'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, pic) -> torch.Tensor:
|
||||||
|
lab_image = ImageCms.applyTransform(
|
||||||
|
im=pic,
|
||||||
|
transform=self.rgb_to_lab_transform
|
||||||
|
)
|
||||||
|
return lab_image
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
||||||
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
||||||
# removed in Pillow 10.
|
# removed in Pillow 10.
|
||||||
|
@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
|
|||||||
Hacked together by / Copyright 2019, Ross Wightman
|
Hacked together by / Copyright 2019, Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -84,6 +85,7 @@ def transforms_imagenet_train(
|
|||||||
use_prefetcher: bool = False,
|
use_prefetcher: bool = False,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
separate: bool = False,
|
separate: bool = False,
|
||||||
|
use_tensor: Optional[bool] = True, # FIXME forced True for testing
|
||||||
):
|
):
|
||||||
""" ImageNet-oriented image transforms for training.
|
""" ImageNet-oriented image transforms for training.
|
||||||
|
|
||||||
@ -111,6 +113,7 @@ def transforms_imagenet_train(
|
|||||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||||
separate: Output transforms in 3-stage tuple.
|
separate: Output transforms in 3-stage tuple.
|
||||||
|
use_tensor: Use of float [0, 1.0) tensors for image transforms
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||||
@ -119,13 +122,18 @@ def transforms_imagenet_train(
|
|||||||
* a portion of the data through the secondary transform
|
* a portion of the data through the secondary transform
|
||||||
* normalizes and converts the branches above with the third, final transform
|
* normalizes and converts the branches above with the third, final transform
|
||||||
"""
|
"""
|
||||||
|
if use_tensor:
|
||||||
|
primary_tfl = [MaybeToTensor()]
|
||||||
|
else:
|
||||||
|
primary_tfl = []
|
||||||
|
|
||||||
train_crop_mode = train_crop_mode or 'rrc'
|
train_crop_mode = train_crop_mode or 'rrc'
|
||||||
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
||||||
if train_crop_mode in ('rkrc', 'rkrr'):
|
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||||
# FIXME integration of RKR is a WIP
|
# FIXME integration of RKR is a WIP
|
||||||
scale = tuple(scale or (0.8, 1.00))
|
scale = tuple(scale or (0.8, 1.00))
|
||||||
ratio = tuple(ratio or (0.9, 1/.9))
|
ratio = tuple(ratio or (0.9, 1/.9))
|
||||||
primary_tfl = [
|
primary_tfl += [
|
||||||
ResizeKeepRatio(
|
ResizeKeepRatio(
|
||||||
img_size,
|
img_size,
|
||||||
interpolation=interpolation,
|
interpolation=interpolation,
|
||||||
@ -142,7 +150,7 @@ def transforms_imagenet_train(
|
|||||||
else:
|
else:
|
||||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||||
primary_tfl = [
|
primary_tfl += [
|
||||||
RandomResizedCropAndInterpolation(
|
RandomResizedCropAndInterpolation(
|
||||||
img_size,
|
img_size,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@ -166,9 +174,13 @@ def transforms_imagenet_train(
|
|||||||
img_size_min = min(img_size)
|
img_size_min = min(img_size)
|
||||||
else:
|
else:
|
||||||
img_size_min = img_size
|
img_size_min = img_size
|
||||||
|
if use_tensor:
|
||||||
|
aa_mean = deepcopy(mean)
|
||||||
|
else:
|
||||||
|
aa_mean = tuple([min(255, round(255 * x)) for x in mean])
|
||||||
aa_params = dict(
|
aa_params = dict(
|
||||||
translate_const=int(img_size_min * 0.45),
|
translate_const=int(img_size_min * 0.45),
|
||||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
img_mean=aa_mean,
|
||||||
)
|
)
|
||||||
if interpolation and interpolation != 'random':
|
if interpolation and interpolation != 'random':
|
||||||
aa_params['interpolation'] = str_to_pil_interp(interpolation)
|
aa_params['interpolation'] = str_to_pil_interp(interpolation)
|
||||||
@ -218,10 +230,12 @@ def transforms_imagenet_train(
|
|||||||
final_tfl += [ToNumpy()]
|
final_tfl += [ToNumpy()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||||
final_tfl += [MaybePILToTensor()]
|
if not use_tensor:
|
||||||
|
final_tfl += [MaybePILToTensor()]
|
||||||
else:
|
else:
|
||||||
|
if not use_tensor:
|
||||||
|
final_tfl += [MaybeToTensor()]
|
||||||
final_tfl += [
|
final_tfl += [
|
||||||
MaybeToTensor(),
|
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std),
|
std=torch.tensor(std),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user