diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 504d1199..59754891 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -19,9 +19,10 @@ from timm.data.random_erasing import RandomErasing def transforms_noaug_train( img_size: Union[int, Tuple[int, int]] = 224, interpolation: str = 'bilinear', - use_prefetcher: bool = False, mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + use_prefetcher: bool = False, + normalize: bool = True, ): """ No-augmentation image transforms for training. @@ -31,6 +32,7 @@ def transforms_noaug_train( mean: Image normalization mean. std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used). Returns: @@ -45,6 +47,9 @@ def transforms_noaug_train( if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] + elif not normalize: + # when normalize disabled, converted to tensor without scaling, keep original dtype + tfl += [transforms.PILToTensor()] else: tfl += [ transforms.ToTensor(), @@ -77,6 +82,7 @@ def transforms_imagenet_train( re_count: int = 1, re_num_splits: int = 0, use_prefetcher: bool = False, + normalize: bool = True, separate: bool = False, ): """ ImageNet-oriented image transforms for training. @@ -103,6 +109,7 @@ def transforms_imagenet_train( re_count: Number of random erasing regions. re_num_splits: Control split of random erasing across batch size. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. Returns: @@ -209,12 +216,15 @@ def transforms_imagenet_train( if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm final_tfl += [ToNumpy()] + elif not normalize: + # when normalize disable, converted to tensor without scaling, keeps original dtype + final_tfl += [transforms.PILToTensor()] else: final_tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), - std=torch.tensor(std) + std=torch.tensor(std), ), ] if re_prob > 0.: @@ -243,6 +253,7 @@ def transforms_imagenet_eval( mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, + normalize: bool = True, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -255,6 +266,7 @@ def transforms_imagenet_eval( mean: Image normalization mean. std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). Returns: Composed transform pipeline @@ -304,13 +316,16 @@ def transforms_imagenet_eval( if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] + elif not normalize: + # when normalize disabled, converted to tensor without scaling, keeps original dtype + tfl += [transforms.PILToTensor()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), - ) + ), ] return transforms.Compose(tfl) @@ -342,6 +357,7 @@ def create_transform( crop_border_pixels: Optional[int] = None, tf_preprocessing: bool = False, use_prefetcher: bool = False, + normalize: bool = True, separate: bool = False, ): """ @@ -373,6 +389,7 @@ def create_transform( crop_border_pixels: Inference crop border of specified # pixels around edge of original image. tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. Returns: @@ -397,9 +414,10 @@ def create_transform( transform = transforms_noaug_train( img_size, interpolation=interpolation, - use_prefetcher=use_prefetcher, mean=mean, std=std, + use_prefetcher=use_prefetcher, + normalize=normalize, ) elif is_training: transform = transforms_imagenet_train( @@ -415,13 +433,14 @@ def create_transform( gaussian_blur_prob=gaussian_blur_prob, auto_augment=auto_augment, interpolation=interpolation, - use_prefetcher=use_prefetcher, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, + use_prefetcher=use_prefetcher, + normalize=normalize, separate=separate, ) else: @@ -429,12 +448,13 @@ def create_transform( transform = transforms_imagenet_eval( img_size, interpolation=interpolation, - use_prefetcher=use_prefetcher, mean=mean, std=std, crop_pct=crop_pct, crop_mode=crop_mode, crop_border_pixels=crop_border_pixels, + use_prefetcher=use_prefetcher, + normalize=normalize, ) return transform