[Feature] Add Part3 of data transform (#1735)

* update data transform part3

* update init

* rename flip funcs

* fix comments

* update comments

* fix lint

* Update mmcv/transforms/processing.py

* fix docs format

* fix comments

* add test pad_val and fix bugs in class Pad

* merge updated pad

* fix lint

* Update tests/test_transforms/test_transforms_processing.py
pull/2133/head
Yifei Yang 2022-03-01 19:06:53 +08:00 committed by zhouzaida
parent 5af6c12b81
commit 2619aa9c8e
4 changed files with 976 additions and 49 deletions

View File

@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import TRANSFORMS
from .loading import LoadAnnotation, LoadImageFromFile
from .processing import Normalize, Pad, RandomFlip, RandomResize, Resize
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize,
RandomResize, Resize)
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
try:
@ -10,7 +12,8 @@ except ImportError:
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'RandomFlip', 'RandomResize'
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
else:
from .formatting import ImageToTensor, ToTensor, to_tensor
@ -18,5 +21,7 @@ else:
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip', 'RandomResize'
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip',
'RandomMultiscaleResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize'
]

View File

@ -16,9 +16,11 @@ def to_tensor(
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
Returns:
torch.Tensor: the converted data.
"""

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterable, List, Optional, Sequence, Tuple, Union
import random
import warnings
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
@ -7,6 +9,9 @@ import mmcv
from mmcv.image.geometric import _scale_size
from .base import BaseTransform
from .builder import TRANSFORMS
from .wrappers import Compose
Number = Union[int, float]
@TRANSFORMS.register_module()
@ -104,7 +109,7 @@ class Resize(BaseTransform):
Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool, optional): Whether to clip the objects
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
@ -292,7 +297,7 @@ class Pad(BaseTransform):
None.
pad_to_square (bool): Whether to pad the image into a square.
Currently only used for YOLOX. Defaults to False.
pad_val (int or dict, optional): A dict for padding value.
pad_val (int or dict): A dict for padding value.
if ``type(pad_val) == int``, the val to pad seg is 255. Defaults to
``dict(img=0, seg=255)``.
padding_mode (str): Type of padding. Should be: constant, edge,
@ -374,9 +379,11 @@ class Pad(BaseTransform):
``results['pad_shape']``."""
if results.get('gt_semantic_seg', None) is not None:
pad_val = self.pad_val.get('seg', 255)
if isinstance(pad_val, int) and results['gt_semantic_seg'].ndim == 3:
pad_val = tuple(
[pad_val for _ in range(results['gt_semantic_seg'].shape[2])])
if isinstance(pad_val,
int) and results['gt_semantic_seg'].ndim == 3:
pad_val = tuple([
pad_val for _ in range(results['gt_semantic_seg'].shape[2])
])
results['gt_semantic_seg'] = mmcv.impad(
results['gt_semantic_seg'],
shape=results['pad_shape'][:2],
@ -407,59 +414,551 @@ class Pad(BaseTransform):
@TRANSFORMS.register_module()
class RandomFlip(BaseTransform):
"""Flip the image & bbox & keypoints & segmentation map.
There are 3 flip modes:
- ``prob`` is float, ``direction`` is string
the image will be flipped on the given direction with probability
of ``prob`` . E.g., ``prob=0.5``, ``direction='horizontal'``,
then image will be horizontally flipped with probability of 0.5.
- ``prob`` is float, ``direction`` is list of string
the image will be flipped on the given direction with probability of
``prob/len(direction)``. E.g., ``prob=0.5``,
``direction=['horizontal', 'vertical']``, then image will be
horizontally flipped with probability of 0.25, vertically with
probability of 0.25.
- ``prob`` is list of float, ``direction`` is list of string
given ``len(prob) == len(direction)``, the image will
be ``direction[i]`` ly flipped with probability of ``prob[i]``.
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
'vertical']``, then image will be horizontally flipped with
probability of 0.3, vertically with probability of 0.5.
class CenterCrop(BaseTransform):
"""Crop the center of the image and segmentation masks. If the crop area
exceeds the original image and ``pad_mode`` is not None, the original image
will be padded before cropping.
Required Keys:
- img
- gt_bboxes
- gt_semantic_seg
- gt_keypoints
- gt_semantic_seg (optional)
Modified Keys:
- img
- gt_bboxes
- gt_semantic_seg
- gt_keypoints
- height
- width
- gt_semantic_seg (optional)
Added Key:
- pad_shape
Args:
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (h, w). If set to an integer, then cropping
height and width are equal to this integer.
pad_val (Union[Number, Dict[str, Number]]): A dict for
padding value. To specify how to set this argument, please see
the docstring of class ``Pad``. Defaults to
``dict(img=0, seg=255)``.
pad_mode (str, optional): Type of padding. Should be: 'constant',
'edge', 'reflect' or 'symmetric'. For details, please see the
docstring of class ``Pad``. Defaults to 'constant'.
pad_cfg (str): Base config for padding. Defaults to
``dict(type='Pad')``.
"""
def __init__(
self,
crop_size: Union[int, Tuple[int, int]],
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255),
pad_mode: Optional[str] = None,
pad_cfg: dict = dict(type='Pad')
) -> None: # flake8: noqa
super().__init__()
assert isinstance(crop_size, int) or (
isinstance(crop_size, tuple) and len(crop_size) == 2
), 'The expected crop_size is an integer, or a tuple containing two '
'intergers'
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.pad_val = pad_val
self.pad_mode = pad_mode
self.pad_cfg = pad_cfg
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
"""Crop image.
Args:
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if results.get('img', None) is not None:
img = mmcv.imcrop(results['img'], bboxes=bboxes)
img_shape = img.shape
results['img'] = img
results['height'] = img_shape[0]
results['width'] = img_shape[1]
results['pad_shape'] = img_shape
def _crop_seg_map(self, results: dict, bboxes: np.ndarray) -> None:
"""Crop semantic segmentation map.
Args:
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if results.get('gt_semantic_seg', None) is not None:
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes)
results['gt_semantic_seg'] = img
def transform(self, results: dict) -> dict:
"""Apply center crop on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: Results with CenterCropped image and semantic segmentation
map.
"""
crop_height, crop_width = self.crop_size[0], self.crop_size[1]
assert 'img' in results, '`img` is not found in results'
img = results['img']
# img.shape has length 2 for grayscale, length 3 for color
img_height, img_width = img.shape[:2]
if crop_height > img_height or crop_width > img_width:
if self.pad_mode is not None:
# pad the area
img_height = max(img_height, crop_height)
img_width = max(img_width, crop_width)
pad_size = (img_width, img_height)
_pad_cfg = self.pad_cfg.copy()
_pad_cfg.update(
dict(
size=pad_size,
pad_val=self.pad_val,
padding_mode=self.pad_mode))
pad_transform = TRANSFORMS.build(_pad_cfg)
results = pad_transform(results)
else:
crop_height = min(crop_height, img_height)
crop_width = min(crop_width, img_width)
y1 = max(0, int(round((img_height - crop_height) / 2.)))
x1 = max(0, int(round((img_width - crop_width) / 2.)))
y2 = min(img_height, y1 + crop_height) - 1
x2 = min(img_width, x1 + crop_width) - 1
bboxes = np.array([x1, y1, x2, y2])
# crop the image
self._crop_img(results, bboxes)
# crop the gt_semantic_seg
self._crop_seg_map(results, bboxes)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', crop_size = {self.crop_size}'
repr_str += f', pad_val = {self.pad_val}'
repr_str += f', pad_mode = {self.pad_mode}'
return repr_str
@TRANSFORMS.register_module()
class RandomGrayscale(BaseTransform):
"""Randomly convert image to grayscale with a probability.
Required Key:
- img
Modified Key:
- img
Added Keys:
- flip
- flip_direction
- grayscale
- grayscale_weights
Args:
prob (float): Probability that image should be converted to
grayscale. Defaults to 0.1.
keep_channel (bool): Whether keep channel number the same as
input. Defaults to False.
channel_weights (tuple): Channel weights to compute gray
image. Defaults to (1., 1., 1.).
color_format (str): Color format set to be any of 'bgr',
'rgb', 'hsv'. Note: 'hsv' image will be transformed into 'bgr'
format no matter whether it is grayscaled. Defaults to 'bgr'.
"""
def __init__(self,
prob: float = 0.1,
keep_channel: bool = False,
channel_weights: Sequence[float] = (1., 1., 1.),
color_format: str = 'bgr') -> None:
super().__init__()
assert 0. <= prob <= 1., ('The range of ``prob`` value is [0., 1.],' +
f' but got {prob} instead')
self.prob = prob
self.keep_channel = keep_channel
self.channel_weights = channel_weights
assert color_format in ['bgr', 'rgb', 'hsv']
self.color_format = color_format
def transform(self, results: dict) -> dict:
"""Apply random grayscale on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: Results with grayscale image.
"""
img = results['img']
# convert hsv to bgr
if self.color_format == 'hsv':
img = mmcv.hsv2bgr(img)
img = img[..., None] if img.ndim == 2 else img
num_output_channels = img.shape[2]
if random.random() < self.prob:
if num_output_channels > 1:
assert num_output_channels == len(
self.channel_weights
), 'The length of ``channel_weights`` are supposed to be '
f'num_output_channels, but got {len(self.channel_weights)}'
' instead.'
normalized_weights = (
np.array(self.channel_weights) / sum(self.channel_weights))
img = (normalized_weights * img).sum(axis=2)
if self.keep_channel:
img = img[:, :, None]
results['img'] = np.dstack(
[img for _ in range(num_output_channels)])
else:
results['img'] = img
return results
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', prob = {self.prob}'
repr_str += f', keep_channel = {self.keep_channel}'
repr_str += f', channel_weights = {self.channel_weights}'
repr_str += f', color_format = {self.color_format}'
return repr_str
@TRANSFORMS.register_module()
class MultiScaleFlipAug(BaseTransform):
"""Test-time augmentation with multiple scales and flipping.
An example configuration is as followed:
.. code-block::
dict(
type='MultiScaleFlipAug',
img_scale=[(1333, 400), (1333, 800)],
flip=True,
transforms=[
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=1),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
``results`` will be resized using all the sizes in ``img_scale``.
If ``flip`` is True, then flipped results will also be added into output
list.
For the above configuration, there are four combinations of resize
and flip:
- Resize to (1333, 400) + no flip
- Resize to (1333, 400) + flip
- Resize to (1333, 800) + no flip
- resize to (1333, 800) + flip
The four results are then transformed with ``transforms`` argument.
After that, results are wrapped into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
)
Required Keys:
- Depending on the requirements of the ``transforms`` parameter.
Modified Keys:
- All output keys of each transform.
Args:
transforms (list[dict]): Transforms to be applied to each resized
and flipped data.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If
flip_direction is a list, multiple flip augmentations will be
applied. It has no effect when flip == False. Defaults to
"horizontal".
resize_cfg (dict): Base config for resizing. Defaults to
``dict(type='Resize', keep_ratio=True)``.
flip_cfg (dict): Base config for flipping. Defaults to
``dict(type='RandomFlip')``.
"""
def __init__(
self,
transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal',
resize_cfg: dict = dict(type='Resize', keep_ratio=True),
flip_cfg: dict = dict(type='RandomFlip')
) -> None:
super().__init__()
self.transforms = Compose(transforms) # type: ignore
assert img_scale is not None
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple)
self.flip = flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
self.resize_cfg = resize_cfg
self.flip_cfg = flip_cfg
def transform(self, results: dict) -> dict:
"""Apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: The augmented data, where each value is wrapped
into a list.
"""
aug_data = []
flip_args = [(False, '')]
if self.flip:
flip_args += [(True, direction)
for direction in self.flip_direction]
for scale in self.img_scale:
for flip, direction in flip_args:
_resize_cfg = self.resize_cfg.copy()
_resize_cfg.update(scale=scale)
_resize_flip = [_resize_cfg]
if flip:
_flip_cfg = self.flip_cfg.copy()
_flip_cfg.update(prob=1.0, direction=direction)
_resize_flip.append(_flip_cfg)
else:
results['flip'] = False
results['flip_direction'] = None
resize_flip = Compose(_resize_flip)
_results = results.copy()
_results = resize_flip(_results)
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', transforms={self.transforms}'
repr_str += f', img_scale={self.img_scale}'
repr_str += f', flip={self.flip}'
repr_str += f', flip_direction={self.flip_direction}'
return repr_str
@TRANSFORMS.register_module()
class RandomMultiscaleResize(BaseTransform):
"""Resize images & bbox & mask from a list of multiple scales.
This transform resizes the input image to some scale. Bboxes and masks are
then resized with the same scale factor. Resize scale will be randomly
selected from ``scales``.
How to choose the target scale to resize the image will follow the rules
below:
- if `scale` is a list of tuple, the target scale is sampled from the list
uniformally.
- if `scale` is a tuple, the target scale will be set to the tuple.
Required Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_keypoints (optional)
Modified Keys:
- img
- height
- width
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_keypoints (optional)
Added Keys:
- scale
- scale_factor
- scale_idx
- keep_ratio
Args:
scales (Union[list, Tuple]): Images scales for resizing.
keep_ratio (bool): Whether to keep the aspect ratio when
resizing the image. Defaults to False.
clip_object_border (bool): Whether clip the objects outside
the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and
'pillow'. These two backends generates slightly different results.
Defaults to 'cv2'.
interpolation (str): The mode of interpolation, support
"bilinear", "bicubic", "nearest". Defaults to "bilinear".
"""
def __init__(
self,
scales: Union[list, Tuple],
keep_ratio: bool = False,
clip_object_border: bool = True,
backend: str = 'cv2',
interpolation: str = 'bilinear',
resize_cfg: dict = dict(type='Resize')
) -> None:
super().__init__()
if isinstance(scales, list):
self.scales = scales
else:
self.scales = [scales]
assert mmcv.is_list_of(self.scales, tuple)
self.keep_ratio = keep_ratio
self.clip_object_border = clip_object_border
self.backend = backend
self.interpolation = interpolation
self.resize_cfg = resize_cfg
@staticmethod
def random_select(scales: List[Tuple]) -> Tuple[Number, int]:
"""Randomly select an img_scale from given candidates.
Args:
scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(scales, tuple)
scale_idx = np.random.randint(len(scales))
scale = scales[scale_idx]
return scale, scale_idx
def transform(self, results: dict) -> dict:
"""Apply resize transforms on results from a list of scales.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
and 'keep_ratio' keys are updated in result dict.
"""
target_scale, scale_idx = self.random_select(self.scales)
_resize_cfg = self.resize_cfg.copy()
_resize_cfg.update(
dict(
scale=target_scale,
keep_ratio=self.keep_ratio,
clip_object_border=self.clip_object_border,
backend=self.backend,
interpolation=self.interpolation))
resize_transform = TRANSFORMS.build(_resize_cfg)
results = resize_transform(results)
results['scale_idx'] = scale_idx
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', scales={self.scales}'
repr_str += f', keep_ratio={self.keep_ratio}'
repr_str += f', clip_object_border={self.clip_object_border}'
repr_str += f', backend={self.backend}'
repr_str += f', interpolation={self.interpolation}'
return repr_str
@TRANSFORMS.register_module()
class RandomFlip(BaseTransform):
"""Flip the image & bbox & keypoints & segmentation map. Added or Updated
keys: flip, flip_direction, img, gt_bboxes, gt_semantic_seg, and
gt_keypoints. There are 3 flip modes:
- ``prob`` is float, ``direction`` is string: the image will be
``direction``ly flipped with probability of ``prob`` .
E.g., ``prob=0.5``, ``direction='horizontal'``,
then image will be horizontally flipped with probability of 0.5.
- ``prob`` is float, ``direction`` is list of string: the image will
be ``direction[i]``ly flipped with probability of
``prob/len(direction)``.
E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``,
then image will be horizontally flipped with probability of 0.25,
vertically with probability of 0.25.
- ``prob`` is list of float, ``direction`` is list of string:
given ``len(prob) == len(direction)``, the image will
be ``direction[i]``ly flipped with probability of ``prob[i]``.
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
'vertical']``, then image will be horizontally flipped with
probability of 0.3, vertically with probability of 0.5.
Required Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_keypoints (optional)
Modified Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_keypoints (optional)
Added Keys:
- flip
- flip_direction
Args:
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to horizontal.
corresponding direction. Defaults to 'horizontal'.
"""
def __init__(
@ -473,8 +972,8 @@ class RandomFlip(BaseTransform):
elif isinstance(prob, float):
assert 0 <= prob <= 1
else:
raise ValueError(f"probs must be float or list of float, but \
got '{type(prob)}'.")
raise ValueError(f'probs must be float or list of float, but \
got `{type(prob)}`.')
self.prob = prob
valid_directions = ['horizontal', 'vertical', 'diagonal']
@ -484,8 +983,8 @@ class RandomFlip(BaseTransform):
assert mmcv.is_list_of(direction, str)
assert set(direction).issubset(set(valid_directions))
else:
raise ValueError(f"direction must be either str or list of str, \
but got '{type(direction)}'.")
raise ValueError(f'direction must be either str or list of str, \
but got `{type(direction)}`.')
self.direction = direction
if isinstance(prob, list):

View File

@ -6,7 +6,18 @@ import numpy as np
import pytest
import mmcv
from mmcv.transforms import Normalize, Pad, RandomFlip, RandomResize, Resize
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
RandomResize, Resize)
try:
import torch
except ModuleNotFoundError:
torch = None
else:
import torchvision
from numpy.testing import assert_array_almost_equal, assert_array_equal
from PIL import Image
class TestNormalize:
@ -223,6 +234,416 @@ class TestPad:
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")
class TestCenterCrop:
@classmethod
def setup_class(cls):
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
cls.original_img = copy.deepcopy(img)
seg = np.random.randint(0, 19, (300, 400)).astype(np.uint8)
cls.gt_semantic_map = copy.deepcopy(seg)
@staticmethod
def reset_results(results, original_img, gt_semantic_map):
results['img'] = copy.deepcopy(original_img)
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
return results
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
def test_error(self):
# test assertion if size is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=-1)
TRANSFORMS.build(transform)
# test assertion if size is tuple but one value is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=(224, -1))
TRANSFORMS.build(transform)
# test assertion if size is tuple and len(size) < 2
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=(224, ))
TRANSFORMS.build(transform)
# test assertion if size is tuple len(size) > 2
with pytest.raises(AssertionError):
transform = dict(type='CenterCrop', crop_size=(224, 224, 3))
TRANSFORMS.build(transform)
def test_repr(self):
# test repr
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
assert isinstance(repr(center_crop_module), str)
def test_transform(self):
results = {}
self.reset_results(results, self.original_img, self.gt_semantic_map)
# test CenterCrop when size is int
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
results = center_crop_module(results)
assert results['height'] == 224
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
# test CenterCrop when size is tuple
transform = dict(type='CenterCrop', crop_size=(224, 224))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 224
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
# test CenterCrop when crop_height != crop_width
transform = dict(type='CenterCrop', crop_size=(256, 224))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 256
assert results['width'] == 224
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
88:312]).all()
# test CenterCrop when crop_size is equal to img.shape
img_height, img_width, _ = self.original_img.shape
transform = dict(type='CenterCrop', crop_size=(img_height, img_width))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 300
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
# test CenterCrop when crop_size is larger than img.shape
transform = dict(
type='CenterCrop', crop_size=(img_height * 2, img_width * 2))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 300
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
# test with padding
transform = dict(
type='CenterCrop',
crop_size=(img_height * 2, img_width // 2),
pad_mode='constant',
pad_val=12)
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
transform = dict(
type='CenterCrop',
crop_size=(img_height * 2, img_width // 2),
pad_mode='constant',
pad_val=dict(img=13, seg=33))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
# test CenterCrop when crop_width is smaller than img_width
transform = dict(
type='CenterCrop', crop_size=(img_height, img_width // 2))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == img_height
assert results['width'] == img_width // 2
assert (results['img'] == self.original_img[:, 100:300, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[:,
100:300]).all()
# test CenterCrop when crop_height is smaller than img_height
transform = dict(
type='CenterCrop', crop_size=(img_height // 2, img_width))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == img_height // 2
assert results['width'] == img_width
assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
...]).all()
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
def test_torchvision_compare(self):
# compare results with torchvision
results = {}
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
center_crop_module = torchvision.transforms.CenterCrop(size=224)
pil_img = Image.fromarray(self.original_img)
pil_seg = Image.fromarray(self.gt_semantic_map)
cropped_img = center_crop_module(pil_img)
cropped_img = np.array(cropped_img)
cropped_seg = center_crop_module(pil_seg)
cropped_seg = np.array(cropped_seg)
assert np.equal(results['img'], cropped_img).all()
assert np.equal(results['gt_semantic_seg'], cropped_seg).all()
class TestRandomGrayscale:
@classmethod
def setup_class(cls):
cls.img = np.random.rand(10, 10, 3).astype(np.float32)
def test_repr(self):
# test repr
transform = dict(
type='RandomGrayscale',
prob=1.,
channel_weights=(0.299, 0.587, 0.114),
keep_channel=True)
random_gray_scale_module = TRANSFORMS.build(transform)
assert isinstance(repr(random_gray_scale_module), str)
def test_error(self):
# test invalid argument
transform = dict(type='RandomGrayscale', prob=2)
with pytest.raises(AssertionError):
TRANSFORMS.build(transform)
def test_transform(self):
results = dict()
# test rgb2gray, return the grayscale image with prob = 1.
transform = dict(
type='RandomGrayscale',
prob=1.,
channel_weights=(0.299, 0.587, 0.114),
keep_channel=True)
random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img)
img = random_gray_scale_module(results)['img']
computed_gray = (
self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 +
self.img[:, :, 2] * 0.114)
for i in range(img.shape[2]):
assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4)
assert img.shape == (10, 10, 3)
# test rgb2gray, return the original image with p=0.
transform = dict(type='RandomGrayscale', prob=0.)
random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img)
img = random_gray_scale_module(results)['img']
assert_array_equal(img, self.img)
assert img.shape == (10, 10, 3)
# test image with one channel
transform = dict(type='RandomGrayscale', prob=1.)
results['img'] = self.img[:, :, 0:1]
random_gray_scale_module = TRANSFORMS.build(transform)
img = random_gray_scale_module(results)['img']
assert_array_equal(img, self.img[:, :, 0:1])
assert img.shape == (10, 10, 1)
class TestMultiScaleFlipAug:
@classmethod
def setup_class(cls):
cls.img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
cls.original_img = copy.deepcopy(cls.img)
def test_error(self):
# test assertion if img_scale is None
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=None, transforms=[])
TRANSFORMS.build(transform)
# test assertion if img_scale is not tuple or list of tuple
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=[1333, 800], transforms=[])
TRANSFORMS.build(transform)
# test assertion if flip_direction is not str or list of str
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug',
img_scale=[(1333, 800)],
flip_direction=1,
transforms=[])
TRANSFORMS.build(transform)
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
def test_multi_scale_flip_aug(self):
# test with empty transforms
transform = dict(
type='MultiScaleFlipAug',
transforms=[],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12
# test with flip=False
transform = dict(
type='MultiScaleFlipAug',
transforms=[],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 3
# test with transforms
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transforms_cfg = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12
class TestRandomMultiscaleResize:
@classmethod
def setup_class(cls):
cls.img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
cls.original_img = copy.deepcopy(cls.img)
def reset_results(self, results):
results['img'] = copy.deepcopy(self.original_img)
results['gt_semantic_seg'] = copy.deepcopy(self.original_img)
def test_repr(self):
# test repr
transform = dict(
type='RandomMultiscaleResize', scales=[(1333, 800), (1333, 600)])
random_multiscale_resize = TRANSFORMS.build(transform)
assert isinstance(repr(random_multiscale_resize), str)
def test_error(self):
# test assertion if size is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='RandomMultiscaleResize', scales=[0.5, 1, 2])
TRANSFORMS.build(transform)
def test_random_multiscale_resize(self):
results = dict()
# test with one scale
transform = dict(type='RandomMultiscaleResize', scales=[(1333, 800)])
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results = random_multiscale_resize(results)
assert results['img'].shape == (800, 1333, 3)
# test with multi scales
_scale_choice = [(1333, 800), (1333, 600)]
transform = dict(type='RandomMultiscaleResize', scales=_scale_choice)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results = random_multiscale_resize(results)
assert (results['img'].shape[1],
results['img'].shape[0]) in _scale_choice
# test keep_ratio
transform = dict(
type='RandomMultiscaleResize',
scales=[(900, 600)],
keep_ratio=True)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
_input_ratio = results['img'].shape[0] / results['img'].shape[1]
results = random_multiscale_resize(results)
_output_ratio = results['img'].shape[0] / results['img'].shape[1]
assert_array_almost_equal(_input_ratio, _output_ratio)
# test clip_object_border
gt_bboxes = [[200, 150, 600, 450]]
transform = dict(
type='RandomMultiscaleResize',
scales=[(200, 150)],
clip_object_border=True)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes)
results = random_multiscale_resize(results)
assert results['img'].shape == (150, 200, 3)
assert np.equal(results['gt_bboxes'], np.array([[100, 75, 200,
150]])).all()
transform = dict(
type='RandomMultiscaleResize',
scales=[(200, 150)],
clip_object_border=False)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes)
results = random_multiscale_resize(results)
assert results['img'].shape == (150, 200, 3)
assert np.equal(results['gt_bboxes'], np.array([[100, 75, 300,
225]])).all()
class TestRandomFlip:
def test_init(self):