mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add 'Maybe' PIL / image tensor conversions in case image alread in tensor format
This commit is contained in:
parent
648aaa4123
commit
83c2c2f0c5
@ -5,6 +5,7 @@ import warnings
|
|||||||
from typing import List, Sequence, Tuple, Union
|
from typing import List, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
@ -17,7 +18,7 @@ import numpy as np
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
|
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
|
||||||
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
|
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
|
||||||
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
|
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder", "MaybeToTensor", "MaybePILToTensor"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -40,6 +41,54 @@ class ToTensor:
|
|||||||
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
|
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class MaybeToTensor(transforms.ToTensor):
|
||||||
|
"""Convert a PIL Image or ndarray to tensor if it's not already one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, pic) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Converted image.
|
||||||
|
"""
|
||||||
|
if isinstance(pic, torch.Tensor):
|
||||||
|
return pic
|
||||||
|
return F.to_tensor(pic)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
class MaybePILToTensor:
|
||||||
|
"""Convert a PIL Image to a tensor of the same type - this does not scale values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, pic):
|
||||||
|
"""
|
||||||
|
Note: A deep copy of the underlying array is performed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pic (PIL Image): Image to be converted to tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Converted image.
|
||||||
|
"""
|
||||||
|
if isinstance(pic, torch.Tensor):
|
||||||
|
return pic
|
||||||
|
return F.pil_to_tensor(pic)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
||||||
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
||||||
# removed in Pillow 10.
|
# removed in Pillow 10.
|
||||||
|
@ -11,8 +11,8 @@ from torchvision import transforms
|
|||||||
|
|
||||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||||
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
|
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
|
||||||
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
|
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
|
||||||
from timm.data.random_erasing import RandomErasing
|
from timm.data.random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
@ -49,10 +49,10 @@ def transforms_noaug_train(
|
|||||||
tfl += [ToNumpy()]
|
tfl += [ToNumpy()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
||||||
tfl += [transforms.PILToTensor()]
|
tfl += [MaybePILToTensor()]
|
||||||
else:
|
else:
|
||||||
tfl += [
|
tfl += [
|
||||||
transforms.ToTensor(),
|
MaybeToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std)
|
std=torch.tensor(std)
|
||||||
@ -218,10 +218,10 @@ def transforms_imagenet_train(
|
|||||||
final_tfl += [ToNumpy()]
|
final_tfl += [ToNumpy()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||||
final_tfl += [transforms.PILToTensor()]
|
final_tfl += [MaybePILToTensor()]
|
||||||
else:
|
else:
|
||||||
final_tfl += [
|
final_tfl += [
|
||||||
transforms.ToTensor(),
|
MaybeToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std),
|
std=torch.tensor(std),
|
||||||
@ -318,10 +318,10 @@ def transforms_imagenet_eval(
|
|||||||
tfl += [ToNumpy()]
|
tfl += [ToNumpy()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
||||||
tfl += [transforms.PILToTensor()]
|
tfl += [MaybePILToTensor()]
|
||||||
else:
|
else:
|
||||||
tfl += [
|
tfl += [
|
||||||
transforms.ToTensor(),
|
MaybeToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std),
|
std=torch.tensor(std),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user