Add normalize flag to transforms factory, allow return of non-normalized native dtype torch.Tensors

This commit is contained in:
Ross Wightman 2024-05-13 15:23:25 -07:00
parent a69863ad61
commit 3bfd036b58

View File

@ -19,9 +19,10 @@ from timm.data.random_erasing import RandomErasing
def transforms_noaug_train(
img_size: Union[int, Tuple[int, int]] = 224,
interpolation: str = 'bilinear',
use_prefetcher: bool = False,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
):
""" No-augmentation image transforms for training.
@ -31,6 +32,7 @@ def transforms_noaug_train(
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
Returns:
@ -45,6 +47,9 @@ def transforms_noaug_train(
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keep original dtype
tfl += [transforms.PILToTensor()]
else:
tfl += [
transforms.ToTensor(),
@ -77,6 +82,7 @@ def transforms_imagenet_train(
re_count: int = 1,
re_num_splits: int = 0,
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
):
""" ImageNet-oriented image transforms for training.
@ -103,6 +109,7 @@ def transforms_imagenet_train(
re_count: Number of random erasing regions.
re_num_splits: Control split of random erasing across batch size.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
Returns:
@ -209,12 +216,15 @@ def transforms_imagenet_train(
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [transforms.PILToTensor()]
else:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std)
std=torch.tensor(std),
),
]
if re_prob > 0.:
@ -243,6 +253,7 @@ def transforms_imagenet_eval(
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
):
""" ImageNet-oriented image transform for evaluation and inference.
@ -255,6 +266,7 @@ def transforms_imagenet_eval(
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
Returns:
Composed transform pipeline
@ -304,13 +316,16 @@ def transforms_imagenet_eval(
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keeps original dtype
tfl += [transforms.PILToTensor()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
)
),
]
return transforms.Compose(tfl)
@ -342,6 +357,7 @@ def create_transform(
crop_border_pixels: Optional[int] = None,
tf_preprocessing: bool = False,
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
):
"""
@ -373,6 +389,7 @@ def create_transform(
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
Returns:
@ -397,9 +414,10 @@ def create_transform(
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
use_prefetcher=use_prefetcher,
normalize=normalize,
)
elif is_training:
transform = transforms_imagenet_train(
@ -415,13 +433,14 @@ def create_transform(
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
use_prefetcher=use_prefetcher,
normalize=normalize,
separate=separate,
)
else:
@ -429,12 +448,13 @@ def create_transform(
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
use_prefetcher=use_prefetcher,
normalize=normalize,
)
return transform