From c78b5597d896a333b60283adaa4724b471c1aee8 Mon Sep 17 00:00:00 2001 From: yingfhu Date: Thu, 2 Jun 2022 09:52:59 +0000 Subject: [PATCH] =?UTF-8?q?[Refactor]=20refactor=20randomCrop,=20randomRes?= =?UTF-8?q?izeCrop=20and=20CenterCrop=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/_base_/datasets/cifar100_bs16.py | 2 +- configs/_base_/datasets/cifar10_bs16.py | 2 +- configs/_base_/datasets/cub_bs8_384.py | 2 +- configs/_base_/datasets/cub_bs8_448.py | 2 +- mmcls/datasets/pipelines/__init__.py | 17 +- mmcls/datasets/pipelines/processing.py | 537 ++++++------------ .../test_pipelines/test_processing.py | 161 ++++++ 7 files changed, 352 insertions(+), 371 deletions(-) diff --git a/configs/_base_/datasets/cifar100_bs16.py b/configs/_base_/datasets/cifar100_bs16.py index 528c45b0..7bdc55f9 100644 --- a/configs/_base_/datasets/cifar100_bs16.py +++ b/configs/_base_/datasets/cifar100_bs16.py @@ -8,7 +8,7 @@ preprocess_cfg = dict( to_rgb=False) train_pipeline = [ - dict(type='RandomCrop', size=32, padding=4), + dict(type='RandomCrop', crop_size=32, padding=4), dict(type='RandomFlip', prob=0.5, direction='horizontal'), dict(type='PackClsInputs'), ] diff --git a/configs/_base_/datasets/cifar10_bs16.py b/configs/_base_/datasets/cifar10_bs16.py index 65bf95d0..edc5fc13 100644 --- a/configs/_base_/datasets/cifar10_bs16.py +++ b/configs/_base_/datasets/cifar10_bs16.py @@ -8,7 +8,7 @@ preprocess_cfg = dict( to_rgb=False) train_pipeline = [ - dict(type='RandomCrop', size=32, padding=4), + dict(type='RandomCrop', crop_size=32, padding=4), dict(type='RandomFlip', prob=0.5, direction='horizontal'), dict(type='PackClsInputs'), ] diff --git a/configs/_base_/datasets/cub_bs8_384.py b/configs/_base_/datasets/cub_bs8_384.py index d5c3564d..94380f54 100644 --- a/configs/_base_/datasets/cub_bs8_384.py +++ b/configs/_base_/datasets/cub_bs8_384.py @@ -11,7 +11,7 @@ preprocess_cfg = dict( train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=510), - dict(type='RandomCrop', size=384), + dict(type='RandomCrop', crop_size=384), dict(type='RandomFlip', prob=0.5, direction='horizontal'), dict(type='PackClsInputs'), ] diff --git a/configs/_base_/datasets/cub_bs8_448.py b/configs/_base_/datasets/cub_bs8_448.py index c0b74094..f1ef3469 100644 --- a/configs/_base_/datasets/cub_bs8_448.py +++ b/configs/_base_/datasets/cub_bs8_448.py @@ -10,7 +10,7 @@ preprocess_cfg = dict( train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=600), - dict(type='RandomCrop', size=448), + dict(type='RandomCrop', crop_size=448), dict(type='RandomFlip', prob=0.5, direction='horizontal'), dict(type='PackClsInputs'), ] diff --git a/mmcls/datasets/pipelines/__init__.py b/mmcls/datasets/pipelines/__init__.py index f346fa7b..4c7f10bc 100644 --- a/mmcls/datasets/pipelines/__init__.py +++ b/mmcls/datasets/pipelines/__init__.py @@ -6,16 +6,15 @@ from .auto_augment import (AutoAugment, AutoContrast, Brightness, from .compose import Compose from .formatting import (Collect, ImageToTensor, PackClsInputs, ToNumpy, ToPIL, ToTensor, Transpose, to_tensor) -from .processing import (CenterCrop, ColorJitter, Lighting, Normalize, Pad, - RandomCrop, RandomErasing, RandomGrayscale, - RandomResizedCrop) +from .processing import (ColorJitter, Lighting, Normalize, Pad, RandomCrop, + RandomErasing, RandomGrayscale, RandomResizedCrop) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', - 'Transpose', 'Collect', 'CenterCrop', 'Normalize', 'RandomCrop', - 'RandomResizedCrop', 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', - 'Invert', 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', - 'Equalize', 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', - 'SolarizeAdd', 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', - 'RandomErasing', 'Pad', 'PackClsInputs' + 'Transpose', 'Collect', 'Normalize', 'RandomCrop', 'RandomResizedCrop', + 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert', + 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', + 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad', + 'PackClsInputs' ] diff --git a/mmcls/datasets/pipelines/processing.py b/mmcls/datasets/pipelines/processing.py index ed276f37..ba6533e8 100644 --- a/mmcls/datasets/pipelines/processing.py +++ b/mmcls/datasets/pipelines/processing.py @@ -3,11 +3,11 @@ import inspect import math import random from numbers import Number -from typing import Dict, Sequence +from typing import Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np -from mmcv import BaseTransform +from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness from mmcls.registry import TRANSFORMS @@ -20,13 +20,22 @@ except ImportError: @TRANSFORMS.register_module() -class RandomCrop(object): +class RandomCrop(BaseTransform): """Crop the given Image at a random location. + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. + crop_size (sequence or int): Desired output size of the crop. If + crop_size is an int instead of sequence like (h, w), a square crop + (crop_size, crop_size) is made. padding (int or sequence, optional): Optional padding on each border of the image. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. If a sequence @@ -56,15 +65,18 @@ class RandomCrop(object): """ def __init__(self, - size, - padding=None, - pad_if_needed=False, - pad_val=0, - padding_mode='constant'): - if isinstance(size, (tuple, list)): - self.size = size + crop_size: Union[Sequence, int], + padding: Optional[Union[Sequence, int]] = None, + pad_if_needed: bool = False, + pad_val: Union[Number, Sequence[Number]] = 0, + padding_mode: str = 'constant') -> None: + if isinstance(crop_size, Sequence): + assert len(crop_size) == 2 + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size else: - self.size = (size, size) + assert crop_size > 0 + self.crop_size = (crop_size, crop_size) # check padding mode assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] self.padding = padding @@ -72,98 +84,111 @@ class RandomCrop(object): self.pad_val = pad_val self.padding_mode = padding_mode - @staticmethod - def get_params(img, output_size): + @cache_randomness + def rand_crop_params(self, img: np.ndarray): """Get parameters for ``crop`` for a random crop. Args: img (ndarray): Image to be cropped. - output_size (tuple): Expected output size of the crop. Returns: - tuple: Params (xmin, ymin, target_height, target_width) to be + tuple: Params (offset_h, offset_w, target_h, target_w) to be passed to ``crop`` for random crop. """ - height = img.shape[0] - width = img.shape[1] - target_height, target_width = output_size - if width == target_width and height == target_height: - return 0, 0, height, width + h, w = img.shape[:2] + target_h, target_w = self.crop_size + if w == target_w and h == target_h: + return 0, 0, h, w + elif w < target_w or h < target_h: + target_w = min(w, target_w) + target_h = min(w, target_h) - ymin = random.randint(0, height - target_height) - xmin = random.randint(0, width - target_width) - return ymin, xmin, target_height, target_width + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images. - def __call__(self, results): - """ Args: - img (ndarray): Image to be cropped. + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. """ - for key in results.get('img_fields', ['img']): - img = results[key] - if self.padding is not None: - img = mmcv.impad( - img, padding=self.padding, pad_val=self.pad_val) + img = results['img'] + if self.padding is not None: + img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val) - # pad the height if needed - if self.pad_if_needed and img.shape[0] < self.size[0]: - img = mmcv.impad( - img, - padding=(0, self.size[0] - img.shape[0], 0, - self.size[0] - img.shape[0]), - pad_val=self.pad_val, - padding_mode=self.padding_mode) + # pad img if needed + if self.pad_if_needed: + h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2) + w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2) - # pad the width if needed - if self.pad_if_needed and img.shape[1] < self.size[1]: - img = mmcv.impad( - img, - padding=(self.size[1] - img.shape[1], 0, - self.size[1] - img.shape[1], 0), - pad_val=self.pad_val, - padding_mode=self.padding_mode) - - ymin, xmin, height, width = self.get_params(img, self.size) - results[key] = mmcv.imcrop( + img = mmcv.impad( img, - np.array([ - xmin, - ymin, - xmin + width - 1, - ymin + height - 1, - ])) + padding=(w_pad, h_pad, w_pad, h_pad), + pad_val=self.pad_val, + padding_mode=self.padding_mode) + + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + np.array([ + offset_w, + offset_h, + offset_w + target_w - 1, + offset_h + target_h - 1, + ])) + results['img'] = img + results['img_shape'] = img.shape + return results def __repr__(self): - return (self.__class__.__name__ + - f'(size={self.size}, padding={self.padding})') + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', padding={self.padding}' + repr_str += f', pad_if_needed={self.pad_if_needed}' + repr_str += f', pad_val={self.pad_val}' + repr_str += f', padding_mode={self.padding_mode})' + return repr_str @TRANSFORMS.register_module() -class RandomResizedCrop(object): - """Crop the given image to random size and aspect ratio. +class RandomResizedCrop(BaseTransform): + """Crop the given image to random scale and aspect ratio. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + Args: - size (sequence | int): Desired output size of the crop. If size is an + scale (sequence | int): Desired output scale of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. - scale (tuple): Range of the random size of the cropped image compared - to the original image. Defaults to (0.08, 1.0). - ratio (tuple): Range of the random aspect ratio of the cropped image - compared to the original image. Defaults to (3. / 4., 4. / 3.). + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). max_attempts (int): Maximum number of attempts before falling back to Central Crop. Defaults to 10. - efficientnet_style (bool): Whether to use efficientnet style Random - ResizedCrop. Defaults to False. - min_covered (Number): Minimum ratio of the cropped area to the original - area. Only valid if efficientnet_style is true. Defaults to 0.1. - crop_padding (int): The crop padding parameter in efficientnet style - center crop. Only valid if efficientnet_style is true. - Defaults to 32. interpolation (str): Interpolation method, accepted values are 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'bilinear'. @@ -172,217 +197,119 @@ class RandomResizedCrop(object): """ def __init__(self, - size, - scale=(0.08, 1.0), - ratio=(3. / 4., 4. / 3.), - max_attempts=10, - efficientnet_style=False, - min_covered=0.1, - crop_padding=32, - interpolation='bilinear', - backend='cv2'): - if efficientnet_style: - assert isinstance(size, int) - self.size = (size, size) - assert crop_padding >= 0 + scale: Union[Sequence, int], + crop_ratio_range: Tuple[float, float] = (0.08, 1.0), + aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.), + max_attempts: int = 10, + interpolation: str = 'bilinear', + backend: str = 'cv2') -> None: + if isinstance(scale, Sequence): + assert len(scale) == 2 + assert scale[0] > 0 and scale[1] > 0 + self.scale = scale else: - if isinstance(size, (tuple, list)): - self.size = size - else: - self.size = (size, size) - if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - raise ValueError('range should be of kind (min, max). ' - f'But received scale {scale} and rato {ratio}.') - assert min_covered >= 0, 'min_covered should be no less than 0.' + assert scale > 0 + self.scale = (scale, scale) + if (crop_ratio_range[0] > crop_ratio_range[1]) or ( + aspect_ratio_range[0] > aspect_ratio_range[1]): + raise ValueError( + 'range should be of kind (min, max). ' + f'But received crop_ratio_range {crop_ratio_range} ' + f'and aspect_ratio_range {aspect_ratio_range}.') assert isinstance(max_attempts, int) and max_attempts >= 0, \ 'max_attempts mush be int and no less than 0.' assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', 'lanczos') - if backend not in ['cv2', 'pillow']: - raise ValueError(f'backend: {backend} is not supported for resize.' - 'Supported backends are "cv2", "pillow"') - self.scale = scale - self.ratio = ratio + self.crop_ratio_range = crop_ratio_range + self.aspect_ratio_range = aspect_ratio_range self.max_attempts = max_attempts - self.efficientnet_style = efficientnet_style - self.min_covered = min_covered - self.crop_padding = crop_padding self.interpolation = interpolation self.backend = backend - @staticmethod - def get_params(img, scale, ratio, max_attempts=10): + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: img (ndarray): Image to be cropped. - scale (tuple): Range of the random size of the cropped image - compared to the original image size. - ratio (tuple): Range of the random aspect ratio of the cropped - image compared to the original image area. - max_attempts (int): Maximum number of attempts before falling back - to central crop. Defaults to 10. Returns: - tuple: Params (ymin, xmin, ymax, xmax) to be passed to `crop` for - a random sized crop. + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. """ - height = img.shape[0] - width = img.shape[1] - area = height * width + h, w = img.shape[:2] + area = h * w - for _ in range(max_attempts): - target_area = random.uniform(*scale) * area - log_ratio = (math.log(ratio[0]), math.log(ratio[1])) - aspect_ratio = math.exp(random.uniform(*log_ratio)) + for _ in range(self.max_attempts): + target_area = np.random.uniform(*self.crop_ratio_range) * area + log_ratio = (math.log(self.aspect_ratio_range[0]), + math.log(self.aspect_ratio_range[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + target_w = int(round(math.sqrt(target_area * aspect_ratio))) + target_h = int(round(math.sqrt(target_area / aspect_ratio))) - target_width = int(round(math.sqrt(target_area * aspect_ratio))) - target_height = int(round(math.sqrt(target_area / aspect_ratio))) + if 0 < target_w <= w and 0 < target_h <= h: + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) - if 0 < target_width <= width and 0 < target_height <= height: - ymin = random.randint(0, height - target_height) - xmin = random.randint(0, width - target_width) - ymax = ymin + target_height - 1 - xmax = xmin + target_width - 1 - return ymin, xmin, ymax, xmax + return offset_h, offset_w, target_h, target_w # Fallback to central crop - in_ratio = float(width) / float(height) - if in_ratio < min(ratio): - target_width = width - target_height = int(round(target_width / min(ratio))) - elif in_ratio > max(ratio): - target_height = height - target_width = int(round(target_height * max(ratio))) + in_ratio = float(w) / float(h) + if in_ratio < min(self.aspect_ratio_range): + target_w = w + target_h = int(round(target_w / min(self.aspect_ratio_range))) + elif in_ratio > max(self.aspect_ratio_range): + target_h = h + target_w = int(round(target_h * max(self.aspect_ratio_range))) else: # whole image - target_width = width - target_height = height - ymin = (height - target_height) // 2 - xmin = (width - target_width) // 2 - ymax = ymin + target_height - 1 - xmax = xmin + target_width - 1 - return ymin, xmin, ymax, xmax + target_w = w + target_h = h + offset_h = (h - target_h) // 2 + offset_w = (w - target_w) // 2 + return offset_h, offset_w, target_h, target_w - # https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa - @staticmethod - def get_params_efficientnet_style(img, - size, - scale, - ratio, - max_attempts=10, - min_covered=0.1, - crop_padding=32): - """Get parameters for ``crop`` for a random sized crop in efficientnet - style. + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. Args: - img (ndarray): Image to be cropped. - size (sequence): Desired output size of the crop. - scale (tuple): Range of the random size of the cropped image - compared to the original image size. - ratio (tuple): Range of the random aspect ratio of the cropped - image compared to the original image area. - max_attempts (int): Maximum number of attempts before falling back - to central crop. Defaults to 10. - min_covered (Number): Minimum ratio of the cropped area to the - original area. Only valid if efficientnet_style is true. - Defaults to 0.1. - crop_padding (int): The crop padding parameter in efficientnet - style center crop. Defaults to 32. + results (dict): Result dict from loading pipeline. Returns: - tuple: Params (ymin, xmin, ymax, xmax) to be passed to `crop` for - a random sized crop. + dict: Randomly resized cropped results, 'img_shape' + key in result dict is updated according to crop size. """ - height, width = img.shape[:2] - area = height * width - min_target_area = scale[0] * area - max_target_area = scale[1] * area + img = results['img'] + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + target_w - 1, + offset_h + target_h - 1 + ])) + img = mmcv.imresize( + img, + tuple(self.scale[::-1]), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape - for _ in range(max_attempts): - aspect_ratio = random.uniform(*ratio) - min_target_height = int( - round(math.sqrt(min_target_area / aspect_ratio))) - max_target_height = int( - round(math.sqrt(max_target_area / aspect_ratio))) - - if max_target_height * aspect_ratio > width: - max_target_height = int((width + 0.5 - 1e-7) / aspect_ratio) - if max_target_height * aspect_ratio > width: - max_target_height -= 1 - - max_target_height = min(max_target_height, height) - min_target_height = min(max_target_height, min_target_height) - - # slightly differs from tf implementation - target_height = int( - round(random.uniform(min_target_height, max_target_height))) - target_width = int(round(target_height * aspect_ratio)) - target_area = target_height * target_width - - # slight differs from tf. In tf, if target_area > max_target_area, - # area will be recalculated - if (target_area < min_target_area or target_area > max_target_area - or target_width > width or target_height > height - or target_area < min_covered * area): - continue - - ymin = random.randint(0, height - target_height) - xmin = random.randint(0, width - target_width) - ymax = ymin + target_height - 1 - xmax = xmin + target_width - 1 - - return ymin, xmin, ymax, xmax - - # Fallback to central crop - img_short = min(height, width) - crop_size = size[0] / (size[0] + crop_padding) * img_short - - ymin = max(0, int(round((height - crop_size) / 2.))) - xmin = max(0, int(round((width - crop_size) / 2.))) - ymax = min(height, ymin + crop_size) - 1 - xmax = min(width, xmin + crop_size) - 1 - - return ymin, xmin, ymax, xmax - - def __call__(self, results): - for key in results.get('img_fields', ['img']): - img = results[key] - if self.efficientnet_style: - get_params_func = self.get_params_efficientnet_style - get_params_args = dict( - img=img, - size=self.size, - scale=self.scale, - ratio=self.ratio, - max_attempts=self.max_attempts, - min_covered=self.min_covered, - crop_padding=self.crop_padding) - else: - get_params_func = self.get_params - get_params_args = dict( - img=img, - scale=self.scale, - ratio=self.ratio, - max_attempts=self.max_attempts) - ymin, xmin, ymax, xmax = get_params_func(**get_params_args) - img = mmcv.imcrop(img, bboxes=np.array([xmin, ymin, xmax, ymax])) - results[key] = mmcv.imresize( - img, - tuple(self.size[::-1]), - interpolation=self.interpolation, - backend=self.backend) return results def __repr__(self): - repr_str = self.__class__.__name__ + f'(size={self.size}' - repr_str += f', scale={tuple(round(s, 4) for s in self.scale)}' - repr_str += f', ratio={tuple(round(r, 4) for r in self.ratio)}' + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(scale={self.scale}' + repr_str += ', crop_ratio_range=' + repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}' + repr_str += ', aspect_ratio_range=' + repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}' repr_str += f', max_attempts={self.max_attempts}' - repr_str += f', efficientnet_style={self.efficientnet_style}' - repr_str += f', min_covered={self.min_covered}' - repr_str += f', crop_padding={self.crop_padding}' repr_str += f', interpolation={self.interpolation}' repr_str += f', backend={self.backend})' return repr_str @@ -737,112 +664,6 @@ class ResizeEdge(BaseTransform): return repr_str -@TRANSFORMS.register_module() -class CenterCrop(object): - r"""Center crop the image. - - Args: - crop_size (int | tuple): Expected size after cropping with the format - of (h, w). - efficientnet_style (bool): Whether to use efficientnet style center - crop. Defaults to False. - crop_padding (int): The crop padding parameter in efficientnet style - center crop. Only valid if efficientnet style is True. Defaults to - 32. - interpolation (str): Interpolation method, accepted values are - 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if - ``efficientnet_style`` is True. Defaults to 'bilinear'. - backend (str): The image resize backend type, accepted values are - `cv2` and `pillow`. Only valid if efficientnet style is True. - Defaults to `cv2`. - - - Notes: - - If the image is smaller than the crop size, return the original - image. - - If efficientnet_style is set to False, the pipeline would be a simple - center crop using the crop_size. - - If efficientnet_style is set to True, the pipeline will be to first - to perform the center crop with the ``crop_size_`` as: - - .. math:: - \text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} + - \text{crop_padding}} \times \text{short_edge} - - And then the pipeline resizes the img to the input crop size. - """ - - def __init__(self, - crop_size, - efficientnet_style=False, - crop_padding=32, - interpolation='bilinear', - backend='cv2'): - if efficientnet_style: - assert isinstance(crop_size, int) - assert crop_padding >= 0 - assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', - 'lanczos') - if backend not in ['cv2', 'pillow']: - raise ValueError( - f'backend: {backend} is not supported for ' - 'resize. Supported backends are "cv2", "pillow"') - else: - assert isinstance(crop_size, int) or (isinstance(crop_size, tuple) - and len(crop_size) == 2) - if isinstance(crop_size, int): - crop_size = (crop_size, crop_size) - assert crop_size[0] > 0 and crop_size[1] > 0 - self.crop_size = crop_size - self.efficientnet_style = efficientnet_style - self.crop_padding = crop_padding - self.interpolation = interpolation - self.backend = backend - - def __call__(self, results): - crop_height, crop_width = self.crop_size[0], self.crop_size[1] - for key in results.get('img_fields', ['img']): - img = results[key] - # img.shape has length 2 for grayscale, length 3 for color - img_height, img_width = img.shape[:2] - - # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa - if self.efficientnet_style: - img_short = min(img_height, img_width) - crop_height = crop_height / (crop_height + - self.crop_padding) * img_short - crop_width = crop_width / (crop_width + - self.crop_padding) * img_short - - y1 = max(0, int(round((img_height - crop_height) / 2.))) - x1 = max(0, int(round((img_width - crop_width) / 2.))) - y2 = min(img_height, y1 + crop_height) - 1 - x2 = min(img_width, x1 + crop_width) - 1 - - # crop the image - img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2])) - - if self.efficientnet_style: - img = mmcv.imresize( - img, - tuple(self.crop_size[::-1]), - interpolation=self.interpolation, - backend=self.backend) - img_shape = img.shape - results[key] = img - results['img_shape'] = img_shape - - return results - - def __repr__(self): - repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' - repr_str += f', efficientnet_style={self.efficientnet_style}' - repr_str += f', crop_padding={self.crop_padding}' - repr_str += f', interpolation={self.interpolation}' - repr_str += f', backend={self.backend})' - return repr_str - - @TRANSFORMS.register_module() class Normalize(object): """Normalize the image. diff --git a/tests/test_data/test_pipelines/test_processing.py b/tests/test_data/test_pipelines/test_processing.py index bc590a7e..18c80d95 100644 --- a/tests/test_data/test_pipelines/test_processing.py +++ b/tests/test_data/test_pipelines/test_processing.py @@ -21,6 +21,167 @@ def construct_toy_data(): results['img_shape'] = img.shape return results +class TestRandomCrop(TestCase): + + def test_assertion(self): + with self.assertRaises(AssertionError): + cfg = dict(type='RandomCrop', crop_size=-1) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomCrop', crop_size=(1, 2, 3)) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomCrop', crop_size=(1, -2)) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomCrop', crop_size=224, padding_mode='co') + TRANSFORMS.build(cfg) + + def test_transform(self): + results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8)) + + # test random crop by default. + cfg = dict(type='RandomCrop', crop_size=224) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test int padding and int pad_val. + cfg = dict( + type='RandomCrop', crop_size=(224, 224), padding=2, pad_val=1) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test int padding and sequence pad_val. + cfg = dict( + type='RandomCrop', crop_size=224, padding=2, pad_val=(0, 50, 0)) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test sequence padding. + cfg = dict(type='RandomCrop', crop_size=224, padding=(2, 3, 4, 5)) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test pad_if_needed. + cfg = dict( + type='RandomCrop', + crop_size=300, + pad_if_needed=True, + padding_mode='edge') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (300, 300, 3)) + + # test large crop size. + results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8)) + cfg = dict(type='RandomCrop', crop_size=300) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (256, 256, 3)) + + # test equal size. + results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8)) + cfg = dict(type='RandomCrop', crop_size=256) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (256, 256, 3)) + + def test_repr(self): + cfg = dict(type='RandomCrop', crop_size=224) + transform = TRANSFORMS.build(cfg) + self.assertEqual( + repr(transform), 'RandomCrop(crop_size=(224, 224), padding=None, ' + 'pad_if_needed=False, pad_val=0, padding_mode=constant)') + + +class RandomResizedCrop(TestCase): + + def test_assertion(self): + with self.assertRaises(AssertionError): + cfg = dict(type='RandomResizedCrop', scale=-1) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomResizedCrop', scale=(1, 2, 3)) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomResizedCrop', scale=(1, -2)) + TRANSFORMS.build(cfg) + + with self.assertRaises(ValueError): + cfg = dict( + type='RandomResizedCrop', scale=224, crop_ratio_range=(1, 0.1)) + TRANSFORMS.build(cfg) + + with self.assertRaises(ValueError): + cfg = dict( + type='RandomResizedCrop', + scale=224, + aspect_ratio_range=(1, 0.1)) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomResizedCrop', scale=224, max_attempts=-1) + TRANSFORMS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = dict(type='RandomResizedCrop', scale=224, interpolation='ne') + TRANSFORMS.build(cfg) + + def test_transform(self): + results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8)) + + # test random crop by default. + cfg = dict(type='RandomResizedCrop', scale=224) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test crop_ratio_range. + cfg = dict( + type='RandomResizedCrop', + scale=(224, 224), + crop_ratio_range=(0.5, 0.8)) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test aspect_ratio_range. + cfg = dict( + type='RandomResizedCrop', scale=224, aspect_ratio_range=(0.5, 0.8)) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test max_attempts. + cfg = dict(type='RandomResizedCrop', scale=224, max_attempts=0) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 224, 3)) + + # test large crop size. + results = dict(img=np.random.randint(0, 256, (256, 256, 3), np.uint8)) + cfg = dict(type='RandomResizedCrop', scale=300) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (300, 300, 3)) + + def test_repr(self): + cfg = dict(type='RandomResizedCrop', scale=224) + transform = TRANSFORMS.build(cfg) + self.assertEqual( + repr(transform), 'RandomResizedCrop(scale=(224, 224), ' + 'crop_ratio_range=(0.08, 1.0), aspect_ratio_range=(0.75, 1.3333), ' + 'max_attempts=10, interpolation=bilinear, backend=cv2)') + class TestResizeEdge(TestCase):