mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: Add Part2 of data transform (#1730)
* [Refactor]: New commit of Part2 of data transform * [Fix]: Fix lint * [Fix]: Change flip reisze to prefix * [Refactor]: Delete redundant code in ToTensor * [Fix]:optional * [Fix]: Change the discription of RandomFlip * [Refactor]: Change flip_with_flip_direction to flip_on_direction Co-authored-by: Your <you@example.com>pull/2133/head
parent
9e4b2ff58e
commit
53070ebccf
|
@ -1,10 +1,22 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .builder import TRANSFORMS
|
from .builder import TRANSFORMS
|
||||||
from .loading import LoadAnnotation, LoadImageFromFile
|
from .loading import LoadAnnotation, LoadImageFromFile
|
||||||
from .processing import Normalize, Pad, Resize
|
from .processing import Normalize, Pad, RandomFlip, RandomResize, Resize
|
||||||
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
|
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
|
||||||
|
|
||||||
__all__ = [
|
try:
|
||||||
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
import torch # noqa: F401
|
||||||
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad'
|
except ImportError:
|
||||||
]
|
__all__ = [
|
||||||
|
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||||
|
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
|
||||||
|
'RandomFlip', 'RandomResize'
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
from .formatting import ImageToTensor, ToTensor, to_tensor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||||
|
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
|
||||||
|
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip', 'RandomResize'
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
from .base import BaseTransform
|
||||||
|
from .builder import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor(
|
||||||
|
data: Union[torch.Tensor, np.ndarray, Sequence, int,
|
||||||
|
float]) -> torch.Tensor:
|
||||||
|
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||||
|
|
||||||
|
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||||
|
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||||
|
Args:
|
||||||
|
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||||
|
be converted.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the converted data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return data
|
||||||
|
elif isinstance(data, np.ndarray):
|
||||||
|
return torch.from_numpy(data)
|
||||||
|
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||||
|
return torch.tensor(data)
|
||||||
|
elif isinstance(data, int):
|
||||||
|
return torch.LongTensor([data])
|
||||||
|
elif isinstance(data, float):
|
||||||
|
return torch.FloatTensor([data])
|
||||||
|
else:
|
||||||
|
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class ToTensor(BaseTransform):
|
||||||
|
"""Convert some results to :obj:`torch.Tensor` by given keys.
|
||||||
|
|
||||||
|
Required keys:
|
||||||
|
|
||||||
|
- all these keys in `keys`
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- all these keys in `keys`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): Keys that need to be converted to Tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, keys: Sequence[str]) -> None:
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Transform function to convert data to `torch.Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
Returns:
|
||||||
|
dict: `keys` in results will be updated.
|
||||||
|
"""
|
||||||
|
for key in self.keys:
|
||||||
|
|
||||||
|
key_list = key.split('.')
|
||||||
|
cur_item = results
|
||||||
|
for i in range(len(key_list)):
|
||||||
|
if key_list[i] not in cur_item:
|
||||||
|
raise KeyError(f'Can not find key {key}')
|
||||||
|
if i == len(key_list) - 1:
|
||||||
|
cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]])
|
||||||
|
break
|
||||||
|
cur_item = cur_item[key_list[i]]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class ImageToTensor(BaseTransform):
|
||||||
|
"""Convert image to :obj:`torch.Tensor` by given keys.
|
||||||
|
|
||||||
|
The dimension order of input image is (H, W, C). The pipeline will convert
|
||||||
|
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
|
||||||
|
(1, H, W).
|
||||||
|
|
||||||
|
Required keys:
|
||||||
|
|
||||||
|
- all these keys in `keys`
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- all these keys in `keys`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): Key of images to be converted to Tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, keys: dict) -> None:
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Transform function to convert image in results to
|
||||||
|
:obj:`torch.Tensor` and transpose the channel order.
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict contains the image data to convert.
|
||||||
|
Returns:
|
||||||
|
dict: The result dict contains the image converted
|
||||||
|
to :obj:``torch.Tensor`` and transposed to (C, H, W) order.
|
||||||
|
"""
|
||||||
|
for key in self.keys:
|
||||||
|
img = results[key]
|
||||||
|
if len(img.shape) < 3:
|
||||||
|
img = np.expand_dims(img, -1)
|
||||||
|
results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.__class__.__name__ + f'(keys={self.keys})'
|
|
@ -1,5 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Optional, Sequence, Tuple, Union
|
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -400,3 +400,408 @@ class Pad(BaseTransform):
|
||||||
repr_str += f'pad_val={self.pad_val}), '
|
repr_str += f'pad_val={self.pad_val}), '
|
||||||
repr_str += f'padding_mode={self.padding_mode})'
|
repr_str += f'padding_mode={self.padding_mode})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class RandomFlip(BaseTransform):
|
||||||
|
"""Flip the image & bbox & keypoints & segmentation map.
|
||||||
|
|
||||||
|
There are 3 flip modes:
|
||||||
|
|
||||||
|
- ``prob`` is float, ``direction`` is string
|
||||||
|
|
||||||
|
the image will be flipped on the given direction with probability
|
||||||
|
of ``prob`` . E.g., ``prob=0.5``, ``direction='horizontal'``,
|
||||||
|
then image will be horizontally flipped with probability of 0.5.
|
||||||
|
|
||||||
|
- ``prob`` is float, ``direction`` is list of string
|
||||||
|
|
||||||
|
the image will be flipped on the given direction with probability of
|
||||||
|
``prob/len(direction)``. E.g., ``prob=0.5``,
|
||||||
|
``direction=['horizontal', 'vertical']``, then image will be
|
||||||
|
horizontally flipped with probability of 0.25, vertically with
|
||||||
|
probability of 0.25.
|
||||||
|
|
||||||
|
- ``prob`` is list of float, ``direction`` is list of string
|
||||||
|
|
||||||
|
given ``len(prob) == len(direction)``, the image will
|
||||||
|
be ``direction[i]`` ly flipped with probability of ``prob[i]``.
|
||||||
|
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
|
||||||
|
'vertical']``, then image will be horizontally flipped with
|
||||||
|
probability of 0.3, vertically with probability of 0.5.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_bboxes
|
||||||
|
- gt_semantic_seg
|
||||||
|
- gt_keypoints
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_bboxes
|
||||||
|
- gt_semantic_seg
|
||||||
|
- gt_keypoints
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- flip
|
||||||
|
- flip_direction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prob (float | list[float], optional): The flipping probability.
|
||||||
|
Defaults to None.
|
||||||
|
direction(str | list[str]): The flipping direction. Options
|
||||||
|
If input is a list, the length must equal ``prob``. Each
|
||||||
|
element in ``prob`` indicates the flip probability of
|
||||||
|
corresponding direction. Defaults to horizontal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prob: Optional[Union[float, Iterable[float]]] = None,
|
||||||
|
direction: Union[str,
|
||||||
|
Sequence[Optional[str]]] = 'horizontal') -> None:
|
||||||
|
if isinstance(prob, list):
|
||||||
|
assert mmcv.is_list_of(prob, float)
|
||||||
|
assert 0 <= sum(prob) <= 1
|
||||||
|
elif isinstance(prob, float):
|
||||||
|
assert 0 <= prob <= 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"probs must be float or list of float, but \
|
||||||
|
got '{type(prob)}'.")
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
valid_directions = ['horizontal', 'vertical', 'diagonal']
|
||||||
|
if isinstance(direction, str):
|
||||||
|
assert direction in valid_directions
|
||||||
|
elif isinstance(direction, list):
|
||||||
|
assert mmcv.is_list_of(direction, str)
|
||||||
|
assert set(direction).issubset(set(valid_directions))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"direction must be either str or list of str, \
|
||||||
|
but got '{type(direction)}'.")
|
||||||
|
self.direction = direction
|
||||||
|
|
||||||
|
if isinstance(prob, list):
|
||||||
|
assert len(prob) == len(self.direction)
|
||||||
|
|
||||||
|
def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
|
||||||
|
direction: str) -> np.ndarray:
|
||||||
|
"""Flip bboxes horizontally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
|
||||||
|
img_shape (tuple[int]): Image shape (height, width)
|
||||||
|
direction (str): Flip direction. Options are 'horizontal',
|
||||||
|
'vertical'.
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Flipped bounding boxes.
|
||||||
|
"""
|
||||||
|
assert bboxes.shape[-1] % 4 == 0
|
||||||
|
flipped = bboxes.copy()
|
||||||
|
h, w = img_shape
|
||||||
|
if direction == 'horizontal':
|
||||||
|
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
||||||
|
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
||||||
|
elif direction == 'vertical':
|
||||||
|
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
||||||
|
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
||||||
|
elif direction == 'diagonal':
|
||||||
|
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
||||||
|
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
||||||
|
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
||||||
|
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Flipping direction must be 'horizontal', 'vertical', \
|
||||||
|
or 'diagnal', but got '{direction}'")
|
||||||
|
return flipped
|
||||||
|
|
||||||
|
def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int],
|
||||||
|
direction: str) -> np.ndarray:
|
||||||
|
"""Flip keypoints horizontally, vertically or diagnally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints (numpy.ndarray): Keypoints, shape (..., 2)
|
||||||
|
img_shape (tuple[int]): Image shape (height, width)
|
||||||
|
direction (str): Flip direction. Options are 'horizontal',
|
||||||
|
'vertical'.
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Flipped keypoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
meta_info = keypoints[..., 2:]
|
||||||
|
keypoints = keypoints[..., :2]
|
||||||
|
flipped = keypoints.copy()
|
||||||
|
h, w = img_shape
|
||||||
|
if direction == 'horizontal':
|
||||||
|
flipped[..., 0::2] = w - keypoints[..., 0::2]
|
||||||
|
elif direction == 'vertical':
|
||||||
|
flipped[..., 1::2] = h - keypoints[..., 1::2]
|
||||||
|
elif direction == 'diagonal':
|
||||||
|
flipped[..., 0::2] = w - keypoints[..., 0::2]
|
||||||
|
flipped[..., 1::2] = h - keypoints[..., 1::2]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Flipping direction must be 'horizontal', 'vertical', \
|
||||||
|
or 'diagnal', but got '{direction}'")
|
||||||
|
flipped = np.concatenate([keypoints, meta_info], axis=-1)
|
||||||
|
return flipped
|
||||||
|
|
||||||
|
def _choose_direction(self) -> str:
|
||||||
|
"""Choose the flip direction according to `prob` and `direction`"""
|
||||||
|
if isinstance(self.direction,
|
||||||
|
Sequence) and not isinstance(self.direction, str):
|
||||||
|
# None means non-flip
|
||||||
|
direction_list: list = list(self.direction) + [None]
|
||||||
|
elif isinstance(self.direction, str):
|
||||||
|
# None means non-flip
|
||||||
|
direction_list = [self.direction, None]
|
||||||
|
|
||||||
|
if isinstance(self.prob, list):
|
||||||
|
non_prob: float = 1 - sum(self.prob)
|
||||||
|
prob_list = self.prob + [non_prob]
|
||||||
|
elif isinstance(self.prob, float):
|
||||||
|
non_prob = 1. - self.prob
|
||||||
|
# exclude non-flip
|
||||||
|
single_ratio = self.prob / (len(direction_list) - 1)
|
||||||
|
prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob]
|
||||||
|
|
||||||
|
cur_dir = np.random.choice(direction_list, p=prob_list)
|
||||||
|
|
||||||
|
return cur_dir
|
||||||
|
|
||||||
|
def _flip(self, results: dict) -> None:
|
||||||
|
"""Flip images, bounding boxes, semantic segmentation map and
|
||||||
|
keypoints."""
|
||||||
|
# flip image
|
||||||
|
results['img'] = mmcv.imflip(
|
||||||
|
results['img'], direction=results['flip_direction'])
|
||||||
|
|
||||||
|
img_shape = results['img'].shape[:2]
|
||||||
|
|
||||||
|
# flip bboxes
|
||||||
|
if results.get('gt_bboxes', None) is not None:
|
||||||
|
results['gt_bboxes'] = self.flip_bbox(results['gt_bboxes'],
|
||||||
|
img_shape,
|
||||||
|
results['flip_direction'])
|
||||||
|
|
||||||
|
# flip keypoints
|
||||||
|
if results.get('gt_keypoints', None) is not None:
|
||||||
|
results['gt_keypoints'] = self.flip_keypoints(
|
||||||
|
results['gt_keypoints'], img_shape, results['flip_direction'])
|
||||||
|
|
||||||
|
# flip segs
|
||||||
|
if results.get('gt_semantic_seg', None) is not None:
|
||||||
|
results['gt_semantic_seg'] = mmcv.imflip(
|
||||||
|
results['gt_semantic_seg'],
|
||||||
|
direction=results['flip_direction'])
|
||||||
|
|
||||||
|
def _flip_on_direction(self, results: dict) -> None:
|
||||||
|
"""Function to flip images, bounding boxes, semantic segmentation map
|
||||||
|
and keypoints."""
|
||||||
|
cur_dir = self._choose_direction()
|
||||||
|
if cur_dir is None:
|
||||||
|
results['flip'] = False
|
||||||
|
results['flip_direction'] = None
|
||||||
|
else:
|
||||||
|
results['flip'] = True
|
||||||
|
results['flip_direction'] = cur_dir
|
||||||
|
self._flip(results)
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Transform function to flip images, bounding boxes, semantic
|
||||||
|
segmentation map and keypoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
Returns:
|
||||||
|
dict: Flipped results, 'img', 'gt_bboxes', 'gt_semantic_seg',
|
||||||
|
'gt_keypoints', 'flip', and 'flip_direction' keys are
|
||||||
|
updated in result dict.
|
||||||
|
"""
|
||||||
|
self._flip_on_direction(results)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(scale={self.prob}, '
|
||||||
|
repr_str += f'interpolation={self.direction})'
|
||||||
|
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class RandomResize(BaseTransform):
|
||||||
|
"""Random resize images & bbox & keypoints.
|
||||||
|
|
||||||
|
Added or updated keys: scale, scale_factor, keep_ratio, img, height, width,
|
||||||
|
gt_bboxes, gt_semantic_seg, and gt_keypoints.
|
||||||
|
How to choose the target scale to resize the image will follow the rules
|
||||||
|
below:
|
||||||
|
|
||||||
|
- if `scale` is a list of tuple, the first value of the target scale is
|
||||||
|
sampled from [`scale[0][0]`, `scale[1][0]`] uniformally and the second
|
||||||
|
value of the target scale is sampled from [`scale[0][1]`, `scale[1][1]`]
|
||||||
|
uniformally.
|
||||||
|
- if `scale` is a tuple, the first and second values of the target scale
|
||||||
|
is equal to the first and second values of `scale` multiplied by a value
|
||||||
|
sampled from [`ratio_range[0]`, `ratio_range[1]`] uniformally.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_bboxes
|
||||||
|
- gt_semantic_seg
|
||||||
|
- gt_keypoints
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_bboxes
|
||||||
|
- gt_semantic_seg
|
||||||
|
- gt_keypoints
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- scale
|
||||||
|
- scale_factor
|
||||||
|
- keep_ratio
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale (tuple or list[tuple], optional): Images scales for resizing.
|
||||||
|
Defaults to None.
|
||||||
|
ratio_range (tuple[float], optional): (min_ratio, max_ratio).
|
||||||
|
Defaults to None.
|
||||||
|
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
||||||
|
image. Defaults to True.
|
||||||
|
clip_object_border (bool): Whether to clip the objects
|
||||||
|
outside the border of the image. In some dataset like MOT17, the
|
||||||
|
gt bboxes are allowed to cross the border of images. Therefore,
|
||||||
|
we don't need to clip the gt bboxes in these cases.
|
||||||
|
Defaults to True.
|
||||||
|
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
|
||||||
|
These two backends generates slightly different results. Defaults
|
||||||
|
to 'cv2'.
|
||||||
|
interpolation (str): How to interpolate the original image when
|
||||||
|
resizing. Defaults to 'bilinear'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
scale: Union[Tuple[int, int], List[Tuple[int, int]]] = None,
|
||||||
|
ratio_range: Tuple[float, float] = None,
|
||||||
|
keep_ratio: bool = True,
|
||||||
|
clip_object_border: bool = True,
|
||||||
|
backend: str = 'cv2',
|
||||||
|
interpolation: str = 'bilinear') -> None:
|
||||||
|
|
||||||
|
assert scale is not None
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
self.ratio_range = ratio_range
|
||||||
|
self.keep_ratio = keep_ratio
|
||||||
|
self.clip_object_border = clip_object_border
|
||||||
|
self.backend = backend
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
# create a empty Reisize object
|
||||||
|
self.resize = Resize(0)
|
||||||
|
self.resize.keep_ratio = keep_ratio
|
||||||
|
self.resize.clip_object_border = clip_object_border
|
||||||
|
self.resize.backend = backend
|
||||||
|
self.resize.interpolation = interpolation
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _random_sample(scales: Sequence[Tuple[int, int]]) -> Tuple[int, int]:
|
||||||
|
"""Private function to randomly sample a scale from a list of tuples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scales (list[tuple]): Images scale range for sampling.
|
||||||
|
There must be two tuples in scales, which specify the lower
|
||||||
|
and upper bound of image scales.
|
||||||
|
Returns:
|
||||||
|
tuple: Returns the target scale.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert mmcv.is_list_of(scales, tuple) and len(scales) == 2
|
||||||
|
scale_long = [max(s) for s in scales]
|
||||||
|
scale_short = [min(s) for s in scales]
|
||||||
|
long_edge = np.random.randint(min(scale_long), max(scale_long) + 1)
|
||||||
|
short_edge = np.random.randint(min(scale_short), max(scale_short) + 1)
|
||||||
|
scale = (long_edge, short_edge)
|
||||||
|
return scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _random_sample_ratio(
|
||||||
|
scale: tuple, ratio_range: Tuple[float, float]) -> Tuple[int, int]:
|
||||||
|
"""Private function to randomly sample a scale from a tuple.
|
||||||
|
|
||||||
|
A ratio will be randomly sampled from the range specified by
|
||||||
|
``ratio_range``. Then it would be multiplied with ``scale`` to
|
||||||
|
generate sampled scale.
|
||||||
|
Args:
|
||||||
|
scale (tuple): Images scale base to multiply with ratio.
|
||||||
|
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
||||||
|
the ``scale``.
|
||||||
|
Returns:
|
||||||
|
tuple: Returns the target scale.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(scale, tuple) and len(scale) == 2
|
||||||
|
min_ratio, max_ratio = ratio_range
|
||||||
|
assert min_ratio <= max_ratio
|
||||||
|
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
||||||
|
scale = int(scale[0] * ratio), int(scale[1] * ratio)
|
||||||
|
return scale
|
||||||
|
|
||||||
|
def _random_scale(self, results: dict) -> None:
|
||||||
|
"""Private function to randomly sample an scale according to the type
|
||||||
|
of `scale`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from :obj:`dataset`.
|
||||||
|
Returns:
|
||||||
|
dict: One new key 'scale`is added into ``results``,
|
||||||
|
which would be used by subsequent pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(self.scale, tuple):
|
||||||
|
assert self.ratio_range is not None and len(self.ratio_range) == 2
|
||||||
|
scale: Tuple[int, int] = self._random_sample_ratio(
|
||||||
|
self.scale, self.ratio_range)
|
||||||
|
elif mmcv.is_list_of(self.scale, tuple):
|
||||||
|
scale = self._random_sample(self.scale)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Do not support sampling function \
|
||||||
|
for '{self.scale}'")
|
||||||
|
|
||||||
|
results['scale'] = scale
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Transform function to resize images, bounding boxes, semantic
|
||||||
|
segmentation map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
Returns:
|
||||||
|
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
|
||||||
|
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
|
||||||
|
and 'keep_ratio' keys are updated in result dict.
|
||||||
|
"""
|
||||||
|
self._random_scale(results)
|
||||||
|
self.resize.scale = results['scale']
|
||||||
|
results = self.resize.transform(results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(scale={self.scale}, '
|
||||||
|
repr_str += f'ratio_range={self.ratio_range}, '
|
||||||
|
repr_str += f'keep_ratio={self.keep_ratio}, '
|
||||||
|
repr_str += f'bbox_clip_border={self.clip_object_border}, '
|
||||||
|
repr_str += f'backend={self.backend}, '
|
||||||
|
repr_str += f'interpolation={self.interpolation})'
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
torch = None
|
||||||
|
else:
|
||||||
|
from mmcv.transforms import ToTensor, to_tensor, ImageToTensor
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(condition=torch is None, reason='No torch in current env')
|
||||||
|
def test_to_tensor():
|
||||||
|
|
||||||
|
# The type of the input object is torch.Tensor
|
||||||
|
data_tensor = torch.tensor([1, 2, 3])
|
||||||
|
tensor_from_tensor = to_tensor(data_tensor)
|
||||||
|
assert isinstance(tensor_from_tensor, torch.Tensor)
|
||||||
|
|
||||||
|
# The type of the input object is numpy.ndarray
|
||||||
|
data_numpy = np.array([1, 2, 3])
|
||||||
|
tensor_from_numpy = to_tensor(data_numpy)
|
||||||
|
assert isinstance(tensor_from_numpy, torch.Tensor)
|
||||||
|
|
||||||
|
# The type of the input object is list
|
||||||
|
data_list = [1, 2, 3]
|
||||||
|
tensor_from_list = to_tensor(data_list)
|
||||||
|
assert isinstance(tensor_from_list, torch.Tensor)
|
||||||
|
|
||||||
|
# The type of the input object is int
|
||||||
|
data_int = 1
|
||||||
|
tensor_from_int = to_tensor(data_int)
|
||||||
|
assert isinstance(tensor_from_int, torch.Tensor)
|
||||||
|
|
||||||
|
# The type of the input object is float
|
||||||
|
data_float = 1.0
|
||||||
|
tensor_from_float = to_tensor(data_float)
|
||||||
|
assert isinstance(tensor_from_float, torch.Tensor)
|
||||||
|
|
||||||
|
# The type of the input object is invalid
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
data_str = '123'
|
||||||
|
_ = to_tensor(data_str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(condition=torch is None, reason='No torch in current env')
|
||||||
|
class TestToTensor:
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
TRANSFORM = ToTensor(keys=['img_label'])
|
||||||
|
assert TRANSFORM.keys == ['img_label']
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
TRANSFORMS = ToTensor(['instances.bbox', 'img_label'])
|
||||||
|
|
||||||
|
# Test multi-level key and single-level key (multi-level key is
|
||||||
|
# not in results)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
results = {'instances': {'label': [1]}, 'img_label': [1]}
|
||||||
|
results_tensor = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert isinstance(results_tensor['instances']['label'], list)
|
||||||
|
assert isinstance(results_tensor['img_label'], torch.Tensor)
|
||||||
|
|
||||||
|
# Test multi-level key (multi-level key is in results)
|
||||||
|
results = {'instances': {'bbox': [[0, 0, 10, 10]]}, 'img_label': [1]}
|
||||||
|
results_tensor = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert isinstance(results_tensor['instances']['bbox'], torch.Tensor)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
TRANSFORMS = ToTensor(['instances.bbox', 'img_label'])
|
||||||
|
TRANSFORMS_str = str(TRANSFORMS)
|
||||||
|
isinstance(TRANSFORMS_str, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(condition=torch is None, reason='No torch in current env')
|
||||||
|
class TestImageToTensor:
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
TRANSFORMS = ImageToTensor(['img'])
|
||||||
|
assert TRANSFORMS.keys == ['img']
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
TRANSFORMS = ImageToTensor(['img'])
|
||||||
|
|
||||||
|
# image only has one channel
|
||||||
|
results = {'img': np.zeros((224, 224))}
|
||||||
|
results = TRANSFORMS.transform(results)
|
||||||
|
assert results['img'].shape == (1, 224, 224)
|
||||||
|
|
||||||
|
# image has three channels
|
||||||
|
results = {'img': np.zeros((224, 224, 3))}
|
||||||
|
results = TRANSFORMS.transform(results)
|
||||||
|
assert results['img'].shape == (3, 224, 224)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
TRANSFORMS = ImageToTensor(['img'])
|
||||||
|
TRANSFORMS_str = str(TRANSFORMS)
|
||||||
|
assert isinstance(TRANSFORMS_str, str)
|
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
from mmcv.transforms import Normalize, Pad, Resize
|
from mmcv.transforms import Normalize, Pad, RandomFlip, RandomResize, Resize
|
||||||
|
|
||||||
|
|
||||||
class TestNormalize:
|
class TestNormalize:
|
||||||
|
@ -183,3 +183,154 @@ class TestPad:
|
||||||
assert repr(trans) == (
|
assert repr(trans) == (
|
||||||
'Pad(size=None, size_divisor=11, pad_to_square=True, '
|
'Pad(size=None, size_divisor=11, pad_to_square=True, '
|
||||||
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")
|
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandomFlip:
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
|
||||||
|
# prob is float
|
||||||
|
TRANSFORMS = RandomFlip(0.1)
|
||||||
|
assert TRANSFORMS.prob == 0.1
|
||||||
|
|
||||||
|
# prob is None
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TRANSFORMS = RandomFlip(None)
|
||||||
|
assert TRANSFORMS.prob is None
|
||||||
|
|
||||||
|
# prob is a list
|
||||||
|
TRANSFORMS = RandomFlip([0.1, 0.2], ['horizontal', 'vertical'])
|
||||||
|
assert len(TRANSFORMS.prob) == 2
|
||||||
|
assert len(TRANSFORMS.direction) == 2
|
||||||
|
|
||||||
|
# direction is an invalid type
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TRANSFORMS = RandomFlip(0.1, 1)
|
||||||
|
|
||||||
|
# prob is an invalid type
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TRANSFORMS = RandomFlip('0.1')
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'img': np.random.random((224, 224, 3)),
|
||||||
|
'gt_bboxes': np.array([[0, 1, 100, 101]]),
|
||||||
|
'gt_keypoints': np.array([[[100, 100, 1.0]]]),
|
||||||
|
'gt_semantic_seg': np.random.random((224, 224, 3))
|
||||||
|
}
|
||||||
|
|
||||||
|
# horizontal flip
|
||||||
|
TRANSFORMS = RandomFlip([1.0], ['horizontal'])
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
|
||||||
|
101]])).all()
|
||||||
|
|
||||||
|
# diagnal flip
|
||||||
|
TRANSFORMS = RandomFlip([1.0], ['diagonal'])
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert (results_update['gt_bboxes'] == np.array([[124, 123, 224,
|
||||||
|
223]])).all()
|
||||||
|
|
||||||
|
# vertical flip
|
||||||
|
TRANSFORMS = RandomFlip([1.0], ['vertical'])
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert (results_update['gt_bboxes'] == np.array([[0, 123, 100,
|
||||||
|
223]])).all()
|
||||||
|
|
||||||
|
# horizontal flip when direction is None
|
||||||
|
TRANSFORMS = RandomFlip(1.0)
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
|
||||||
|
101]])).all()
|
||||||
|
|
||||||
|
TRANSFORMS = RandomFlip(0.0)
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert (results_update['gt_bboxes'] == np.array([[0, 1, 100,
|
||||||
|
101]])).all()
|
||||||
|
|
||||||
|
# flip direction is invalid in bbox flip
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TRANSFORMS = RandomFlip(1.0)
|
||||||
|
results_update = TRANSFORMS.flip_bbox(results['gt_bboxes'],
|
||||||
|
(224, 224), 'invalid')
|
||||||
|
|
||||||
|
# flip direction is invalid in keypoints flip
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TRANSFORMS = RandomFlip(1.0)
|
||||||
|
results_update = TRANSFORMS.flip_keypoints(results['gt_keypoints'],
|
||||||
|
(224, 224), 'invalid')
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
TRANSFORMS = RandomFlip(0.1)
|
||||||
|
TRANSFORMS_str = str(TRANSFORMS)
|
||||||
|
assert isinstance(TRANSFORMS_str, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandomResize:
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
TRANSFORMS = RandomResize(
|
||||||
|
(224, 224),
|
||||||
|
(1.0, 2.0),
|
||||||
|
)
|
||||||
|
assert TRANSFORMS.scale == (224, 224)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
TRANSFORMS = RandomResize(
|
||||||
|
(224, 224),
|
||||||
|
(1.0, 2.0),
|
||||||
|
)
|
||||||
|
TRANSFORMS_str = str(TRANSFORMS)
|
||||||
|
assert isinstance(TRANSFORMS_str, str)
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
|
||||||
|
# choose target scale from init when override is True
|
||||||
|
results = {}
|
||||||
|
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0))
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert results_update['scale'][0] >= 224 and results_update['scale'][
|
||||||
|
0] <= 448
|
||||||
|
assert results_update['scale'][1] >= 224 and results_update['scale'][
|
||||||
|
1] <= 448
|
||||||
|
|
||||||
|
# keep ratio is True
|
||||||
|
results = {
|
||||||
|
'img': np.random.random((224, 224, 3)),
|
||||||
|
'gt_semantic_seg': np.random.random((224, 224, 3)),
|
||||||
|
'gt_bboxes': np.array([[0, 0, 112, 112]]),
|
||||||
|
'gt_keypoints': np.array([[[112, 112]]])
|
||||||
|
}
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0), keep_ratio=True)
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert 224 <= results_update['height']
|
||||||
|
assert 448 >= results_update['height']
|
||||||
|
assert 224 <= results_update['width']
|
||||||
|
assert 448 >= results_update['width']
|
||||||
|
assert results_update['keep_ratio']
|
||||||
|
assert results['gt_bboxes'][0][2] >= 112
|
||||||
|
assert results['gt_bboxes'][0][2] <= 112
|
||||||
|
|
||||||
|
# keep ratio is False
|
||||||
|
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0), keep_ratio=False)
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
|
||||||
|
# choose target scale from init when override is False and scale is a
|
||||||
|
# list of tuples
|
||||||
|
results = {}
|
||||||
|
TRANSFORMS = RandomResize([(224, 448), (112, 224)], keep_ratio=True)
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
assert results_update['scale'][0] >= 224 and results_update['scale'][
|
||||||
|
0] <= 448
|
||||||
|
assert results_update['scale'][1] >= 112 and results_update['scale'][
|
||||||
|
1] <= 224
|
||||||
|
|
||||||
|
# the type of scale is invalid in init
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
results = {}
|
||||||
|
TRANSFORMS = RandomResize([(224, 448), [112, 224]],
|
||||||
|
keep_ratio=True)
|
||||||
|
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||||
|
|
Loading…
Reference in New Issue