fix: Interpolation Warning

pull/608/head
zuchen.wang 2021-11-10 17:20:52 +08:00
parent 3be6d2c439
commit d543b80dcf
2 changed files with 9 additions and 5 deletions

View File

@ -25,6 +25,7 @@ import re
import PIL
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
from torchvision.transforms import InterpolationMode
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
@ -39,11 +40,11 @@ _HPARAMS_DEFAULT = dict(
img_mean=_FILL,
)
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.BILINEAR)
interpolation = kwargs.pop('resample', InterpolationMode.BILINEAR)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:

View File

@ -5,6 +5,7 @@
"""
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from .transforms import *
from .autoaugment import AutoAugment
@ -63,11 +64,12 @@ def build_transforms(cfg, is_train=True):
res.append(T.RandomApply([AutoAugment()], p=autoaug_prob))
if size_train[0] > 0:
res.append(T.Resize(size_train[0] if len(size_train) == 1 else size_train, interpolation=3))
res.append(T.Resize(size_train[0] if len(size_train) == 1 else size_train,
interpolation=InterpolationMode.BICUBIC))
if do_crop:
res.append(T.RandomResizedCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size,
interpolation=3,
interpolation=InterpolationMode.BICUBIC,
scale=crop_scale, ratio=crop_ratio))
if do_pad:
res.extend([T.Pad(padding_size, padding_mode=padding_mode),
@ -93,7 +95,8 @@ def build_transforms(cfg, is_train=True):
crop_size = cfg.INPUT.CROP.SIZE
if size_test[0] > 0:
res.append(T.Resize(size_test[0] if len(size_test) == 1 else size_test, interpolation=3))
res.append(T.Resize(size_test[0] if len(size_test) == 1 else size_test,
interpolation=InterpolationMode.BICUBIC))
if do_crop:
res.append(T.CenterCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size))
res.append(ToTensor())