mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
add RandomResizedCrop
This commit is contained in:
parent
be1dd2f5c2
commit
c119c4677c
@ -10,7 +10,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
from mmcv.image import (adjust_brightness, adjust_color, adjust_contrast,
|
from mmcv.image import (adjust_brightness, adjust_color, adjust_contrast,
|
||||||
adjust_hue, adjust_lighting, solarize)
|
adjust_hue, adjust_lighting, imcrop, imresize,
|
||||||
|
solarize)
|
||||||
from mmcv.transforms import BaseTransform
|
from mmcv.transforms import BaseTransform
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from timm.data import create_transform
|
from timm.data import create_transform
|
||||||
@ -795,3 +796,149 @@ class ColorJitter(BaseTransform):
|
|||||||
repr_str += f'saturation={self.saturation},'
|
repr_str += f'saturation={self.saturation},'
|
||||||
repr_str += f'saturation={self.hue})'
|
repr_str += f'saturation={self.hue})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class RandomResizedCrop(BaseTransform):
|
||||||
|
"""Crop the given image to random size and aspect ratio.
|
||||||
|
|
||||||
|
A crop of random size (default: of 0.08 to 1.0) of the original size and a
|
||||||
|
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
|
||||||
|
is made. This crop is finally resized to given size.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (Sequence | int): Desired output size of the crop. If size is an
|
||||||
|
int instead of sequence like (h, w), a square crop (size, size) is
|
||||||
|
made.
|
||||||
|
scale (Tuple): Range of the random size of the cropped image compared
|
||||||
|
to the original image. Defaults to (0.08, 1.0).
|
||||||
|
ratio (Tuple): Range of the random aspect ratio of the cropped image
|
||||||
|
compared to the original image. Defaults to (3. / 4., 4. / 3.).
|
||||||
|
max_attempts (int): Maximum number of attempts before falling back to
|
||||||
|
Central Crop. Defaults to 10.
|
||||||
|
interpolation (str): Interpolation method, accepted values are
|
||||||
|
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
|
||||||
|
'bilinear'.
|
||||||
|
backend (str): The image resize backend type, accepted values are
|
||||||
|
`cv2` and `pillow`. Defaults to `cv2`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
size: Union[int, Sequence[int]],
|
||||||
|
scale: Optional[Tuple] = (0.08, 1.0),
|
||||||
|
ratio: Optional[Tuple] = (3. / 4., 4. / 3.),
|
||||||
|
max_attempts: Optional[int] = 10,
|
||||||
|
interpolation: Optional[str] = 'bilinear',
|
||||||
|
backend: Optional[str] = 'cv2') -> None:
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(size, (tuple, list)):
|
||||||
|
self.size = size
|
||||||
|
else:
|
||||||
|
self.size = (size, size)
|
||||||
|
|
||||||
|
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||||
|
raise ValueError('range should be of kind (min, max). '
|
||||||
|
f'But received scale {scale} and rato {ratio}.')
|
||||||
|
assert isinstance(max_attempts, int) and max_attempts >= 0, \
|
||||||
|
'max_attempts mush be int and no less than 0.'
|
||||||
|
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
||||||
|
'lanczos')
|
||||||
|
if backend not in ['cv2', 'pillow']:
|
||||||
|
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||||
|
'Supported backends are "cv2", "pillow"')
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
self.ratio = ratio
|
||||||
|
self.max_attempts = max_attempts
|
||||||
|
self.interpolation = interpolation
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(
|
||||||
|
img: np.ndarray,
|
||||||
|
scale: Tuple,
|
||||||
|
ratio: Tuple,
|
||||||
|
max_attempts: Optional[int] = 10) -> Tuple[int, int, int, int]:
|
||||||
|
"""Get parameters for ``crop`` for a random sized crop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): Image to be cropped.
|
||||||
|
scale (Tuple): Range of the random size of the cropped image
|
||||||
|
compared to the original image size.
|
||||||
|
ratio (Tuple): Range of the random aspect ratio of the cropped
|
||||||
|
image compared to the original image area.
|
||||||
|
max_attempts (int): Maximum number of attempts before falling back
|
||||||
|
to central crop. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Params (ymin, xmin, ymax, xmax) to be passed to `crop` for
|
||||||
|
a random sized crop.
|
||||||
|
"""
|
||||||
|
height = img.shape[0]
|
||||||
|
width = img.shape[1]
|
||||||
|
area = height * width
|
||||||
|
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
target_area = random.uniform(*scale) * area
|
||||||
|
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||||
|
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||||
|
|
||||||
|
target_width = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
target_height = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
|
||||||
|
if 0 < target_width <= width and 0 < target_height <= height:
|
||||||
|
ymin = random.randint(0, height - target_height)
|
||||||
|
xmin = random.randint(0, width - target_width)
|
||||||
|
ymax = ymin + target_height - 1
|
||||||
|
xmax = xmin + target_width - 1
|
||||||
|
return ymin, xmin, ymax, xmax
|
||||||
|
|
||||||
|
# Fallback to central crop
|
||||||
|
in_ratio = float(width) / float(height)
|
||||||
|
if in_ratio < min(ratio):
|
||||||
|
target_width = width
|
||||||
|
target_height = int(round(target_width / min(ratio)))
|
||||||
|
elif in_ratio > max(ratio):
|
||||||
|
target_height = height
|
||||||
|
target_width = int(round(target_height * max(ratio)))
|
||||||
|
else: # whole image
|
||||||
|
target_width = width
|
||||||
|
target_height = height
|
||||||
|
ymin = (height - target_height) // 2
|
||||||
|
xmin = (width - target_width) // 2
|
||||||
|
ymax = ymin + target_height - 1
|
||||||
|
xmax = xmin + target_width - 1
|
||||||
|
return ymin, xmin, ymax, xmax
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
img = results['img']
|
||||||
|
get_params_args = dict(
|
||||||
|
img=img,
|
||||||
|
scale=self.scale,
|
||||||
|
ratio=self.ratio,
|
||||||
|
max_attempts=self.max_attempts)
|
||||||
|
ymin, xmin, ymax, xmax = self.get_params(**get_params_args)
|
||||||
|
img = imcrop(img, bboxes=np.array([xmin, ymin, xmax, ymax]))
|
||||||
|
results['img'] = imresize(
|
||||||
|
img,
|
||||||
|
tuple(self.size[::-1]),
|
||||||
|
interpolation=self.interpolation,
|
||||||
|
backend=self.backend)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__ + f'(size={self.size}'
|
||||||
|
repr_str += f', scale={tuple(round(s, 4) for s in self.scale)}'
|
||||||
|
repr_str += f', ratio={tuple(round(r, 4) for r in self.ratio)}'
|
||||||
|
repr_str += f', max_attempts={self.max_attempts}'
|
||||||
|
repr_str += f', interpolation={self.interpolation}'
|
||||||
|
repr_str += f', backend={self.backend})'
|
||||||
|
return repr_str
|
||||||
|
@ -1,8 +1,16 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from mmcv import imread
|
||||||
|
from mmcv.transforms import Compose
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import mmselfsup.datasets.pipelines.transforms as mmselfsup_transforms
|
||||||
from mmselfsup.datasets.pipelines import (
|
from mmselfsup.datasets.pipelines import (
|
||||||
BEiTMaskGenerator, ColorJitter, Lighting, RandomGaussianBlur,
|
BEiTMaskGenerator, ColorJitter, Lighting, RandomGaussianBlur,
|
||||||
RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic,
|
RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic,
|
||||||
@ -128,6 +136,8 @@ def test_random_rotation():
|
|||||||
assert list(results['img'].shape) == [4, 3, 224, 224]
|
assert list(results['img'].shape) == [4, 3, 224, 224]
|
||||||
assert list(results['rot_label'].shape) == [4]
|
assert list(results['rot_label'].shape) == [4]
|
||||||
|
|
||||||
|
assert isinstance(str(module), str)
|
||||||
|
|
||||||
|
|
||||||
def test_random_patch():
|
def test_random_patch():
|
||||||
transform = dict()
|
transform = dict()
|
||||||
@ -140,6 +150,8 @@ def test_random_patch():
|
|||||||
assert list(results['img'].shape) == [8, 6, 53, 53]
|
assert list(results['img'].shape) == [8, 6, 53, 53]
|
||||||
assert list(results['patch_label'].shape) == [8]
|
assert list(results['patch_label'].shape) == [8]
|
||||||
|
|
||||||
|
assert isinstance(str(module), str)
|
||||||
|
|
||||||
|
|
||||||
def test_color_jitter():
|
def test_color_jitter():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -163,3 +175,199 @@ def test_color_jitter():
|
|||||||
assert results['img'].shape == original_img.shape
|
assert results['img'].shape == original_img.shape
|
||||||
|
|
||||||
assert isinstance(str(transform), str)
|
assert isinstance(str(transform), str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_randomresizedcrop():
|
||||||
|
ori_img = imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../../data/color.jpg'), 'color')
|
||||||
|
ori_img_pil = Image.open(
|
||||||
|
osp.join(osp.dirname(__file__), '../../data/color.jpg'))
|
||||||
|
|
||||||
|
seed = random.randint(0, 100)
|
||||||
|
|
||||||
|
# test when scale is not of kind (min, max)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
kwargs = dict(
|
||||||
|
size=(200, 300), scale=(1.0, 0.08), ratio=(3. / 4., 4. / 3.))
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
composed_transform(results)['img']
|
||||||
|
|
||||||
|
# test when ratio is not of kind (min, max)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
kwargs = dict(
|
||||||
|
size=(200, 300), scale=(0.08, 1.0), ratio=(4. / 3., 3. / 4.))
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
composed_transform(results)['img']
|
||||||
|
|
||||||
|
# test crop size is int
|
||||||
|
kwargs = dict(size=200, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
baseline = composed_transform(ori_img_pil)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
|
||||||
|
# test __repr__()
|
||||||
|
print(composed_transform)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert np.array(img).shape == (200, 200, 3)
|
||||||
|
assert np.array(baseline).shape == (200, 200, 3)
|
||||||
|
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
||||||
|
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
||||||
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
# test crop size < image size
|
||||||
|
kwargs = dict(size=(200, 300), scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
baseline = composed_transform(ori_img_pil)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert np.array(img).shape == (200, 300, 3)
|
||||||
|
assert np.array(baseline).shape == (200, 300, 3)
|
||||||
|
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
||||||
|
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
||||||
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
# test crop size > image size
|
||||||
|
kwargs = dict(size=(600, 700), scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
baseline = composed_transform(ori_img_pil)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert np.array(img).shape == (600, 700, 3)
|
||||||
|
assert np.array(baseline).shape == (600, 700, 3)
|
||||||
|
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
||||||
|
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
||||||
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
# test cropping the whole image
|
||||||
|
kwargs = dict(
|
||||||
|
size=(ori_img.shape[0], ori_img.shape[1]),
|
||||||
|
scale=(1.0, 2.0),
|
||||||
|
ratio=(1.0, 2.0))
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
baseline = composed_transform(ori_img_pil)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
||||||
|
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
||||||
|
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
||||||
|
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
||||||
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
# test central crop when in_ratio < min(ratio)
|
||||||
|
kwargs = dict(
|
||||||
|
size=(ori_img.shape[0], ori_img.shape[1]),
|
||||||
|
scale=(1.0, 2.0),
|
||||||
|
ratio=(2., 3.))
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
baseline = composed_transform(ori_img_pil)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
||||||
|
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
||||||
|
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
||||||
|
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
||||||
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
# test central crop when in_ratio > max(ratio)
|
||||||
|
kwargs = dict(
|
||||||
|
size=(ori_img.shape[0], ori_img.shape[1]),
|
||||||
|
scale=(1.0, 2.0),
|
||||||
|
ratio=(3. / 4., 1))
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
baseline = composed_transform(ori_img_pil)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
||||||
|
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
||||||
|
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
||||||
|
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
||||||
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
# test different interpolation types
|
||||||
|
for mode in ['nearest', 'bilinear', 'bicubic', 'area', 'lanczos']:
|
||||||
|
kwargs = dict(
|
||||||
|
size=(600, 700),
|
||||||
|
scale=(0.08, 1.0),
|
||||||
|
ratio=(3. / 4., 4. / 3.),
|
||||||
|
interpolation=mode)
|
||||||
|
aug = []
|
||||||
|
aug.extend([mmselfsup_transforms.RandomResizedCrop(**kwargs)])
|
||||||
|
composed_transform = Compose(aug)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = ori_img
|
||||||
|
img = composed_transform(results)['img']
|
||||||
|
assert img.shape == (600, 700, 3)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user