[Refactor] refactor randomCrop, randomResizeCrop and CenterCrop。
parent
6563f5f448
commit
c78b5597d8
|
@ -8,7 +8,7 @@ preprocess_cfg = dict(
|
|||
to_rgb=False)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='RandomCrop', size=32, padding=4),
|
||||
dict(type='RandomCrop', crop_size=32, padding=4),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
|
|
@ -8,7 +8,7 @@ preprocess_cfg = dict(
|
|||
to_rgb=False)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='RandomCrop', size=32, padding=4),
|
||||
dict(type='RandomCrop', crop_size=32, padding=4),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
|
|
@ -11,7 +11,7 @@ preprocess_cfg = dict(
|
|||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=510),
|
||||
dict(type='RandomCrop', size=384),
|
||||
dict(type='RandomCrop', crop_size=384),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
|
|
@ -10,7 +10,7 @@ preprocess_cfg = dict(
|
|||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=600),
|
||||
dict(type='RandomCrop', size=448),
|
||||
dict(type='RandomCrop', crop_size=448),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
|
|
@ -6,16 +6,15 @@ from .auto_augment import (AutoAugment, AutoContrast, Brightness,
|
|||
from .compose import Compose
|
||||
from .formatting import (Collect, ImageToTensor, PackClsInputs, ToNumpy, ToPIL,
|
||||
ToTensor, Transpose, to_tensor)
|
||||
from .processing import (CenterCrop, ColorJitter, Lighting, Normalize, Pad,
|
||||
RandomCrop, RandomErasing, RandomGrayscale,
|
||||
RandomResizedCrop)
|
||||
from .processing import (ColorJitter, Lighting, Normalize, Pad, RandomCrop,
|
||||
RandomErasing, RandomGrayscale, RandomResizedCrop)
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||
'Transpose', 'Collect', 'CenterCrop', 'Normalize', 'RandomCrop',
|
||||
'RandomResizedCrop', 'RandomGrayscale', 'Shear', 'Translate', 'Rotate',
|
||||
'Invert', 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast',
|
||||
'Equalize', 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment',
|
||||
'SolarizeAdd', 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter',
|
||||
'RandomErasing', 'Pad', 'PackClsInputs'
|
||||
'Transpose', 'Collect', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
||||
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad',
|
||||
'PackClsInputs'
|
||||
]
|
||||
|
|
|
@ -3,11 +3,11 @@ import inspect
|
|||
import math
|
||||
import random
|
||||
from numbers import Number
|
||||
from typing import Dict, Sequence
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import BaseTransform
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
|
@ -20,13 +20,22 @@ except ImportError:
|
|||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCrop(object):
|
||||
class RandomCrop(BaseTransform):
|
||||
"""Crop the given Image at a random location.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
crop_size (sequence or int): Desired output size of the crop. If
|
||||
crop_size is an int instead of sequence like (h, w), a square crop
|
||||
(crop_size, crop_size) is made.
|
||||
padding (int or sequence, optional): Optional padding on each border
|
||||
of the image. If a sequence of length 4 is provided, it is used to
|
||||
pad left, top, right, bottom borders respectively. If a sequence
|
||||
|
@ -56,15 +65,18 @@ class RandomCrop(object):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
size,
|
||||
padding=None,
|
||||
pad_if_needed=False,
|
||||
pad_val=0,
|
||||
padding_mode='constant'):
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
crop_size: Union[Sequence, int],
|
||||
padding: Optional[Union[Sequence, int]] = None,
|
||||
pad_if_needed: bool = False,
|
||||
pad_val: Union[Number, Sequence[Number]] = 0,
|
||||
padding_mode: str = 'constant') -> None:
|
||||
if isinstance(crop_size, Sequence):
|
||||
assert len(crop_size) == 2
|
||||
assert crop_size[0] > 0 and crop_size[1] > 0
|
||||
self.crop_size = crop_size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
assert crop_size > 0
|
||||
self.crop_size = (crop_size, crop_size)
|
||||
# check padding mode
|
||||
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
||||
self.padding = padding
|
||||
|
@ -72,98 +84,111 @@ class RandomCrop(object):
|
|||
self.pad_val = pad_val
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, output_size):
|
||||
@cache_randomness
|
||||
def rand_crop_params(self, img: np.ndarray):
|
||||
"""Get parameters for ``crop`` for a random crop.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be cropped.
|
||||
output_size (tuple): Expected output size of the crop.
|
||||
|
||||
Returns:
|
||||
tuple: Params (xmin, ymin, target_height, target_width) to be
|
||||
tuple: Params (offset_h, offset_w, target_h, target_w) to be
|
||||
passed to ``crop`` for random crop.
|
||||
"""
|
||||
height = img.shape[0]
|
||||
width = img.shape[1]
|
||||
target_height, target_width = output_size
|
||||
if width == target_width and height == target_height:
|
||||
return 0, 0, height, width
|
||||
h, w = img.shape[:2]
|
||||
target_h, target_w = self.crop_size
|
||||
if w == target_w and h == target_h:
|
||||
return 0, 0, h, w
|
||||
elif w < target_w or h < target_h:
|
||||
target_w = min(w, target_w)
|
||||
target_h = min(w, target_h)
|
||||
|
||||
ymin = random.randint(0, height - target_height)
|
||||
xmin = random.randint(0, width - target_width)
|
||||
return ymin, xmin, target_height, target_width
|
||||
offset_h = np.random.randint(0, h - target_h + 1)
|
||||
offset_w = np.random.randint(0, w - target_w + 1)
|
||||
|
||||
return offset_h, offset_w, target_h, target_w
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to randomly crop images.
|
||||
|
||||
def __call__(self, results):
|
||||
"""
|
||||
Args:
|
||||
img (ndarray): Image to be cropped.
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Randomly cropped results, 'img_shape'
|
||||
key in result dict is updated according to crop size.
|
||||
"""
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
if self.padding is not None:
|
||||
img = mmcv.impad(
|
||||
img, padding=self.padding, pad_val=self.pad_val)
|
||||
img = results['img']
|
||||
if self.padding is not None:
|
||||
img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val)
|
||||
|
||||
# pad the height if needed
|
||||
if self.pad_if_needed and img.shape[0] < self.size[0]:
|
||||
img = mmcv.impad(
|
||||
img,
|
||||
padding=(0, self.size[0] - img.shape[0], 0,
|
||||
self.size[0] - img.shape[0]),
|
||||
pad_val=self.pad_val,
|
||||
padding_mode=self.padding_mode)
|
||||
# pad img if needed
|
||||
if self.pad_if_needed:
|
||||
h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2)
|
||||
w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2)
|
||||
|
||||
# pad the width if needed
|
||||
if self.pad_if_needed and img.shape[1] < self.size[1]:
|
||||
img = mmcv.impad(
|
||||
img,
|
||||
padding=(self.size[1] - img.shape[1], 0,
|
||||
self.size[1] - img.shape[1], 0),
|
||||
pad_val=self.pad_val,
|
||||
padding_mode=self.padding_mode)
|
||||
|
||||
ymin, xmin, height, width = self.get_params(img, self.size)
|
||||
results[key] = mmcv.imcrop(
|
||||
img = mmcv.impad(
|
||||
img,
|
||||
np.array([
|
||||
xmin,
|
||||
ymin,
|
||||
xmin + width - 1,
|
||||
ymin + height - 1,
|
||||
]))
|
||||
padding=(w_pad, h_pad, w_pad, h_pad),
|
||||
pad_val=self.pad_val,
|
||||
padding_mode=self.padding_mode)
|
||||
|
||||
offset_h, offset_w, target_h, target_w = self.rand_crop_params(img)
|
||||
img = mmcv.imcrop(
|
||||
img,
|
||||
np.array([
|
||||
offset_w,
|
||||
offset_h,
|
||||
offset_w + target_w - 1,
|
||||
offset_h + target_h - 1,
|
||||
]))
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return (self.__class__.__name__ +
|
||||
f'(size={self.size}, padding={self.padding})')
|
||||
"""Print the basic information of the transform.
|
||||
|
||||
Returns:
|
||||
str: Formatted string.
|
||||
"""
|
||||
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
|
||||
repr_str += f', padding={self.padding}'
|
||||
repr_str += f', pad_if_needed={self.pad_if_needed}'
|
||||
repr_str += f', pad_val={self.pad_val}'
|
||||
repr_str += f', padding_mode={self.padding_mode})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomResizedCrop(object):
|
||||
"""Crop the given image to random size and aspect ratio.
|
||||
class RandomResizedCrop(BaseTransform):
|
||||
"""Crop the given image to random scale and aspect ratio.
|
||||
|
||||
A crop of random size (default: of 0.08 to 1.0) of the original size and a
|
||||
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
|
||||
is made. This crop is finally resized to given size.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Args:
|
||||
size (sequence | int): Desired output size of the crop. If size is an
|
||||
scale (sequence | int): Desired output scale of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
scale (tuple): Range of the random size of the cropped image compared
|
||||
to the original image. Defaults to (0.08, 1.0).
|
||||
ratio (tuple): Range of the random aspect ratio of the cropped image
|
||||
compared to the original image. Defaults to (3. / 4., 4. / 3.).
|
||||
crop_ratio_range (tuple): Range of the random size of the cropped
|
||||
image compared to the original image. Defaults to (0.08, 1.0).
|
||||
aspect_ratio_range (tuple): Range of the random aspect ratio of the
|
||||
cropped image compared to the original image.
|
||||
Defaults to (3. / 4., 4. / 3.).
|
||||
max_attempts (int): Maximum number of attempts before falling back to
|
||||
Central Crop. Defaults to 10.
|
||||
efficientnet_style (bool): Whether to use efficientnet style Random
|
||||
ResizedCrop. Defaults to False.
|
||||
min_covered (Number): Minimum ratio of the cropped area to the original
|
||||
area. Only valid if efficientnet_style is true. Defaults to 0.1.
|
||||
crop_padding (int): The crop padding parameter in efficientnet style
|
||||
center crop. Only valid if efficientnet_style is true.
|
||||
Defaults to 32.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
|
||||
'bilinear'.
|
||||
|
@ -172,217 +197,119 @@ class RandomResizedCrop(object):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3. / 4., 4. / 3.),
|
||||
max_attempts=10,
|
||||
efficientnet_style=False,
|
||||
min_covered=0.1,
|
||||
crop_padding=32,
|
||||
interpolation='bilinear',
|
||||
backend='cv2'):
|
||||
if efficientnet_style:
|
||||
assert isinstance(size, int)
|
||||
self.size = (size, size)
|
||||
assert crop_padding >= 0
|
||||
scale: Union[Sequence, int],
|
||||
crop_ratio_range: Tuple[float, float] = (0.08, 1.0),
|
||||
aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.),
|
||||
max_attempts: int = 10,
|
||||
interpolation: str = 'bilinear',
|
||||
backend: str = 'cv2') -> None:
|
||||
if isinstance(scale, Sequence):
|
||||
assert len(scale) == 2
|
||||
assert scale[0] > 0 and scale[1] > 0
|
||||
self.scale = scale
|
||||
else:
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||
raise ValueError('range should be of kind (min, max). '
|
||||
f'But received scale {scale} and rato {ratio}.')
|
||||
assert min_covered >= 0, 'min_covered should be no less than 0.'
|
||||
assert scale > 0
|
||||
self.scale = (scale, scale)
|
||||
if (crop_ratio_range[0] > crop_ratio_range[1]) or (
|
||||
aspect_ratio_range[0] > aspect_ratio_range[1]):
|
||||
raise ValueError(
|
||||
'range should be of kind (min, max). '
|
||||
f'But received crop_ratio_range {crop_ratio_range} '
|
||||
f'and aspect_ratio_range {aspect_ratio_range}.')
|
||||
assert isinstance(max_attempts, int) and max_attempts >= 0, \
|
||||
'max_attempts mush be int and no less than 0.'
|
||||
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
||||
'lanczos')
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
'Supported backends are "cv2", "pillow"')
|
||||
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
self.crop_ratio_range = crop_ratio_range
|
||||
self.aspect_ratio_range = aspect_ratio_range
|
||||
self.max_attempts = max_attempts
|
||||
self.efficientnet_style = efficientnet_style
|
||||
self.min_covered = min_covered
|
||||
self.crop_padding = crop_padding
|
||||
self.interpolation = interpolation
|
||||
self.backend = backend
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, scale, ratio, max_attempts=10):
|
||||
@cache_randomness
|
||||
def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]:
|
||||
"""Get parameters for ``crop`` for a random sized crop.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be cropped.
|
||||
scale (tuple): Range of the random size of the cropped image
|
||||
compared to the original image size.
|
||||
ratio (tuple): Range of the random aspect ratio of the cropped
|
||||
image compared to the original image area.
|
||||
max_attempts (int): Maximum number of attempts before falling back
|
||||
to central crop. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
tuple: Params (ymin, xmin, ymax, xmax) to be passed to `crop` for
|
||||
a random sized crop.
|
||||
tuple: Params (offset_h, offset_w, target_h, target_w) to be
|
||||
passed to `crop` for a random sized crop.
|
||||
"""
|
||||
height = img.shape[0]
|
||||
width = img.shape[1]
|
||||
area = height * width
|
||||
h, w = img.shape[:2]
|
||||
area = h * w
|
||||
|
||||
for _ in range(max_attempts):
|
||||
target_area = random.uniform(*scale) * area
|
||||
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
for _ in range(self.max_attempts):
|
||||
target_area = np.random.uniform(*self.crop_ratio_range) * area
|
||||
log_ratio = (math.log(self.aspect_ratio_range[0]),
|
||||
math.log(self.aspect_ratio_range[1]))
|
||||
aspect_ratio = math.exp(np.random.uniform(*log_ratio))
|
||||
target_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
target_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
target_width = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
target_height = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if 0 < target_w <= w and 0 < target_h <= h:
|
||||
offset_h = np.random.randint(0, h - target_h + 1)
|
||||
offset_w = np.random.randint(0, w - target_w + 1)
|
||||
|
||||
if 0 < target_width <= width and 0 < target_height <= height:
|
||||
ymin = random.randint(0, height - target_height)
|
||||
xmin = random.randint(0, width - target_width)
|
||||
ymax = ymin + target_height - 1
|
||||
xmax = xmin + target_width - 1
|
||||
return ymin, xmin, ymax, xmax
|
||||
return offset_h, offset_w, target_h, target_w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = float(width) / float(height)
|
||||
if in_ratio < min(ratio):
|
||||
target_width = width
|
||||
target_height = int(round(target_width / min(ratio)))
|
||||
elif in_ratio > max(ratio):
|
||||
target_height = height
|
||||
target_width = int(round(target_height * max(ratio)))
|
||||
in_ratio = float(w) / float(h)
|
||||
if in_ratio < min(self.aspect_ratio_range):
|
||||
target_w = w
|
||||
target_h = int(round(target_w / min(self.aspect_ratio_range)))
|
||||
elif in_ratio > max(self.aspect_ratio_range):
|
||||
target_h = h
|
||||
target_w = int(round(target_h * max(self.aspect_ratio_range)))
|
||||
else: # whole image
|
||||
target_width = width
|
||||
target_height = height
|
||||
ymin = (height - target_height) // 2
|
||||
xmin = (width - target_width) // 2
|
||||
ymax = ymin + target_height - 1
|
||||
xmax = xmin + target_width - 1
|
||||
return ymin, xmin, ymax, xmax
|
||||
target_w = w
|
||||
target_h = h
|
||||
offset_h = (h - target_h) // 2
|
||||
offset_w = (w - target_w) // 2
|
||||
return offset_h, offset_w, target_h, target_w
|
||||
|
||||
# https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa
|
||||
@staticmethod
|
||||
def get_params_efficientnet_style(img,
|
||||
size,
|
||||
scale,
|
||||
ratio,
|
||||
max_attempts=10,
|
||||
min_covered=0.1,
|
||||
crop_padding=32):
|
||||
"""Get parameters for ``crop`` for a random sized crop in efficientnet
|
||||
style.
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to randomly resized crop images.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be cropped.
|
||||
size (sequence): Desired output size of the crop.
|
||||
scale (tuple): Range of the random size of the cropped image
|
||||
compared to the original image size.
|
||||
ratio (tuple): Range of the random aspect ratio of the cropped
|
||||
image compared to the original image area.
|
||||
max_attempts (int): Maximum number of attempts before falling back
|
||||
to central crop. Defaults to 10.
|
||||
min_covered (Number): Minimum ratio of the cropped area to the
|
||||
original area. Only valid if efficientnet_style is true.
|
||||
Defaults to 0.1.
|
||||
crop_padding (int): The crop padding parameter in efficientnet
|
||||
style center crop. Defaults to 32.
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: Params (ymin, xmin, ymax, xmax) to be passed to `crop` for
|
||||
a random sized crop.
|
||||
dict: Randomly resized cropped results, 'img_shape'
|
||||
key in result dict is updated according to crop size.
|
||||
"""
|
||||
height, width = img.shape[:2]
|
||||
area = height * width
|
||||
min_target_area = scale[0] * area
|
||||
max_target_area = scale[1] * area
|
||||
img = results['img']
|
||||
offset_h, offset_w, target_h, target_w = self.rand_crop_params(img)
|
||||
img = mmcv.imcrop(
|
||||
img,
|
||||
bboxes=np.array([
|
||||
offset_w, offset_h, offset_w + target_w - 1,
|
||||
offset_h + target_h - 1
|
||||
]))
|
||||
img = mmcv.imresize(
|
||||
img,
|
||||
tuple(self.scale[::-1]),
|
||||
interpolation=self.interpolation,
|
||||
backend=self.backend)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
|
||||
for _ in range(max_attempts):
|
||||
aspect_ratio = random.uniform(*ratio)
|
||||
min_target_height = int(
|
||||
round(math.sqrt(min_target_area / aspect_ratio)))
|
||||
max_target_height = int(
|
||||
round(math.sqrt(max_target_area / aspect_ratio)))
|
||||
|
||||
if max_target_height * aspect_ratio > width:
|
||||
max_target_height = int((width + 0.5 - 1e-7) / aspect_ratio)
|
||||
if max_target_height * aspect_ratio > width:
|
||||
max_target_height -= 1
|
||||
|
||||
max_target_height = min(max_target_height, height)
|
||||
min_target_height = min(max_target_height, min_target_height)
|
||||
|
||||
# slightly differs from tf implementation
|
||||
target_height = int(
|
||||
round(random.uniform(min_target_height, max_target_height)))
|
||||
target_width = int(round(target_height * aspect_ratio))
|
||||
target_area = target_height * target_width
|
||||
|
||||
# slight differs from tf. In tf, if target_area > max_target_area,
|
||||
# area will be recalculated
|
||||
if (target_area < min_target_area or target_area > max_target_area
|
||||
or target_width > width or target_height > height
|
||||
or target_area < min_covered * area):
|
||||
continue
|
||||
|
||||
ymin = random.randint(0, height - target_height)
|
||||
xmin = random.randint(0, width - target_width)
|
||||
ymax = ymin + target_height - 1
|
||||
xmax = xmin + target_width - 1
|
||||
|
||||
return ymin, xmin, ymax, xmax
|
||||
|
||||
# Fallback to central crop
|
||||
img_short = min(height, width)
|
||||
crop_size = size[0] / (size[0] + crop_padding) * img_short
|
||||
|
||||
ymin = max(0, int(round((height - crop_size) / 2.)))
|
||||
xmin = max(0, int(round((width - crop_size) / 2.)))
|
||||
ymax = min(height, ymin + crop_size) - 1
|
||||
xmax = min(width, xmin + crop_size) - 1
|
||||
|
||||
return ymin, xmin, ymax, xmax
|
||||
|
||||
def __call__(self, results):
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
if self.efficientnet_style:
|
||||
get_params_func = self.get_params_efficientnet_style
|
||||
get_params_args = dict(
|
||||
img=img,
|
||||
size=self.size,
|
||||
scale=self.scale,
|
||||
ratio=self.ratio,
|
||||
max_attempts=self.max_attempts,
|
||||
min_covered=self.min_covered,
|
||||
crop_padding=self.crop_padding)
|
||||
else:
|
||||
get_params_func = self.get_params
|
||||
get_params_args = dict(
|
||||
img=img,
|
||||
scale=self.scale,
|
||||
ratio=self.ratio,
|
||||
max_attempts=self.max_attempts)
|
||||
ymin, xmin, ymax, xmax = get_params_func(**get_params_args)
|
||||
img = mmcv.imcrop(img, bboxes=np.array([xmin, ymin, xmax, ymax]))
|
||||
results[key] = mmcv.imresize(
|
||||
img,
|
||||
tuple(self.size[::-1]),
|
||||
interpolation=self.interpolation,
|
||||
backend=self.backend)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(size={self.size}'
|
||||
repr_str += f', scale={tuple(round(s, 4) for s in self.scale)}'
|
||||
repr_str += f', ratio={tuple(round(r, 4) for r in self.ratio)}'
|
||||
"""Print the basic information of the transform.
|
||||
|
||||
Returns:
|
||||
str: Formatted string.
|
||||
"""
|
||||
repr_str = self.__class__.__name__ + f'(scale={self.scale}'
|
||||
repr_str += ', crop_ratio_range='
|
||||
repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}'
|
||||
repr_str += ', aspect_ratio_range='
|
||||
repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}'
|
||||
repr_str += f', max_attempts={self.max_attempts}'
|
||||
repr_str += f', efficientnet_style={self.efficientnet_style}'
|
||||
repr_str += f', min_covered={self.min_covered}'
|
||||
repr_str += f', crop_padding={self.crop_padding}'
|
||||
repr_str += f', interpolation={self.interpolation}'
|
||||
repr_str += f', backend={self.backend})'
|
||||
return repr_str
|
||||
|
@ -737,112 +664,6 @@ class ResizeEdge(BaseTransform):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class CenterCrop(object):
|
||||
r"""Center crop the image.
|
||||
|
||||
Args:
|
||||
crop_size (int | tuple): Expected size after cropping with the format
|
||||
of (h, w).
|
||||
efficientnet_style (bool): Whether to use efficientnet style center
|
||||
crop. Defaults to False.
|
||||
crop_padding (int): The crop padding parameter in efficientnet style
|
||||
center crop. Only valid if efficientnet style is True. Defaults to
|
||||
32.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
|
||||
``efficientnet_style`` is True. Defaults to 'bilinear'.
|
||||
backend (str): The image resize backend type, accepted values are
|
||||
`cv2` and `pillow`. Only valid if efficientnet style is True.
|
||||
Defaults to `cv2`.
|
||||
|
||||
|
||||
Notes:
|
||||
- If the image is smaller than the crop size, return the original
|
||||
image.
|
||||
- If efficientnet_style is set to False, the pipeline would be a simple
|
||||
center crop using the crop_size.
|
||||
- If efficientnet_style is set to True, the pipeline will be to first
|
||||
to perform the center crop with the ``crop_size_`` as:
|
||||
|
||||
.. math::
|
||||
\text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} +
|
||||
\text{crop_padding}} \times \text{short_edge}
|
||||
|
||||
And then the pipeline resizes the img to the input crop size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
crop_size,
|
||||
efficientnet_style=False,
|
||||
crop_padding=32,
|
||||
interpolation='bilinear',
|
||||
backend='cv2'):
|
||||
if efficientnet_style:
|
||||
assert isinstance(crop_size, int)
|
||||
assert crop_padding >= 0
|
||||
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
||||
'lanczos')
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(
|
||||
f'backend: {backend} is not supported for '
|
||||
'resize. Supported backends are "cv2", "pillow"')
|
||||
else:
|
||||
assert isinstance(crop_size, int) or (isinstance(crop_size, tuple)
|
||||
and len(crop_size) == 2)
|
||||
if isinstance(crop_size, int):
|
||||
crop_size = (crop_size, crop_size)
|
||||
assert crop_size[0] > 0 and crop_size[1] > 0
|
||||
self.crop_size = crop_size
|
||||
self.efficientnet_style = efficientnet_style
|
||||
self.crop_padding = crop_padding
|
||||
self.interpolation = interpolation
|
||||
self.backend = backend
|
||||
|
||||
def __call__(self, results):
|
||||
crop_height, crop_width = self.crop_size[0], self.crop_size[1]
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
# img.shape has length 2 for grayscale, length 3 for color
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
# https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa
|
||||
if self.efficientnet_style:
|
||||
img_short = min(img_height, img_width)
|
||||
crop_height = crop_height / (crop_height +
|
||||
self.crop_padding) * img_short
|
||||
crop_width = crop_width / (crop_width +
|
||||
self.crop_padding) * img_short
|
||||
|
||||
y1 = max(0, int(round((img_height - crop_height) / 2.)))
|
||||
x1 = max(0, int(round((img_width - crop_width) / 2.)))
|
||||
y2 = min(img_height, y1 + crop_height) - 1
|
||||
x2 = min(img_width, x1 + crop_width) - 1
|
||||
|
||||
# crop the image
|
||||
img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
|
||||
|
||||
if self.efficientnet_style:
|
||||
img = mmcv.imresize(
|
||||
img,
|
||||
tuple(self.crop_size[::-1]),
|
||||
interpolation=self.interpolation,
|
||||
backend=self.backend)
|
||||
img_shape = img.shape
|
||||
results[key] = img
|
||||
results['img_shape'] = img_shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
|
||||
repr_str += f', efficientnet_style={self.efficientnet_style}'
|
||||
repr_str += f', crop_padding={self.crop_padding}'
|
||||
repr_str += f', interpolation={self.interpolation}'
|
||||
repr_str += f', backend={self.backend})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Normalize(object):
|
||||
"""Normalize the image.
|
||||
|
|
|
@ -21,6 +21,167 @@ def construct_toy_data():
|
|||
results['img_shape'] = img.shape
|
||||
return results
|
||||
|
||||
class TestRandomCrop(TestCase):
|
||||
|
||||
def test_assertion(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomCrop', crop_size=-1)
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomCrop', crop_size=(1, 2, 3))
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomCrop', crop_size=(1, -2))
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomCrop', crop_size=224, padding_mode='co')
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
def test_transform(self):
|
||||
results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8))
|
||||
|
||||
# test random crop by default.
|
||||
cfg = dict(type='RandomCrop', crop_size=224)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test int padding and int pad_val.
|
||||
cfg = dict(
|
||||
type='RandomCrop', crop_size=(224, 224), padding=2, pad_val=1)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test int padding and sequence pad_val.
|
||||
cfg = dict(
|
||||
type='RandomCrop', crop_size=224, padding=2, pad_val=(0, 50, 0))
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test sequence padding.
|
||||
cfg = dict(type='RandomCrop', crop_size=224, padding=(2, 3, 4, 5))
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test pad_if_needed.
|
||||
cfg = dict(
|
||||
type='RandomCrop',
|
||||
crop_size=300,
|
||||
pad_if_needed=True,
|
||||
padding_mode='edge')
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (300, 300, 3))
|
||||
|
||||
# test large crop size.
|
||||
results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8))
|
||||
cfg = dict(type='RandomCrop', crop_size=300)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (256, 256, 3))
|
||||
|
||||
# test equal size.
|
||||
results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8))
|
||||
cfg = dict(type='RandomCrop', crop_size=256)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (256, 256, 3))
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='RandomCrop', crop_size=224)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(
|
||||
repr(transform), 'RandomCrop(crop_size=(224, 224), padding=None, '
|
||||
'pad_if_needed=False, pad_val=0, padding_mode=constant)')
|
||||
|
||||
|
||||
class RandomResizedCrop(TestCase):
|
||||
|
||||
def test_assertion(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomResizedCrop', scale=-1)
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomResizedCrop', scale=(1, 2, 3))
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomResizedCrop', scale=(1, -2))
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
cfg = dict(
|
||||
type='RandomResizedCrop', scale=224, crop_ratio_range=(1, 0.1))
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
cfg = dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
aspect_ratio_range=(1, 0.1))
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomResizedCrop', scale=224, max_attempts=-1)
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomResizedCrop', scale=224, interpolation='ne')
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
def test_transform(self):
|
||||
results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8))
|
||||
|
||||
# test random crop by default.
|
||||
cfg = dict(type='RandomResizedCrop', scale=224)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test crop_ratio_range.
|
||||
cfg = dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=(224, 224),
|
||||
crop_ratio_range=(0.5, 0.8))
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test aspect_ratio_range.
|
||||
cfg = dict(
|
||||
type='RandomResizedCrop', scale=224, aspect_ratio_range=(0.5, 0.8))
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test max_attempts.
|
||||
cfg = dict(type='RandomResizedCrop', scale=224, max_attempts=0)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (224, 224, 3))
|
||||
|
||||
# test large crop size.
|
||||
results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8))
|
||||
cfg = dict(type='RandomResizedCrop', scale=300)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(results)
|
||||
self.assertTupleEqual(results['img'].shape, (300, 300, 3))
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='RandomResizedCrop', scale=224)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(
|
||||
repr(transform), 'RandomResizedCrop(scale=(224, 224), '
|
||||
'crop_ratio_range=(0.08, 1.0), aspect_ratio_range=(0.75, 1.3333), '
|
||||
'max_attempts=10, interpolation=bilinear, backend=cv2)')
|
||||
|
||||
|
||||
class TestResizeEdge(TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue