mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add Part3 of data transform (#1735)
* update data transform part3 * update init * rename flip funcs * fix comments * update comments * fix lint * Update mmcv/transforms/processing.py * fix docs format * fix comments * add test pad_val and fix bugs in class Pad * merge updated pad * fix lint * Update tests/test_transforms/test_transforms_processing.pypull/2133/head
parent
5af6c12b81
commit
2619aa9c8e
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import TRANSFORMS
|
||||
from .loading import LoadAnnotation, LoadImageFromFile
|
||||
from .processing import Normalize, Pad, RandomFlip, RandomResize, Resize
|
||||
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
|
||||
RandomFlip, RandomGrayscale, RandomMultiscaleResize,
|
||||
RandomResize, Resize)
|
||||
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
|
||||
|
||||
try:
|
||||
|
@ -10,7 +12,8 @@ except ImportError:
|
|||
__all__ = [
|
||||
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
|
||||
'RandomFlip', 'RandomResize'
|
||||
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
|
||||
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
|
||||
]
|
||||
else:
|
||||
from .formatting import ImageToTensor, ToTensor, to_tensor
|
||||
|
@ -18,5 +21,7 @@ else:
|
|||
__all__ = [
|
||||
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
|
||||
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip', 'RandomResize'
|
||||
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip',
|
||||
'RandomMultiscaleResize', 'CenterCrop', 'RandomGrayscale',
|
||||
'MultiScaleFlipAug', 'RandomResize'
|
||||
]
|
||||
|
|
|
@ -16,9 +16,11 @@ def to_tensor(
|
|||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||
be converted.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the converted data.
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -7,6 +9,9 @@ import mmcv
|
|||
from mmcv.image.geometric import _scale_size
|
||||
from .base import BaseTransform
|
||||
from .builder import TRANSFORMS
|
||||
from .wrappers import Compose
|
||||
|
||||
Number = Union[int, float]
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
|
@ -104,7 +109,7 @@ class Resize(BaseTransform):
|
|||
Defaults to None.
|
||||
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
||||
image. Defaults to False.
|
||||
clip_object_border (bool, optional): Whether to clip the objects
|
||||
clip_object_border (bool): Whether to clip the objects
|
||||
outside the border of the image. In some dataset like MOT17, the gt
|
||||
bboxes are allowed to cross the border of images. Therefore, we
|
||||
don't need to clip the gt bboxes in these cases. Defaults to True.
|
||||
|
@ -292,7 +297,7 @@ class Pad(BaseTransform):
|
|||
None.
|
||||
pad_to_square (bool): Whether to pad the image into a square.
|
||||
Currently only used for YOLOX. Defaults to False.
|
||||
pad_val (int or dict, optional): A dict for padding value.
|
||||
pad_val (int or dict): A dict for padding value.
|
||||
if ``type(pad_val) == int``, the val to pad seg is 255. Defaults to
|
||||
``dict(img=0, seg=255)``.
|
||||
padding_mode (str): Type of padding. Should be: constant, edge,
|
||||
|
@ -374,9 +379,11 @@ class Pad(BaseTransform):
|
|||
``results['pad_shape']``."""
|
||||
if results.get('gt_semantic_seg', None) is not None:
|
||||
pad_val = self.pad_val.get('seg', 255)
|
||||
if isinstance(pad_val, int) and results['gt_semantic_seg'].ndim == 3:
|
||||
pad_val = tuple(
|
||||
[pad_val for _ in range(results['gt_semantic_seg'].shape[2])])
|
||||
if isinstance(pad_val,
|
||||
int) and results['gt_semantic_seg'].ndim == 3:
|
||||
pad_val = tuple([
|
||||
pad_val for _ in range(results['gt_semantic_seg'].shape[2])
|
||||
])
|
||||
results['gt_semantic_seg'] = mmcv.impad(
|
||||
results['gt_semantic_seg'],
|
||||
shape=results['pad_shape'][:2],
|
||||
|
@ -407,59 +414,551 @@ class Pad(BaseTransform):
|
|||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomFlip(BaseTransform):
|
||||
"""Flip the image & bbox & keypoints & segmentation map.
|
||||
|
||||
There are 3 flip modes:
|
||||
|
||||
- ``prob`` is float, ``direction`` is string
|
||||
|
||||
the image will be flipped on the given direction with probability
|
||||
of ``prob`` . E.g., ``prob=0.5``, ``direction='horizontal'``,
|
||||
then image will be horizontally flipped with probability of 0.5.
|
||||
|
||||
- ``prob`` is float, ``direction`` is list of string
|
||||
|
||||
the image will be flipped on the given direction with probability of
|
||||
``prob/len(direction)``. E.g., ``prob=0.5``,
|
||||
``direction=['horizontal', 'vertical']``, then image will be
|
||||
horizontally flipped with probability of 0.25, vertically with
|
||||
probability of 0.25.
|
||||
|
||||
- ``prob`` is list of float, ``direction`` is list of string
|
||||
|
||||
given ``len(prob) == len(direction)``, the image will
|
||||
be ``direction[i]`` ly flipped with probability of ``prob[i]``.
|
||||
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
|
||||
'vertical']``, then image will be horizontally flipped with
|
||||
probability of 0.3, vertically with probability of 0.5.
|
||||
class CenterCrop(BaseTransform):
|
||||
"""Crop the center of the image and segmentation masks. If the crop area
|
||||
exceeds the original image and ``pad_mode`` is not None, the original image
|
||||
will be padded before cropping.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_semantic_seg
|
||||
- gt_keypoints
|
||||
- gt_semantic_seg (optional)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_semantic_seg
|
||||
- gt_keypoints
|
||||
- height
|
||||
- width
|
||||
- gt_semantic_seg (optional)
|
||||
|
||||
Added Key:
|
||||
|
||||
- pad_shape
|
||||
|
||||
|
||||
Args:
|
||||
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
|
||||
with the format of (h, w). If set to an integer, then cropping
|
||||
height and width are equal to this integer.
|
||||
pad_val (Union[Number, Dict[str, Number]]): A dict for
|
||||
padding value. To specify how to set this argument, please see
|
||||
the docstring of class ``Pad``. Defaults to
|
||||
``dict(img=0, seg=255)``.
|
||||
pad_mode (str, optional): Type of padding. Should be: 'constant',
|
||||
'edge', 'reflect' or 'symmetric'. For details, please see the
|
||||
docstring of class ``Pad``. Defaults to 'constant'.
|
||||
pad_cfg (str): Base config for padding. Defaults to
|
||||
``dict(type='Pad')``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crop_size: Union[int, Tuple[int, int]],
|
||||
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255),
|
||||
pad_mode: Optional[str] = None,
|
||||
pad_cfg: dict = dict(type='Pad')
|
||||
) -> None: # flake8: noqa
|
||||
super().__init__()
|
||||
assert isinstance(crop_size, int) or (
|
||||
isinstance(crop_size, tuple) and len(crop_size) == 2
|
||||
), 'The expected crop_size is an integer, or a tuple containing two '
|
||||
'intergers'
|
||||
|
||||
if isinstance(crop_size, int):
|
||||
crop_size = (crop_size, crop_size)
|
||||
assert crop_size[0] > 0 and crop_size[1] > 0
|
||||
self.crop_size = crop_size
|
||||
self.pad_val = pad_val
|
||||
self.pad_mode = pad_mode
|
||||
self.pad_cfg = pad_cfg
|
||||
|
||||
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
|
||||
"""Crop image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
|
||||
"""
|
||||
if results.get('img', None) is not None:
|
||||
img = mmcv.imcrop(results['img'], bboxes=bboxes)
|
||||
img_shape = img.shape
|
||||
results['img'] = img
|
||||
results['height'] = img_shape[0]
|
||||
results['width'] = img_shape[1]
|
||||
results['pad_shape'] = img_shape
|
||||
|
||||
def _crop_seg_map(self, results: dict, bboxes: np.ndarray) -> None:
|
||||
"""Crop semantic segmentation map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
|
||||
"""
|
||||
if results.get('gt_semantic_seg', None) is not None:
|
||||
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes)
|
||||
results['gt_semantic_seg'] = img
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Apply center crop on results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Results with CenterCropped image and semantic segmentation
|
||||
map.
|
||||
"""
|
||||
crop_height, crop_width = self.crop_size[0], self.crop_size[1]
|
||||
|
||||
assert 'img' in results, '`img` is not found in results'
|
||||
img = results['img']
|
||||
# img.shape has length 2 for grayscale, length 3 for color
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
if crop_height > img_height or crop_width > img_width:
|
||||
if self.pad_mode is not None:
|
||||
# pad the area
|
||||
img_height = max(img_height, crop_height)
|
||||
img_width = max(img_width, crop_width)
|
||||
pad_size = (img_width, img_height)
|
||||
_pad_cfg = self.pad_cfg.copy()
|
||||
_pad_cfg.update(
|
||||
dict(
|
||||
size=pad_size,
|
||||
pad_val=self.pad_val,
|
||||
padding_mode=self.pad_mode))
|
||||
pad_transform = TRANSFORMS.build(_pad_cfg)
|
||||
results = pad_transform(results)
|
||||
else:
|
||||
crop_height = min(crop_height, img_height)
|
||||
crop_width = min(crop_width, img_width)
|
||||
|
||||
y1 = max(0, int(round((img_height - crop_height) / 2.)))
|
||||
x1 = max(0, int(round((img_width - crop_width) / 2.)))
|
||||
y2 = min(img_height, y1 + crop_height) - 1
|
||||
x2 = min(img_width, x1 + crop_width) - 1
|
||||
bboxes = np.array([x1, y1, x2, y2])
|
||||
|
||||
# crop the image
|
||||
self._crop_img(results, bboxes)
|
||||
# crop the gt_semantic_seg
|
||||
self._crop_seg_map(results, bboxes)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f', crop_size = {self.crop_size}'
|
||||
repr_str += f', pad_val = {self.pad_val}'
|
||||
repr_str += f', pad_mode = {self.pad_mode}'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomGrayscale(BaseTransform):
|
||||
"""Randomly convert image to grayscale with a probability.
|
||||
|
||||
Required Key:
|
||||
|
||||
- img
|
||||
|
||||
Modified Key:
|
||||
|
||||
- img
|
||||
|
||||
Added Keys:
|
||||
|
||||
- flip
|
||||
- flip_direction
|
||||
- grayscale
|
||||
- grayscale_weights
|
||||
|
||||
Args:
|
||||
prob (float): Probability that image should be converted to
|
||||
grayscale. Defaults to 0.1.
|
||||
keep_channel (bool): Whether keep channel number the same as
|
||||
input. Defaults to False.
|
||||
channel_weights (tuple): Channel weights to compute gray
|
||||
image. Defaults to (1., 1., 1.).
|
||||
color_format (str): Color format set to be any of 'bgr',
|
||||
'rgb', 'hsv'. Note: 'hsv' image will be transformed into 'bgr'
|
||||
format no matter whether it is grayscaled. Defaults to 'bgr'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prob: float = 0.1,
|
||||
keep_channel: bool = False,
|
||||
channel_weights: Sequence[float] = (1., 1., 1.),
|
||||
color_format: str = 'bgr') -> None:
|
||||
super().__init__()
|
||||
assert 0. <= prob <= 1., ('The range of ``prob`` value is [0., 1.],' +
|
||||
f' but got {prob} instead')
|
||||
self.prob = prob
|
||||
self.keep_channel = keep_channel
|
||||
self.channel_weights = channel_weights
|
||||
assert color_format in ['bgr', 'rgb', 'hsv']
|
||||
self.color_format = color_format
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Apply random grayscale on results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Results with grayscale image.
|
||||
"""
|
||||
img = results['img']
|
||||
# convert hsv to bgr
|
||||
if self.color_format == 'hsv':
|
||||
img = mmcv.hsv2bgr(img)
|
||||
img = img[..., None] if img.ndim == 2 else img
|
||||
num_output_channels = img.shape[2]
|
||||
if random.random() < self.prob:
|
||||
if num_output_channels > 1:
|
||||
assert num_output_channels == len(
|
||||
self.channel_weights
|
||||
), 'The length of ``channel_weights`` are supposed to be '
|
||||
f'num_output_channels, but got {len(self.channel_weights)}'
|
||||
' instead.'
|
||||
normalized_weights = (
|
||||
np.array(self.channel_weights) / sum(self.channel_weights))
|
||||
img = (normalized_weights * img).sum(axis=2)
|
||||
if self.keep_channel:
|
||||
img = img[:, :, None]
|
||||
results['img'] = np.dstack(
|
||||
[img for _ in range(num_output_channels)])
|
||||
else:
|
||||
results['img'] = img
|
||||
return results
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f', prob = {self.prob}'
|
||||
repr_str += f', keep_channel = {self.keep_channel}'
|
||||
repr_str += f', channel_weights = {self.channel_weights}'
|
||||
repr_str += f', color_format = {self.color_format}'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MultiScaleFlipAug(BaseTransform):
|
||||
"""Test-time augmentation with multiple scales and flipping.
|
||||
|
||||
An example configuration is as followed:
|
||||
|
||||
.. code-block::
|
||||
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=[(1333, 400), (1333, 800)],
|
||||
flip=True,
|
||||
transforms=[
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=1),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
|
||||
``results`` will be resized using all the sizes in ``img_scale``.
|
||||
If ``flip`` is True, then flipped results will also be added into output
|
||||
list.
|
||||
|
||||
For the above configuration, there are four combinations of resize
|
||||
and flip:
|
||||
|
||||
- Resize to (1333, 400) + no flip
|
||||
- Resize to (1333, 400) + flip
|
||||
- Resize to (1333, 800) + no flip
|
||||
- resize to (1333, 800) + flip
|
||||
|
||||
The four results are then transformed with ``transforms`` argument.
|
||||
After that, results are wrapped into lists of the same length as followed:
|
||||
|
||||
.. code-block::
|
||||
|
||||
dict(
|
||||
img=[...],
|
||||
img_shape=[...],
|
||||
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
|
||||
flip=[False, True, False, True]
|
||||
...
|
||||
)
|
||||
|
||||
Required Keys:
|
||||
|
||||
- Depending on the requirements of the ``transforms`` parameter.
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- All output keys of each transform.
|
||||
|
||||
Args:
|
||||
transforms (list[dict]): Transforms to be applied to each resized
|
||||
and flipped data.
|
||||
img_scale (tuple | list[tuple] | None): Images scales for resizing.
|
||||
flip (bool): Whether apply flip augmentation. Defaults to False.
|
||||
flip_direction (str | list[str]): Flip augmentation directions,
|
||||
options are "horizontal", "vertical" and "diagonal". If
|
||||
flip_direction is a list, multiple flip augmentations will be
|
||||
applied. It has no effect when flip == False. Defaults to
|
||||
"horizontal".
|
||||
resize_cfg (dict): Base config for resizing. Defaults to
|
||||
``dict(type='Resize', keep_ratio=True)``.
|
||||
flip_cfg (dict): Base config for flipping. Defaults to
|
||||
``dict(type='RandomFlip')``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: List[dict],
|
||||
img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
|
||||
flip: bool = False,
|
||||
flip_direction: Union[str, List[str]] = 'horizontal',
|
||||
resize_cfg: dict = dict(type='Resize', keep_ratio=True),
|
||||
flip_cfg: dict = dict(type='RandomFlip')
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.transforms = Compose(transforms) # type: ignore
|
||||
assert img_scale is not None
|
||||
self.img_scale = img_scale if isinstance(img_scale,
|
||||
list) else [img_scale]
|
||||
self.scale_key = 'scale'
|
||||
assert mmcv.is_list_of(self.img_scale, tuple)
|
||||
|
||||
self.flip = flip
|
||||
self.flip_direction = flip_direction if isinstance(
|
||||
flip_direction, list) else [flip_direction]
|
||||
assert mmcv.is_list_of(self.flip_direction, str)
|
||||
if not self.flip and self.flip_direction != ['horizontal']:
|
||||
warnings.warn(
|
||||
'flip_direction has no effect when flip is set to False')
|
||||
self.resize_cfg = resize_cfg
|
||||
self.flip_cfg = flip_cfg
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Apply test time augment transforms on results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: The augmented data, where each value is wrapped
|
||||
into a list.
|
||||
"""
|
||||
|
||||
aug_data = []
|
||||
flip_args = [(False, '')]
|
||||
if self.flip:
|
||||
flip_args += [(True, direction)
|
||||
for direction in self.flip_direction]
|
||||
for scale in self.img_scale:
|
||||
for flip, direction in flip_args:
|
||||
_resize_cfg = self.resize_cfg.copy()
|
||||
_resize_cfg.update(scale=scale)
|
||||
_resize_flip = [_resize_cfg]
|
||||
|
||||
if flip:
|
||||
_flip_cfg = self.flip_cfg.copy()
|
||||
_flip_cfg.update(prob=1.0, direction=direction)
|
||||
_resize_flip.append(_flip_cfg)
|
||||
else:
|
||||
results['flip'] = False
|
||||
results['flip_direction'] = None
|
||||
|
||||
resize_flip = Compose(_resize_flip)
|
||||
_results = results.copy()
|
||||
_results = resize_flip(_results)
|
||||
data = self.transforms(_results)
|
||||
aug_data.append(data)
|
||||
# list of dict to dict of list
|
||||
aug_data_dict = {key: [] for key in aug_data[0]}
|
||||
for data in aug_data:
|
||||
for key, val in data.items():
|
||||
aug_data_dict[key].append(val)
|
||||
return aug_data_dict
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f', transforms={self.transforms}'
|
||||
repr_str += f', img_scale={self.img_scale}'
|
||||
repr_str += f', flip={self.flip}'
|
||||
repr_str += f', flip_direction={self.flip_direction}'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomMultiscaleResize(BaseTransform):
|
||||
"""Resize images & bbox & mask from a list of multiple scales.
|
||||
|
||||
This transform resizes the input image to some scale. Bboxes and masks are
|
||||
then resized with the same scale factor. Resize scale will be randomly
|
||||
selected from ``scales``.
|
||||
|
||||
How to choose the target scale to resize the image will follow the rules
|
||||
below:
|
||||
|
||||
- if `scale` is a list of tuple, the target scale is sampled from the list
|
||||
uniformally.
|
||||
- if `scale` is a tuple, the target scale will be set to the tuple.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes (optional)
|
||||
- gt_semantic_seg (optional)
|
||||
- gt_keypoints (optional)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- height
|
||||
- width
|
||||
- gt_bboxes (optional)
|
||||
- gt_semantic_seg (optional)
|
||||
- gt_keypoints (optional)
|
||||
|
||||
Added Keys:
|
||||
|
||||
- scale
|
||||
- scale_factor
|
||||
- scale_idx
|
||||
- keep_ratio
|
||||
|
||||
|
||||
Args:
|
||||
scales (Union[list, Tuple]): Images scales for resizing.
|
||||
keep_ratio (bool): Whether to keep the aspect ratio when
|
||||
resizing the image. Defaults to False.
|
||||
clip_object_border (bool): Whether clip the objects outside
|
||||
the border of the image. Defaults to True.
|
||||
backend (str): Image resize backend, choices are 'cv2' and
|
||||
'pillow'. These two backends generates slightly different results.
|
||||
Defaults to 'cv2'.
|
||||
interpolation (str): The mode of interpolation, support
|
||||
"bilinear", "bicubic", "nearest". Defaults to "bilinear".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scales: Union[list, Tuple],
|
||||
keep_ratio: bool = False,
|
||||
clip_object_border: bool = True,
|
||||
backend: str = 'cv2',
|
||||
interpolation: str = 'bilinear',
|
||||
resize_cfg: dict = dict(type='Resize')
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(scales, list):
|
||||
self.scales = scales
|
||||
else:
|
||||
self.scales = [scales]
|
||||
assert mmcv.is_list_of(self.scales, tuple)
|
||||
self.keep_ratio = keep_ratio
|
||||
self.clip_object_border = clip_object_border
|
||||
self.backend = backend
|
||||
self.interpolation = interpolation
|
||||
|
||||
self.resize_cfg = resize_cfg
|
||||
|
||||
@staticmethod
|
||||
def random_select(scales: List[Tuple]) -> Tuple[Number, int]:
|
||||
"""Randomly select an img_scale from given candidates.
|
||||
|
||||
Args:
|
||||
scales (list[tuple]): Images scales for selection.
|
||||
|
||||
Returns:
|
||||
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
|
||||
where ``img_scale`` is the selected image scale and
|
||||
``scale_idx`` is the selected index in the given candidates.
|
||||
"""
|
||||
|
||||
assert mmcv.is_list_of(scales, tuple)
|
||||
scale_idx = np.random.randint(len(scales))
|
||||
scale = scales[scale_idx]
|
||||
return scale, scale_idx
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Apply resize transforms on results from a list of scales.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
|
||||
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
|
||||
and 'keep_ratio' keys are updated in result dict.
|
||||
"""
|
||||
|
||||
target_scale, scale_idx = self.random_select(self.scales)
|
||||
_resize_cfg = self.resize_cfg.copy()
|
||||
_resize_cfg.update(
|
||||
dict(
|
||||
scale=target_scale,
|
||||
keep_ratio=self.keep_ratio,
|
||||
clip_object_border=self.clip_object_border,
|
||||
backend=self.backend,
|
||||
interpolation=self.interpolation))
|
||||
resize_transform = TRANSFORMS.build(_resize_cfg)
|
||||
results = resize_transform(results)
|
||||
results['scale_idx'] = scale_idx
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f', scales={self.scales}'
|
||||
repr_str += f', keep_ratio={self.keep_ratio}'
|
||||
repr_str += f', clip_object_border={self.clip_object_border}'
|
||||
repr_str += f', backend={self.backend}'
|
||||
repr_str += f', interpolation={self.interpolation}'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomFlip(BaseTransform):
|
||||
"""Flip the image & bbox & keypoints & segmentation map. Added or Updated
|
||||
keys: flip, flip_direction, img, gt_bboxes, gt_semantic_seg, and
|
||||
gt_keypoints. There are 3 flip modes:
|
||||
|
||||
- ``prob`` is float, ``direction`` is string: the image will be
|
||||
``direction``ly flipped with probability of ``prob`` .
|
||||
E.g., ``prob=0.5``, ``direction='horizontal'``,
|
||||
then image will be horizontally flipped with probability of 0.5.
|
||||
- ``prob`` is float, ``direction`` is list of string: the image will
|
||||
be ``direction[i]``ly flipped with probability of
|
||||
``prob/len(direction)``.
|
||||
E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``,
|
||||
then image will be horizontally flipped with probability of 0.25,
|
||||
vertically with probability of 0.25.
|
||||
- ``prob`` is list of float, ``direction`` is list of string:
|
||||
given ``len(prob) == len(direction)``, the image will
|
||||
be ``direction[i]``ly flipped with probability of ``prob[i]``.
|
||||
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
|
||||
'vertical']``, then image will be horizontally flipped with
|
||||
probability of 0.3, vertically with probability of 0.5.
|
||||
|
||||
Required Keys:
|
||||
- img
|
||||
- gt_bboxes (optional)
|
||||
- gt_semantic_seg (optional)
|
||||
- gt_keypoints (optional)
|
||||
|
||||
Modified Keys:
|
||||
- img
|
||||
- gt_bboxes (optional)
|
||||
- gt_semantic_seg (optional)
|
||||
- gt_keypoints (optional)
|
||||
|
||||
Added Keys:
|
||||
- flip
|
||||
- flip_direction
|
||||
Args:
|
||||
prob (float | list[float], optional): The flipping probability.
|
||||
Defaults to None.
|
||||
direction(str | list[str]): The flipping direction. Options
|
||||
If input is a list, the length must equal ``prob``. Each
|
||||
element in ``prob`` indicates the flip probability of
|
||||
corresponding direction. Defaults to horizontal.
|
||||
corresponding direction. Defaults to 'horizontal'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -473,8 +972,8 @@ class RandomFlip(BaseTransform):
|
|||
elif isinstance(prob, float):
|
||||
assert 0 <= prob <= 1
|
||||
else:
|
||||
raise ValueError(f"probs must be float or list of float, but \
|
||||
got '{type(prob)}'.")
|
||||
raise ValueError(f'probs must be float or list of float, but \
|
||||
got `{type(prob)}`.')
|
||||
self.prob = prob
|
||||
|
||||
valid_directions = ['horizontal', 'vertical', 'diagonal']
|
||||
|
@ -484,8 +983,8 @@ class RandomFlip(BaseTransform):
|
|||
assert mmcv.is_list_of(direction, str)
|
||||
assert set(direction).issubset(set(valid_directions))
|
||||
else:
|
||||
raise ValueError(f"direction must be either str or list of str, \
|
||||
but got '{type(direction)}'.")
|
||||
raise ValueError(f'direction must be either str or list of str, \
|
||||
but got `{type(direction)}`.')
|
||||
self.direction = direction
|
||||
|
||||
if isinstance(prob, list):
|
||||
|
|
|
@ -6,7 +6,18 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mmcv
|
||||
from mmcv.transforms import Normalize, Pad, RandomFlip, RandomResize, Resize
|
||||
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
|
||||
RandomResize, Resize)
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ModuleNotFoundError:
|
||||
torch = None
|
||||
else:
|
||||
import torchvision
|
||||
|
||||
from numpy.testing import assert_array_almost_equal, assert_array_equal
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestNormalize:
|
||||
|
@ -223,6 +234,416 @@ class TestPad:
|
|||
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")
|
||||
|
||||
|
||||
class TestCenterCrop:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
cls.original_img = copy.deepcopy(img)
|
||||
seg = np.random.randint(0, 19, (300, 400)).astype(np.uint8)
|
||||
cls.gt_semantic_map = copy.deepcopy(seg)
|
||||
|
||||
@staticmethod
|
||||
def reset_results(results, original_img, gt_semantic_map):
|
||||
results['img'] = copy.deepcopy(original_img)
|
||||
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
|
||||
return results
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=torch is None, reason='No torch in current env')
|
||||
def test_error(self):
|
||||
# test assertion if size is smaller than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CenterCrop', crop_size=-1)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if size is tuple but one value is smaller than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CenterCrop', crop_size=(224, -1))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if size is tuple and len(size) < 2
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CenterCrop', crop_size=(224, ))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if size is tuple len(size) > 2
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CenterCrop', crop_size=(224, 224, 3))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
def test_repr(self):
|
||||
# test repr
|
||||
transform = dict(type='CenterCrop', crop_size=224)
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
assert isinstance(repr(center_crop_module), str)
|
||||
|
||||
def test_transform(self):
|
||||
results = {}
|
||||
self.reset_results(results, self.original_img, self.gt_semantic_map)
|
||||
|
||||
# test CenterCrop when size is int
|
||||
transform = dict(type='CenterCrop', crop_size=224)
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 224
|
||||
assert results['width'] == 224
|
||||
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
|
||||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
|
||||
88:312]).all()
|
||||
|
||||
# test CenterCrop when size is tuple
|
||||
transform = dict(type='CenterCrop', crop_size=(224, 224))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 224
|
||||
assert results['width'] == 224
|
||||
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
|
||||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
|
||||
88:312]).all()
|
||||
|
||||
# test CenterCrop when crop_height != crop_width
|
||||
transform = dict(type='CenterCrop', crop_size=(256, 224))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 256
|
||||
assert results['width'] == 224
|
||||
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
|
||||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
|
||||
88:312]).all()
|
||||
|
||||
# test CenterCrop when crop_size is equal to img.shape
|
||||
img_height, img_width, _ = self.original_img.shape
|
||||
transform = dict(type='CenterCrop', crop_size=(img_height, img_width))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 300
|
||||
assert results['width'] == 400
|
||||
assert (results['img'] == self.original_img).all()
|
||||
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
|
||||
|
||||
# test CenterCrop when crop_size is larger than img.shape
|
||||
transform = dict(
|
||||
type='CenterCrop', crop_size=(img_height * 2, img_width * 2))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 300
|
||||
assert results['width'] == 400
|
||||
assert (results['img'] == self.original_img).all()
|
||||
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
|
||||
|
||||
# test with padding
|
||||
transform = dict(
|
||||
type='CenterCrop',
|
||||
crop_size=(img_height * 2, img_width // 2),
|
||||
pad_mode='constant',
|
||||
pad_val=12)
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 600
|
||||
assert results['width'] == 200
|
||||
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
|
||||
assert (results['img'][300:600, 100:300, ...] == 12).all()
|
||||
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
|
||||
|
||||
transform = dict(
|
||||
type='CenterCrop',
|
||||
crop_size=(img_height * 2, img_width // 2),
|
||||
pad_mode='constant',
|
||||
pad_val=dict(img=13, seg=33))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == 600
|
||||
assert results['width'] == 200
|
||||
assert (results['img'][300:600, 100:300, ...] == 13).all()
|
||||
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
|
||||
|
||||
# test CenterCrop when crop_width is smaller than img_width
|
||||
transform = dict(
|
||||
type='CenterCrop', crop_size=(img_height, img_width // 2))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == img_height
|
||||
assert results['width'] == img_width // 2
|
||||
assert (results['img'] == self.original_img[:, 100:300, ...]).all()
|
||||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[:,
|
||||
100:300]).all()
|
||||
|
||||
# test CenterCrop when crop_height is smaller than img_height
|
||||
transform = dict(
|
||||
type='CenterCrop', crop_size=(img_height // 2, img_width))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
assert results['height'] == img_height // 2
|
||||
assert results['width'] == img_width
|
||||
assert (results['img'] == self.original_img[75:225, ...]).all()
|
||||
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
|
||||
...]).all()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=torch is None, reason='No torch in current env')
|
||||
def test_torchvision_compare(self):
|
||||
# compare results with torchvision
|
||||
results = {}
|
||||
transform = dict(type='CenterCrop', crop_size=224)
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
results = center_crop_module(results)
|
||||
center_crop_module = torchvision.transforms.CenterCrop(size=224)
|
||||
pil_img = Image.fromarray(self.original_img)
|
||||
pil_seg = Image.fromarray(self.gt_semantic_map)
|
||||
cropped_img = center_crop_module(pil_img)
|
||||
cropped_img = np.array(cropped_img)
|
||||
cropped_seg = center_crop_module(pil_seg)
|
||||
cropped_seg = np.array(cropped_seg)
|
||||
assert np.equal(results['img'], cropped_img).all()
|
||||
assert np.equal(results['gt_semantic_seg'], cropped_seg).all()
|
||||
|
||||
|
||||
class TestRandomGrayscale:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.img = np.random.rand(10, 10, 3).astype(np.float32)
|
||||
|
||||
def test_repr(self):
|
||||
# test repr
|
||||
transform = dict(
|
||||
type='RandomGrayscale',
|
||||
prob=1.,
|
||||
channel_weights=(0.299, 0.587, 0.114),
|
||||
keep_channel=True)
|
||||
random_gray_scale_module = TRANSFORMS.build(transform)
|
||||
assert isinstance(repr(random_gray_scale_module), str)
|
||||
|
||||
def test_error(self):
|
||||
# test invalid argument
|
||||
transform = dict(type='RandomGrayscale', prob=2)
|
||||
with pytest.raises(AssertionError):
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
def test_transform(self):
|
||||
results = dict()
|
||||
# test rgb2gray, return the grayscale image with prob = 1.
|
||||
transform = dict(
|
||||
type='RandomGrayscale',
|
||||
prob=1.,
|
||||
channel_weights=(0.299, 0.587, 0.114),
|
||||
keep_channel=True)
|
||||
|
||||
random_gray_scale_module = TRANSFORMS.build(transform)
|
||||
results['img'] = copy.deepcopy(self.img)
|
||||
img = random_gray_scale_module(results)['img']
|
||||
computed_gray = (
|
||||
self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 +
|
||||
self.img[:, :, 2] * 0.114)
|
||||
for i in range(img.shape[2]):
|
||||
assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4)
|
||||
assert img.shape == (10, 10, 3)
|
||||
|
||||
# test rgb2gray, return the original image with p=0.
|
||||
transform = dict(type='RandomGrayscale', prob=0.)
|
||||
random_gray_scale_module = TRANSFORMS.build(transform)
|
||||
results['img'] = copy.deepcopy(self.img)
|
||||
img = random_gray_scale_module(results)['img']
|
||||
assert_array_equal(img, self.img)
|
||||
assert img.shape == (10, 10, 3)
|
||||
|
||||
# test image with one channel
|
||||
transform = dict(type='RandomGrayscale', prob=1.)
|
||||
results['img'] = self.img[:, :, 0:1]
|
||||
random_gray_scale_module = TRANSFORMS.build(transform)
|
||||
img = random_gray_scale_module(results)['img']
|
||||
assert_array_equal(img, self.img[:, :, 0:1])
|
||||
assert img.shape == (10, 10, 1)
|
||||
|
||||
|
||||
class TestMultiScaleFlipAug:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
cls.original_img = copy.deepcopy(cls.img)
|
||||
|
||||
def test_error(self):
|
||||
# test assertion if img_scale is None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='MultiScaleFlipAug', img_scale=None, transforms=[])
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if img_scale is not tuple or list of tuple
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='MultiScaleFlipAug', img_scale=[1333, 800], transforms=[])
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if flip_direction is not str or list of str
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=[(1333, 800)],
|
||||
flip_direction=1,
|
||||
transforms=[])
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=torch is None, reason='No torch in current env')
|
||||
def test_multi_scale_flip_aug(self):
|
||||
# test with empty transforms
|
||||
transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
transforms=[],
|
||||
img_scale=[(1333, 800), (800, 600), (640, 480)],
|
||||
flip=True,
|
||||
flip_direction=['horizontal', 'vertical', 'diagonal'])
|
||||
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
results['img'] = copy.deepcopy(self.original_img)
|
||||
results = multi_scale_flip_aug_module(results)
|
||||
assert len(results['img']) == 12
|
||||
|
||||
# test with flip=False
|
||||
transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
transforms=[],
|
||||
img_scale=[(1333, 800), (800, 600), (640, 480)],
|
||||
flip=False,
|
||||
flip_direction=['horizontal', 'vertical', 'diagonal'])
|
||||
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
results['img'] = copy.deepcopy(self.original_img)
|
||||
results = multi_scale_flip_aug_module(results)
|
||||
assert len(results['img']) == 3
|
||||
|
||||
# test with transforms
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
transforms_cfg = [
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
]
|
||||
transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
transforms=transforms_cfg,
|
||||
img_scale=[(1333, 800), (800, 600), (640, 480)],
|
||||
flip=True,
|
||||
flip_direction=['horizontal', 'vertical', 'diagonal'])
|
||||
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
results['img'] = copy.deepcopy(self.original_img)
|
||||
results = multi_scale_flip_aug_module(results)
|
||||
assert len(results['img']) == 12
|
||||
|
||||
|
||||
class TestRandomMultiscaleResize:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
cls.original_img = copy.deepcopy(cls.img)
|
||||
|
||||
def reset_results(self, results):
|
||||
results['img'] = copy.deepcopy(self.original_img)
|
||||
results['gt_semantic_seg'] = copy.deepcopy(self.original_img)
|
||||
|
||||
def test_repr(self):
|
||||
# test repr
|
||||
transform = dict(
|
||||
type='RandomMultiscaleResize', scales=[(1333, 800), (1333, 600)])
|
||||
random_multiscale_resize = TRANSFORMS.build(transform)
|
||||
assert isinstance(repr(random_multiscale_resize), str)
|
||||
|
||||
def test_error(self):
|
||||
# test assertion if size is smaller than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMultiscaleResize', scales=[0.5, 1, 2])
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
def test_random_multiscale_resize(self):
|
||||
results = dict()
|
||||
# test with one scale
|
||||
transform = dict(type='RandomMultiscaleResize', scales=[(1333, 800)])
|
||||
random_multiscale_resize = TRANSFORMS.build(transform)
|
||||
self.reset_results(results)
|
||||
results = random_multiscale_resize(results)
|
||||
assert results['img'].shape == (800, 1333, 3)
|
||||
|
||||
# test with multi scales
|
||||
_scale_choice = [(1333, 800), (1333, 600)]
|
||||
transform = dict(type='RandomMultiscaleResize', scales=_scale_choice)
|
||||
random_multiscale_resize = TRANSFORMS.build(transform)
|
||||
self.reset_results(results)
|
||||
results = random_multiscale_resize(results)
|
||||
assert (results['img'].shape[1],
|
||||
results['img'].shape[0]) in _scale_choice
|
||||
|
||||
# test keep_ratio
|
||||
transform = dict(
|
||||
type='RandomMultiscaleResize',
|
||||
scales=[(900, 600)],
|
||||
keep_ratio=True)
|
||||
random_multiscale_resize = TRANSFORMS.build(transform)
|
||||
self.reset_results(results)
|
||||
_input_ratio = results['img'].shape[0] / results['img'].shape[1]
|
||||
results = random_multiscale_resize(results)
|
||||
_output_ratio = results['img'].shape[0] / results['img'].shape[1]
|
||||
assert_array_almost_equal(_input_ratio, _output_ratio)
|
||||
|
||||
# test clip_object_border
|
||||
gt_bboxes = [[200, 150, 600, 450]]
|
||||
transform = dict(
|
||||
type='RandomMultiscaleResize',
|
||||
scales=[(200, 150)],
|
||||
clip_object_border=True)
|
||||
random_multiscale_resize = TRANSFORMS.build(transform)
|
||||
self.reset_results(results)
|
||||
results['gt_bboxes'] = np.array(gt_bboxes)
|
||||
results = random_multiscale_resize(results)
|
||||
assert results['img'].shape == (150, 200, 3)
|
||||
assert np.equal(results['gt_bboxes'], np.array([[100, 75, 200,
|
||||
150]])).all()
|
||||
|
||||
transform = dict(
|
||||
type='RandomMultiscaleResize',
|
||||
scales=[(200, 150)],
|
||||
clip_object_border=False)
|
||||
random_multiscale_resize = TRANSFORMS.build(transform)
|
||||
self.reset_results(results)
|
||||
results['gt_bboxes'] = np.array(gt_bboxes)
|
||||
results = random_multiscale_resize(results)
|
||||
assert results['img'].shape == (150, 200, 3)
|
||||
assert np.equal(results['gt_bboxes'], np.array([[100, 75, 300,
|
||||
225]])).all()
|
||||
|
||||
|
||||
class TestRandomFlip:
|
||||
|
||||
def test_init(self):
|
||||
|
|
Loading…
Reference in New Issue