add RandomResizedCrop

This commit is contained in:
fangyixiao18 2022-05-17 11:51:03 +08:00
parent be1dd2f5c2
commit c119c4677c
2 changed files with 356 additions and 1 deletions

View File

@ -10,7 +10,8 @@ import numpy as np
import torch
import torchvision.transforms.functional as F
from mmcv.image import (adjust_brightness, adjust_color, adjust_contrast,
adjust_hue, adjust_lighting, solarize)
adjust_hue, adjust_lighting, imcrop, imresize,
solarize)
from mmcv.transforms import BaseTransform
from PIL import Image, ImageFilter
from timm.data import create_transform
@ -795,3 +796,149 @@ class ColorJitter(BaseTransform):
repr_str += f'saturation={self.saturation},'
repr_str += f'saturation={self.hue})'
return repr_str
@TRANSFORMS.register_module()
class RandomResizedCrop(BaseTransform):
"""Crop the given image to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
is made. This crop is finally resized to given size.
Required Keys:
- img
Modified Keys:
- img
Args:
size (Sequence | int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
scale (Tuple): Range of the random size of the cropped image compared
to the original image. Defaults to (0.08, 1.0).
ratio (Tuple): Range of the random aspect ratio of the cropped image
compared to the original image. Defaults to (3. / 4., 4. / 3.).
max_attempts (int): Maximum number of attempts before falling back to
Central Crop. Defaults to 10.
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bilinear'.
backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Defaults to `cv2`.
"""
def __init__(self,
size: Union[int, Sequence[int]],
scale: Optional[Tuple] = (0.08, 1.0),
ratio: Optional[Tuple] = (3. / 4., 4. / 3.),
max_attempts: Optional[int] = 10,
interpolation: Optional[str] = 'bilinear',
backend: Optional[str] = 'cv2') -> None:
super().__init__()
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
raise ValueError('range should be of kind (min, max). '
f'But received scale {scale} and rato {ratio}.')
assert isinstance(max_attempts, int) and max_attempts >= 0, \
'max_attempts mush be int and no less than 0.'
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
self.scale = scale
self.ratio = ratio
self.max_attempts = max_attempts
self.interpolation = interpolation
self.backend = backend
@staticmethod
def get_params(
img: np.ndarray,
scale: Tuple,
ratio: Tuple,
max_attempts: Optional[int] = 10) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (np.ndarray): Image to be cropped.
scale (Tuple): Range of the random size of the cropped image
compared to the original image size.
ratio (Tuple): Range of the random aspect ratio of the cropped
image compared to the original image area.
max_attempts (int): Maximum number of attempts before falling back
to central crop. Defaults to 10.
Returns:
tuple: Params (ymin, xmin, ymax, xmax) to be passed to `crop` for
a random sized crop.
"""
height = img.shape[0]
width = img.shape[1]
area = height * width
for _ in range(max_attempts):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
target_width = int(round(math.sqrt(target_area * aspect_ratio)))
target_height = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < target_width <= width and 0 < target_height <= height:
ymin = random.randint(0, height - target_height)
xmin = random.randint(0, width - target_width)
ymax = ymin + target_height - 1
xmax = xmin + target_width - 1
return ymin, xmin, ymax, xmax
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
target_width = width
target_height = int(round(target_width / min(ratio)))
elif in_ratio > max(ratio):
target_height = height
target_width = int(round(target_height * max(ratio)))
else: # whole image
target_width = width
target_height = height
ymin = (height - target_height) // 2
xmin = (width - target_width) // 2
ymax = ymin + target_height - 1
xmax = xmin + target_width - 1
return ymin, xmin, ymax, xmax
def transform(self, results: Dict) -> Dict:
img = results['img']
get_params_args = dict(
img=img,
scale=self.scale,
ratio=self.ratio,
max_attempts=self.max_attempts)
ymin, xmin, ymax, xmax = self.get_params(**get_params_args)
img = imcrop(img, bboxes=np.array([xmin, ymin, xmax, ymax]))
results['img'] = imresize(
img,
tuple(self.size[::-1]),
interpolation=self.interpolation,
backend=self.backend)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__ + f'(size={self.size}'
repr_str += f', scale={tuple(round(s, 4) for s in self.scale)}'
repr_str += f', ratio={tuple(round(r, 4) for r in self.ratio)}'
repr_str += f', max_attempts={self.max_attempts}'
repr_str += f', interpolation={self.interpolation}'
repr_str += f', backend={self.backend})'
return repr_str

View File

@ -1,8 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import random
import numpy as np
import pytest
import torch
import torchvision
from mmcv import imread
from mmcv.transforms import Compose
from PIL import Image
import mmselfsup.datasets.pipelines.transforms as mmselfsup_transforms
from mmselfsup.datasets.pipelines import (
BEiTMaskGenerator, ColorJitter, Lighting, RandomGaussianBlur,
RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic,
@ -128,6 +136,8 @@ def test_random_rotation():
assert list(results['img'].shape) == [4, 3, 224, 224]
assert list(results['rot_label'].shape) == [4]
assert isinstance(str(module), str)
def test_random_patch():
transform = dict()
@ -140,6 +150,8 @@ def test_random_patch():
assert list(results['img'].shape) == [8, 6, 53, 53]
assert list(results['patch_label'].shape) == [8]
assert isinstance(str(module), str)
def test_color_jitter():
with pytest.raises(ValueError):
@ -163,3 +175,199 @@ def test_color_jitter():
assert results['img'].shape == original_img.shape
assert isinstance(str(transform), str)
def test_randomresizedcrop():
ori_img = imread(
osp.join(osp.dirname(__file__), '../../data/color.jpg'), 'color')
ori_img_pil = Image.open(
osp.join(osp.dirname(__file__), '../../data/color.jpg'))
seed = random.randint(0, 100)
# test when scale is not of kind (min, max)
with pytest.raises(ValueError):
kwargs = dict(
size=(200, 300), scale=(1.0, 0.08), ratio=(3. / 4., 4. / 3.))
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
composed_transform(results)['img']
# test when ratio is not of kind (min, max)
with pytest.raises(ValueError):
kwargs = dict(
size=(200, 300), scale=(0.08, 1.0), ratio=(4. / 3., 3. / 4.))
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
composed_transform(results)['img']
# test crop size is int
kwargs = dict(size=200, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
baseline = composed_transform(ori_img_pil)
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
# test __repr__()
print(composed_transform)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert np.array(img).shape == (200, 200, 3)
assert np.array(baseline).shape == (200, 200, 3)
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
assert nonzero == nonzero_transform
# test crop size < image size
kwargs = dict(size=(200, 300), scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
baseline = composed_transform(ori_img_pil)
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert np.array(img).shape == (200, 300, 3)
assert np.array(baseline).shape == (200, 300, 3)
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
assert nonzero == nonzero_transform
# test crop size > image size
kwargs = dict(size=(600, 700), scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
baseline = composed_transform(ori_img_pil)
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert np.array(img).shape == (600, 700, 3)
assert np.array(baseline).shape == (600, 700, 3)
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
assert nonzero == nonzero_transform
# test cropping the whole image
kwargs = dict(
size=(ori_img.shape[0], ori_img.shape[1]),
scale=(1.0, 2.0),
ratio=(1.0, 2.0))
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
baseline = composed_transform(ori_img_pil)
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
assert nonzero == nonzero_transform
# test central crop when in_ratio < min(ratio)
kwargs = dict(
size=(ori_img.shape[0], ori_img.shape[1]),
scale=(1.0, 2.0),
ratio=(2., 3.))
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
baseline = composed_transform(ori_img_pil)
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
assert nonzero == nonzero_transform
# test central crop when in_ratio > max(ratio)
kwargs = dict(
size=(ori_img.shape[0], ori_img.shape[1]),
scale=(1.0, 2.0),
ratio=(3. / 4., 1))
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
baseline = composed_transform(ori_img_pil)
random.seed(seed)
np.random.seed(seed)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
assert nonzero == nonzero_transform
# test different interpolation types
for mode in ['nearest', 'bilinear', 'bicubic', 'area', 'lanczos']:
kwargs = dict(
size=(600, 700),
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation=mode)
aug = []
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
composed_transform = Compose(aug)
results = dict()
results['img'] = ori_img
img = composed_transform(results)['img']
assert img.shape == (600, 700, 3)