mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add normalize flag to transforms factory, allow return of non-normalized native dtype torch.Tensors
This commit is contained in:
parent
a69863ad61
commit
3bfd036b58
@ -19,9 +19,10 @@ from timm.data.random_erasing import RandomErasing
|
||||
def transforms_noaug_train(
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
interpolation: str = 'bilinear',
|
||||
use_prefetcher: bool = False,
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
):
|
||||
""" No-augmentation image transforms for training.
|
||||
|
||||
@ -31,6 +32,7 @@ def transforms_noaug_train(
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
|
||||
|
||||
Returns:
|
||||
|
||||
@ -45,6 +47,9 @@ def transforms_noaug_train(
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
||||
tfl += [transforms.PILToTensor()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
@ -77,6 +82,7 @@ def transforms_imagenet_train(
|
||||
re_count: int = 1,
|
||||
re_num_splits: int = 0,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
):
|
||||
""" ImageNet-oriented image transforms for training.
|
||||
@ -103,6 +109,7 @@ def transforms_imagenet_train(
|
||||
re_count: Number of random erasing regions.
|
||||
re_num_splits: Control split of random erasing across batch size.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
|
||||
Returns:
|
||||
@ -209,12 +216,15 @@ def transforms_imagenet_train(
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
final_tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||
final_tfl += [transforms.PILToTensor()]
|
||||
else:
|
||||
final_tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std)
|
||||
std=torch.tensor(std),
|
||||
),
|
||||
]
|
||||
if re_prob > 0.:
|
||||
@ -243,6 +253,7 @@ def transforms_imagenet_eval(
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
):
|
||||
""" ImageNet-oriented image transform for evaluation and inference.
|
||||
|
||||
@ -255,6 +266,7 @@ def transforms_imagenet_eval(
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
|
||||
Returns:
|
||||
Composed transform pipeline
|
||||
@ -304,13 +316,16 @@ def transforms_imagenet_eval(
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
||||
tfl += [transforms.PILToTensor()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std),
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
@ -342,6 +357,7 @@ def create_transform(
|
||||
crop_border_pixels: Optional[int] = None,
|
||||
tf_preprocessing: bool = False,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -373,6 +389,7 @@ def create_transform(
|
||||
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
|
||||
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
|
||||
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
|
||||
Returns:
|
||||
@ -397,9 +414,10 @@ def create_transform(
|
||||
transform = transforms_noaug_train(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
)
|
||||
elif is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
@ -415,13 +433,14 @@ def create_transform(
|
||||
gaussian_blur_prob=gaussian_blur_prob,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
separate=separate,
|
||||
)
|
||||
else:
|
||||
@ -429,12 +448,13 @@ def create_transform(
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
crop_mode=crop_mode,
|
||||
crop_border_pixels=crop_border_pixels,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
)
|
||||
|
||||
return transform
|
||||
|
Loading…
x
Reference in New Issue
Block a user