mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add interpolation mode handling to transforms. Removes InterpolationMode warning. Works for torchvision versions w/ and w/o InterpolationMode enum. Fix #738.
This commit is contained in:
parent
ed41d32637
commit
a41de1f666
@ -1,5 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
|
try:
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
has_interpolation_mode = True
|
||||||
|
except ImportError:
|
||||||
|
has_interpolation_mode = False
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import warnings
|
import warnings
|
||||||
import math
|
import math
|
||||||
@ -31,28 +36,50 @@ class ToTensor:
|
|||||||
|
|
||||||
|
|
||||||
_pil_interpolation_to_str = {
|
_pil_interpolation_to_str = {
|
||||||
Image.NEAREST: 'PIL.Image.NEAREST',
|
Image.NEAREST: 'nearest',
|
||||||
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
Image.BILINEAR: 'bilinear',
|
||||||
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
Image.BICUBIC: 'bicubic',
|
||||||
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
Image.BOX: 'box',
|
||||||
Image.HAMMING: 'PIL.Image.HAMMING',
|
Image.HAMMING: 'hamming',
|
||||||
Image.BOX: 'PIL.Image.BOX',
|
Image.LANCZOS: 'lanczos',
|
||||||
}
|
}
|
||||||
|
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
|
||||||
|
|
||||||
|
|
||||||
def _pil_interp(method):
|
if has_interpolation_mode:
|
||||||
if method == 'bicubic':
|
_torch_interpolation_to_str = {
|
||||||
return Image.BICUBIC
|
InterpolationMode.NEAREST: 'nearest',
|
||||||
elif method == 'lanczos':
|
InterpolationMode.BILINEAR: 'bilinear',
|
||||||
return Image.LANCZOS
|
InterpolationMode.BICUBIC: 'bicubic',
|
||||||
elif method == 'hamming':
|
InterpolationMode.BOX: 'box',
|
||||||
return Image.HAMMING
|
InterpolationMode.HAMMING: 'hamming',
|
||||||
|
InterpolationMode.LANCZOS: 'lanczos',
|
||||||
|
}
|
||||||
|
_str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
|
||||||
else:
|
else:
|
||||||
# default bilinear, do we want to allow nearest?
|
_pil_interpolation_to_torch = {}
|
||||||
return Image.BILINEAR
|
_torch_interpolation_to_str = {}
|
||||||
|
|
||||||
|
|
||||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
def str_to_pil_interp(mode_str):
|
||||||
|
return _str_to_pil_interpolation[mode_str]
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_interp_mode(mode_str):
|
||||||
|
if has_interpolation_mode:
|
||||||
|
return _str_to_torch_interpolation[mode_str]
|
||||||
|
else:
|
||||||
|
return _str_to_pil_interpolation[mode_str]
|
||||||
|
|
||||||
|
|
||||||
|
def interp_mode_to_str(mode):
|
||||||
|
if has_interpolation_mode:
|
||||||
|
return _torch_interpolation_to_str[mode]
|
||||||
|
else:
|
||||||
|
return _pil_interpolation_to_str[mode]
|
||||||
|
|
||||||
|
|
||||||
|
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
|
||||||
|
|
||||||
|
|
||||||
class RandomResizedCropAndInterpolation:
|
class RandomResizedCropAndInterpolation:
|
||||||
@ -82,7 +109,7 @@ class RandomResizedCropAndInterpolation:
|
|||||||
if interpolation == 'random':
|
if interpolation == 'random':
|
||||||
self.interpolation = _RANDOM_INTERPOLATION
|
self.interpolation = _RANDOM_INTERPOLATION
|
||||||
else:
|
else:
|
||||||
self.interpolation = _pil_interp(interpolation)
|
self.interpolation = str_to_interp_mode(interpolation)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
|
||||||
@ -146,9 +173,9 @@ class RandomResizedCropAndInterpolation:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if isinstance(self.interpolation, (tuple, list)):
|
if isinstance(self.interpolation, (tuple, list)):
|
||||||
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
|
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
|
||||||
else:
|
else:
|
||||||
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
interpolate_str = interp_mode_to_str(self.interpolation)
|
||||||
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
||||||
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
||||||
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
||||||
|
@ -10,7 +10,7 @@ from torchvision import transforms
|
|||||||
|
|
||||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||||
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy
|
||||||
from timm.data.random_erasing import RandomErasing
|
from timm.data.random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ def transforms_noaug_train(
|
|||||||
# random interpolation not supported with no-aug
|
# random interpolation not supported with no-aug
|
||||||
interpolation = 'bilinear'
|
interpolation = 'bilinear'
|
||||||
tfl = [
|
tfl = [
|
||||||
transforms.Resize(img_size, _pil_interp(interpolation)),
|
transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
|
||||||
transforms.CenterCrop(img_size)
|
transforms.CenterCrop(img_size)
|
||||||
]
|
]
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
@ -87,7 +87,7 @@ def transforms_imagenet_train(
|
|||||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||||
)
|
)
|
||||||
if interpolation and interpolation != 'random':
|
if interpolation and interpolation != 'random':
|
||||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
aa_params['interpolation'] = str_to_pil_interp(interpolation)
|
||||||
if auto_augment.startswith('rand'):
|
if auto_augment.startswith('rand'):
|
||||||
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||||
elif auto_augment.startswith('augmix'):
|
elif auto_augment.startswith('augmix'):
|
||||||
@ -147,7 +147,7 @@ def transforms_imagenet_eval(
|
|||||||
scale_size = int(math.floor(img_size / crop_pct))
|
scale_size = int(math.floor(img_size / crop_pct))
|
||||||
|
|
||||||
tfl = [
|
tfl = [
|
||||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
|
||||||
transforms.CenterCrop(img_size),
|
transforms.CenterCrop(img_size),
|
||||||
]
|
]
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user