mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: Add Part2 of data transform (#1730)
* [Refactor]: New commit of Part2 of data transform * [Fix]: Fix lint * [Fix]: Change flip reisze to prefix * [Refactor]: Delete redundant code in ToTensor * [Fix]:optional * [Fix]: Change the discription of RandomFlip * [Refactor]: Change flip_with_flip_direction to flip_on_direction Co-authored-by: Your <you@example.com>pull/2133/head
parent
9e4b2ff58e
commit
53070ebccf
|
@ -1,10 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import TRANSFORMS
|
||||
from .loading import LoadAnnotation, LoadImageFromFile
|
||||
from .processing import Normalize, Pad, Resize
|
||||
from .processing import Normalize, Pad, RandomFlip, RandomResize, Resize
|
||||
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
|
||||
|
||||
__all__ = [
|
||||
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad'
|
||||
]
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
__all__ = [
|
||||
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
|
||||
'RandomFlip', 'RandomResize'
|
||||
]
|
||||
else:
|
||||
from .formatting import ImageToTensor, ToTensor, to_tensor
|
||||
|
||||
__all__ = [
|
||||
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
|
||||
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
|
||||
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip', 'RandomResize'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import mmcv
|
||||
from .base import BaseTransform
|
||||
from .builder import TRANSFORMS
|
||||
|
||||
|
||||
def to_tensor(
|
||||
data: Union[torch.Tensor, np.ndarray, Sequence, int,
|
||||
float]) -> torch.Tensor:
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
Args:
|
||||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||
be converted.
|
||||
Returns:
|
||||
torch.Tensor: the converted data.
|
||||
"""
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
elif isinstance(data, float):
|
||||
return torch.FloatTensor([data])
|
||||
else:
|
||||
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensor(BaseTransform):
|
||||
"""Convert some results to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
Required keys:
|
||||
|
||||
- all these keys in `keys`
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- all these keys in `keys`
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys that need to be converted to Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys: Sequence[str]) -> None:
|
||||
self.keys = keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to convert data to `torch.Tensor`.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
Returns:
|
||||
dict: `keys` in results will be updated.
|
||||
"""
|
||||
for key in self.keys:
|
||||
|
||||
key_list = key.split('.')
|
||||
cur_item = results
|
||||
for i in range(len(key_list)):
|
||||
if key_list[i] not in cur_item:
|
||||
raise KeyError(f'Can not find key {key}')
|
||||
if i == len(key_list) - 1:
|
||||
cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]])
|
||||
break
|
||||
cur_item = cur_item[key_list[i]]
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ImageToTensor(BaseTransform):
|
||||
"""Convert image to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
The dimension order of input image is (H, W, C). The pipeline will convert
|
||||
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
|
||||
(1, H, W).
|
||||
|
||||
Required keys:
|
||||
|
||||
- all these keys in `keys`
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- all these keys in `keys`
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Key of images to be converted to Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys: dict) -> None:
|
||||
self.keys = keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to convert image in results to
|
||||
:obj:`torch.Tensor` and transpose the channel order.
|
||||
Args:
|
||||
results (dict): Result dict contains the image data to convert.
|
||||
Returns:
|
||||
dict: The result dict contains the image converted
|
||||
to :obj:``torch.Tensor`` and transposed to (C, H, W) order.
|
||||
"""
|
||||
for key in self.keys:
|
||||
img = results[key]
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous()
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -400,3 +400,408 @@ class Pad(BaseTransform):
|
|||
repr_str += f'pad_val={self.pad_val}), '
|
||||
repr_str += f'padding_mode={self.padding_mode})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomFlip(BaseTransform):
|
||||
"""Flip the image & bbox & keypoints & segmentation map.
|
||||
|
||||
There are 3 flip modes:
|
||||
|
||||
- ``prob`` is float, ``direction`` is string
|
||||
|
||||
the image will be flipped on the given direction with probability
|
||||
of ``prob`` . E.g., ``prob=0.5``, ``direction='horizontal'``,
|
||||
then image will be horizontally flipped with probability of 0.5.
|
||||
|
||||
- ``prob`` is float, ``direction`` is list of string
|
||||
|
||||
the image will be flipped on the given direction with probability of
|
||||
``prob/len(direction)``. E.g., ``prob=0.5``,
|
||||
``direction=['horizontal', 'vertical']``, then image will be
|
||||
horizontally flipped with probability of 0.25, vertically with
|
||||
probability of 0.25.
|
||||
|
||||
- ``prob`` is list of float, ``direction`` is list of string
|
||||
|
||||
given ``len(prob) == len(direction)``, the image will
|
||||
be ``direction[i]`` ly flipped with probability of ``prob[i]``.
|
||||
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
|
||||
'vertical']``, then image will be horizontally flipped with
|
||||
probability of 0.3, vertically with probability of 0.5.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_semantic_seg
|
||||
- gt_keypoints
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_semantic_seg
|
||||
- gt_keypoints
|
||||
|
||||
Added Keys:
|
||||
|
||||
- flip
|
||||
- flip_direction
|
||||
|
||||
Args:
|
||||
prob (float | list[float], optional): The flipping probability.
|
||||
Defaults to None.
|
||||
direction(str | list[str]): The flipping direction. Options
|
||||
If input is a list, the length must equal ``prob``. Each
|
||||
element in ``prob`` indicates the flip probability of
|
||||
corresponding direction. Defaults to horizontal.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prob: Optional[Union[float, Iterable[float]]] = None,
|
||||
direction: Union[str,
|
||||
Sequence[Optional[str]]] = 'horizontal') -> None:
|
||||
if isinstance(prob, list):
|
||||
assert mmcv.is_list_of(prob, float)
|
||||
assert 0 <= sum(prob) <= 1
|
||||
elif isinstance(prob, float):
|
||||
assert 0 <= prob <= 1
|
||||
else:
|
||||
raise ValueError(f"probs must be float or list of float, but \
|
||||
got '{type(prob)}'.")
|
||||
self.prob = prob
|
||||
|
||||
valid_directions = ['horizontal', 'vertical', 'diagonal']
|
||||
if isinstance(direction, str):
|
||||
assert direction in valid_directions
|
||||
elif isinstance(direction, list):
|
||||
assert mmcv.is_list_of(direction, str)
|
||||
assert set(direction).issubset(set(valid_directions))
|
||||
else:
|
||||
raise ValueError(f"direction must be either str or list of str, \
|
||||
but got '{type(direction)}'.")
|
||||
self.direction = direction
|
||||
|
||||
if isinstance(prob, list):
|
||||
assert len(prob) == len(self.direction)
|
||||
|
||||
def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
|
||||
direction: str) -> np.ndarray:
|
||||
"""Flip bboxes horizontally.
|
||||
|
||||
Args:
|
||||
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
|
||||
img_shape (tuple[int]): Image shape (height, width)
|
||||
direction (str): Flip direction. Options are 'horizontal',
|
||||
'vertical'.
|
||||
Returns:
|
||||
numpy.ndarray: Flipped bounding boxes.
|
||||
"""
|
||||
assert bboxes.shape[-1] % 4 == 0
|
||||
flipped = bboxes.copy()
|
||||
h, w = img_shape
|
||||
if direction == 'horizontal':
|
||||
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
||||
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
||||
elif direction == 'vertical':
|
||||
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
||||
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
||||
elif direction == 'diagonal':
|
||||
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
||||
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
||||
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
||||
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Flipping direction must be 'horizontal', 'vertical', \
|
||||
or 'diagnal', but got '{direction}'")
|
||||
return flipped
|
||||
|
||||
def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int],
|
||||
direction: str) -> np.ndarray:
|
||||
"""Flip keypoints horizontally, vertically or diagnally.
|
||||
|
||||
Args:
|
||||
keypoints (numpy.ndarray): Keypoints, shape (..., 2)
|
||||
img_shape (tuple[int]): Image shape (height, width)
|
||||
direction (str): Flip direction. Options are 'horizontal',
|
||||
'vertical'.
|
||||
Returns:
|
||||
numpy.ndarray: Flipped keypoints.
|
||||
"""
|
||||
|
||||
meta_info = keypoints[..., 2:]
|
||||
keypoints = keypoints[..., :2]
|
||||
flipped = keypoints.copy()
|
||||
h, w = img_shape
|
||||
if direction == 'horizontal':
|
||||
flipped[..., 0::2] = w - keypoints[..., 0::2]
|
||||
elif direction == 'vertical':
|
||||
flipped[..., 1::2] = h - keypoints[..., 1::2]
|
||||
elif direction == 'diagonal':
|
||||
flipped[..., 0::2] = w - keypoints[..., 0::2]
|
||||
flipped[..., 1::2] = h - keypoints[..., 1::2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Flipping direction must be 'horizontal', 'vertical', \
|
||||
or 'diagnal', but got '{direction}'")
|
||||
flipped = np.concatenate([keypoints, meta_info], axis=-1)
|
||||
return flipped
|
||||
|
||||
def _choose_direction(self) -> str:
|
||||
"""Choose the flip direction according to `prob` and `direction`"""
|
||||
if isinstance(self.direction,
|
||||
Sequence) and not isinstance(self.direction, str):
|
||||
# None means non-flip
|
||||
direction_list: list = list(self.direction) + [None]
|
||||
elif isinstance(self.direction, str):
|
||||
# None means non-flip
|
||||
direction_list = [self.direction, None]
|
||||
|
||||
if isinstance(self.prob, list):
|
||||
non_prob: float = 1 - sum(self.prob)
|
||||
prob_list = self.prob + [non_prob]
|
||||
elif isinstance(self.prob, float):
|
||||
non_prob = 1. - self.prob
|
||||
# exclude non-flip
|
||||
single_ratio = self.prob / (len(direction_list) - 1)
|
||||
prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob]
|
||||
|
||||
cur_dir = np.random.choice(direction_list, p=prob_list)
|
||||
|
||||
return cur_dir
|
||||
|
||||
def _flip(self, results: dict) -> None:
|
||||
"""Flip images, bounding boxes, semantic segmentation map and
|
||||
keypoints."""
|
||||
# flip image
|
||||
results['img'] = mmcv.imflip(
|
||||
results['img'], direction=results['flip_direction'])
|
||||
|
||||
img_shape = results['img'].shape[:2]
|
||||
|
||||
# flip bboxes
|
||||
if results.get('gt_bboxes', None) is not None:
|
||||
results['gt_bboxes'] = self.flip_bbox(results['gt_bboxes'],
|
||||
img_shape,
|
||||
results['flip_direction'])
|
||||
|
||||
# flip keypoints
|
||||
if results.get('gt_keypoints', None) is not None:
|
||||
results['gt_keypoints'] = self.flip_keypoints(
|
||||
results['gt_keypoints'], img_shape, results['flip_direction'])
|
||||
|
||||
# flip segs
|
||||
if results.get('gt_semantic_seg', None) is not None:
|
||||
results['gt_semantic_seg'] = mmcv.imflip(
|
||||
results['gt_semantic_seg'],
|
||||
direction=results['flip_direction'])
|
||||
|
||||
def _flip_on_direction(self, results: dict) -> None:
|
||||
"""Function to flip images, bounding boxes, semantic segmentation map
|
||||
and keypoints."""
|
||||
cur_dir = self._choose_direction()
|
||||
if cur_dir is None:
|
||||
results['flip'] = False
|
||||
results['flip_direction'] = None
|
||||
else:
|
||||
results['flip'] = True
|
||||
results['flip_direction'] = cur_dir
|
||||
self._flip(results)
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to flip images, bounding boxes, semantic
|
||||
segmentation map and keypoints.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
Returns:
|
||||
dict: Flipped results, 'img', 'gt_bboxes', 'gt_semantic_seg',
|
||||
'gt_keypoints', 'flip', and 'flip_direction' keys are
|
||||
updated in result dict.
|
||||
"""
|
||||
self._flip_on_direction(results)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(scale={self.prob}, '
|
||||
repr_str += f'interpolation={self.direction})'
|
||||
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomResize(BaseTransform):
|
||||
"""Random resize images & bbox & keypoints.
|
||||
|
||||
Added or updated keys: scale, scale_factor, keep_ratio, img, height, width,
|
||||
gt_bboxes, gt_semantic_seg, and gt_keypoints.
|
||||
How to choose the target scale to resize the image will follow the rules
|
||||
below:
|
||||
|
||||
- if `scale` is a list of tuple, the first value of the target scale is
|
||||
sampled from [`scale[0][0]`, `scale[1][0]`] uniformally and the second
|
||||
value of the target scale is sampled from [`scale[0][1]`, `scale[1][1]`]
|
||||
uniformally.
|
||||
- if `scale` is a tuple, the first and second values of the target scale
|
||||
is equal to the first and second values of `scale` multiplied by a value
|
||||
sampled from [`ratio_range[0]`, `ratio_range[1]`] uniformally.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_semantic_seg
|
||||
- gt_keypoints
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- gt_bboxes
|
||||
- gt_semantic_seg
|
||||
- gt_keypoints
|
||||
|
||||
Added Keys:
|
||||
|
||||
- scale
|
||||
- scale_factor
|
||||
- keep_ratio
|
||||
|
||||
Args:
|
||||
scale (tuple or list[tuple], optional): Images scales for resizing.
|
||||
Defaults to None.
|
||||
ratio_range (tuple[float], optional): (min_ratio, max_ratio).
|
||||
Defaults to None.
|
||||
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
||||
image. Defaults to True.
|
||||
clip_object_border (bool): Whether to clip the objects
|
||||
outside the border of the image. In some dataset like MOT17, the
|
||||
gt bboxes are allowed to cross the border of images. Therefore,
|
||||
we don't need to clip the gt bboxes in these cases.
|
||||
Defaults to True.
|
||||
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
|
||||
These two backends generates slightly different results. Defaults
|
||||
to 'cv2'.
|
||||
interpolation (str): How to interpolate the original image when
|
||||
resizing. Defaults to 'bilinear'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scale: Union[Tuple[int, int], List[Tuple[int, int]]] = None,
|
||||
ratio_range: Tuple[float, float] = None,
|
||||
keep_ratio: bool = True,
|
||||
clip_object_border: bool = True,
|
||||
backend: str = 'cv2',
|
||||
interpolation: str = 'bilinear') -> None:
|
||||
|
||||
assert scale is not None
|
||||
|
||||
self.scale = scale
|
||||
self.ratio_range = ratio_range
|
||||
self.keep_ratio = keep_ratio
|
||||
self.clip_object_border = clip_object_border
|
||||
self.backend = backend
|
||||
self.interpolation = interpolation
|
||||
|
||||
# create a empty Reisize object
|
||||
self.resize = Resize(0)
|
||||
self.resize.keep_ratio = keep_ratio
|
||||
self.resize.clip_object_border = clip_object_border
|
||||
self.resize.backend = backend
|
||||
self.resize.interpolation = interpolation
|
||||
|
||||
@staticmethod
|
||||
def _random_sample(scales: Sequence[Tuple[int, int]]) -> Tuple[int, int]:
|
||||
"""Private function to randomly sample a scale from a list of tuples.
|
||||
|
||||
Args:
|
||||
scales (list[tuple]): Images scale range for sampling.
|
||||
There must be two tuples in scales, which specify the lower
|
||||
and upper bound of image scales.
|
||||
Returns:
|
||||
tuple: Returns the target scale.
|
||||
"""
|
||||
|
||||
assert mmcv.is_list_of(scales, tuple) and len(scales) == 2
|
||||
scale_long = [max(s) for s in scales]
|
||||
scale_short = [min(s) for s in scales]
|
||||
long_edge = np.random.randint(min(scale_long), max(scale_long) + 1)
|
||||
short_edge = np.random.randint(min(scale_short), max(scale_short) + 1)
|
||||
scale = (long_edge, short_edge)
|
||||
return scale
|
||||
|
||||
@staticmethod
|
||||
def _random_sample_ratio(
|
||||
scale: tuple, ratio_range: Tuple[float, float]) -> Tuple[int, int]:
|
||||
"""Private function to randomly sample a scale from a tuple.
|
||||
|
||||
A ratio will be randomly sampled from the range specified by
|
||||
``ratio_range``. Then it would be multiplied with ``scale`` to
|
||||
generate sampled scale.
|
||||
Args:
|
||||
scale (tuple): Images scale base to multiply with ratio.
|
||||
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
||||
the ``scale``.
|
||||
Returns:
|
||||
tuple: Returns the target scale.
|
||||
"""
|
||||
|
||||
assert isinstance(scale, tuple) and len(scale) == 2
|
||||
min_ratio, max_ratio = ratio_range
|
||||
assert min_ratio <= max_ratio
|
||||
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
||||
scale = int(scale[0] * ratio), int(scale[1] * ratio)
|
||||
return scale
|
||||
|
||||
def _random_scale(self, results: dict) -> None:
|
||||
"""Private function to randomly sample an scale according to the type
|
||||
of `scale`.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:`dataset`.
|
||||
Returns:
|
||||
dict: One new key 'scale`is added into ``results``,
|
||||
which would be used by subsequent pipelines.
|
||||
"""
|
||||
|
||||
if isinstance(self.scale, tuple):
|
||||
assert self.ratio_range is not None and len(self.ratio_range) == 2
|
||||
scale: Tuple[int, int] = self._random_sample_ratio(
|
||||
self.scale, self.ratio_range)
|
||||
elif mmcv.is_list_of(self.scale, tuple):
|
||||
scale = self._random_sample(self.scale)
|
||||
else:
|
||||
raise NotImplementedError(f"Do not support sampling function \
|
||||
for '{self.scale}'")
|
||||
|
||||
results['scale'] = scale
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to resize images, bounding boxes, semantic
|
||||
segmentation map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
Returns:
|
||||
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
|
||||
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
|
||||
and 'keep_ratio' keys are updated in result dict.
|
||||
"""
|
||||
self._random_scale(results)
|
||||
self.resize.scale = results['scale']
|
||||
results = self.resize.transform(results)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(scale={self.scale}, '
|
||||
repr_str += f'ratio_range={self.ratio_range}, '
|
||||
repr_str += f'keep_ratio={self.keep_ratio}, '
|
||||
repr_str += f'bbox_clip_border={self.clip_object_border}, '
|
||||
repr_str += f'backend={self.backend}, '
|
||||
repr_str += f'interpolation={self.interpolation})'
|
||||
return repr_str
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
try:
|
||||
import torch
|
||||
except ModuleNotFoundError:
|
||||
torch = None
|
||||
else:
|
||||
from mmcv.transforms import ToTensor, to_tensor, ImageToTensor
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=torch is None, reason='No torch in current env')
|
||||
def test_to_tensor():
|
||||
|
||||
# The type of the input object is torch.Tensor
|
||||
data_tensor = torch.tensor([1, 2, 3])
|
||||
tensor_from_tensor = to_tensor(data_tensor)
|
||||
assert isinstance(tensor_from_tensor, torch.Tensor)
|
||||
|
||||
# The type of the input object is numpy.ndarray
|
||||
data_numpy = np.array([1, 2, 3])
|
||||
tensor_from_numpy = to_tensor(data_numpy)
|
||||
assert isinstance(tensor_from_numpy, torch.Tensor)
|
||||
|
||||
# The type of the input object is list
|
||||
data_list = [1, 2, 3]
|
||||
tensor_from_list = to_tensor(data_list)
|
||||
assert isinstance(tensor_from_list, torch.Tensor)
|
||||
|
||||
# The type of the input object is int
|
||||
data_int = 1
|
||||
tensor_from_int = to_tensor(data_int)
|
||||
assert isinstance(tensor_from_int, torch.Tensor)
|
||||
|
||||
# The type of the input object is float
|
||||
data_float = 1.0
|
||||
tensor_from_float = to_tensor(data_float)
|
||||
assert isinstance(tensor_from_float, torch.Tensor)
|
||||
|
||||
# The type of the input object is invalid
|
||||
with pytest.raises(TypeError):
|
||||
data_str = '123'
|
||||
_ = to_tensor(data_str)
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=torch is None, reason='No torch in current env')
|
||||
class TestToTensor:
|
||||
|
||||
def test_init(self):
|
||||
TRANSFORM = ToTensor(keys=['img_label'])
|
||||
assert TRANSFORM.keys == ['img_label']
|
||||
|
||||
def test_transform(self):
|
||||
TRANSFORMS = ToTensor(['instances.bbox', 'img_label'])
|
||||
|
||||
# Test multi-level key and single-level key (multi-level key is
|
||||
# not in results)
|
||||
with pytest.raises(KeyError):
|
||||
results = {'instances': {'label': [1]}, 'img_label': [1]}
|
||||
results_tensor = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert isinstance(results_tensor['instances']['label'], list)
|
||||
assert isinstance(results_tensor['img_label'], torch.Tensor)
|
||||
|
||||
# Test multi-level key (multi-level key is in results)
|
||||
results = {'instances': {'bbox': [[0, 0, 10, 10]]}, 'img_label': [1]}
|
||||
results_tensor = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert isinstance(results_tensor['instances']['bbox'], torch.Tensor)
|
||||
|
||||
def test_repr(self):
|
||||
TRANSFORMS = ToTensor(['instances.bbox', 'img_label'])
|
||||
TRANSFORMS_str = str(TRANSFORMS)
|
||||
isinstance(TRANSFORMS_str, str)
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=torch is None, reason='No torch in current env')
|
||||
class TestImageToTensor:
|
||||
|
||||
def test_init(self):
|
||||
TRANSFORMS = ImageToTensor(['img'])
|
||||
assert TRANSFORMS.keys == ['img']
|
||||
|
||||
def test_transform(self):
|
||||
TRANSFORMS = ImageToTensor(['img'])
|
||||
|
||||
# image only has one channel
|
||||
results = {'img': np.zeros((224, 224))}
|
||||
results = TRANSFORMS.transform(results)
|
||||
assert results['img'].shape == (1, 224, 224)
|
||||
|
||||
# image has three channels
|
||||
results = {'img': np.zeros((224, 224, 3))}
|
||||
results = TRANSFORMS.transform(results)
|
||||
assert results['img'].shape == (3, 224, 224)
|
||||
|
||||
def test_repr(self):
|
||||
TRANSFORMS = ImageToTensor(['img'])
|
||||
TRANSFORMS_str = str(TRANSFORMS)
|
||||
assert isinstance(TRANSFORMS_str, str)
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mmcv
|
||||
from mmcv.transforms import Normalize, Pad, Resize
|
||||
from mmcv.transforms import Normalize, Pad, RandomFlip, RandomResize, Resize
|
||||
|
||||
|
||||
class TestNormalize:
|
||||
|
@ -183,3 +183,154 @@ class TestPad:
|
|||
assert repr(trans) == (
|
||||
'Pad(size=None, size_divisor=11, pad_to_square=True, '
|
||||
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")
|
||||
|
||||
|
||||
class TestRandomFlip:
|
||||
|
||||
def test_init(self):
|
||||
|
||||
# prob is float
|
||||
TRANSFORMS = RandomFlip(0.1)
|
||||
assert TRANSFORMS.prob == 0.1
|
||||
|
||||
# prob is None
|
||||
with pytest.raises(ValueError):
|
||||
TRANSFORMS = RandomFlip(None)
|
||||
assert TRANSFORMS.prob is None
|
||||
|
||||
# prob is a list
|
||||
TRANSFORMS = RandomFlip([0.1, 0.2], ['horizontal', 'vertical'])
|
||||
assert len(TRANSFORMS.prob) == 2
|
||||
assert len(TRANSFORMS.direction) == 2
|
||||
|
||||
# direction is an invalid type
|
||||
with pytest.raises(ValueError):
|
||||
TRANSFORMS = RandomFlip(0.1, 1)
|
||||
|
||||
# prob is an invalid type
|
||||
with pytest.raises(ValueError):
|
||||
TRANSFORMS = RandomFlip('0.1')
|
||||
|
||||
def test_transform(self):
|
||||
|
||||
results = {
|
||||
'img': np.random.random((224, 224, 3)),
|
||||
'gt_bboxes': np.array([[0, 1, 100, 101]]),
|
||||
'gt_keypoints': np.array([[[100, 100, 1.0]]]),
|
||||
'gt_semantic_seg': np.random.random((224, 224, 3))
|
||||
}
|
||||
|
||||
# horizontal flip
|
||||
TRANSFORMS = RandomFlip([1.0], ['horizontal'])
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
|
||||
101]])).all()
|
||||
|
||||
# diagnal flip
|
||||
TRANSFORMS = RandomFlip([1.0], ['diagonal'])
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert (results_update['gt_bboxes'] == np.array([[124, 123, 224,
|
||||
223]])).all()
|
||||
|
||||
# vertical flip
|
||||
TRANSFORMS = RandomFlip([1.0], ['vertical'])
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert (results_update['gt_bboxes'] == np.array([[0, 123, 100,
|
||||
223]])).all()
|
||||
|
||||
# horizontal flip when direction is None
|
||||
TRANSFORMS = RandomFlip(1.0)
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
|
||||
101]])).all()
|
||||
|
||||
TRANSFORMS = RandomFlip(0.0)
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert (results_update['gt_bboxes'] == np.array([[0, 1, 100,
|
||||
101]])).all()
|
||||
|
||||
# flip direction is invalid in bbox flip
|
||||
with pytest.raises(ValueError):
|
||||
TRANSFORMS = RandomFlip(1.0)
|
||||
results_update = TRANSFORMS.flip_bbox(results['gt_bboxes'],
|
||||
(224, 224), 'invalid')
|
||||
|
||||
# flip direction is invalid in keypoints flip
|
||||
with pytest.raises(ValueError):
|
||||
TRANSFORMS = RandomFlip(1.0)
|
||||
results_update = TRANSFORMS.flip_keypoints(results['gt_keypoints'],
|
||||
(224, 224), 'invalid')
|
||||
|
||||
def test_repr(self):
|
||||
TRANSFORMS = RandomFlip(0.1)
|
||||
TRANSFORMS_str = str(TRANSFORMS)
|
||||
assert isinstance(TRANSFORMS_str, str)
|
||||
|
||||
|
||||
class TestRandomResize:
|
||||
|
||||
def test_init(self):
|
||||
TRANSFORMS = RandomResize(
|
||||
(224, 224),
|
||||
(1.0, 2.0),
|
||||
)
|
||||
assert TRANSFORMS.scale == (224, 224)
|
||||
|
||||
def test_repr(self):
|
||||
TRANSFORMS = RandomResize(
|
||||
(224, 224),
|
||||
(1.0, 2.0),
|
||||
)
|
||||
TRANSFORMS_str = str(TRANSFORMS)
|
||||
assert isinstance(TRANSFORMS_str, str)
|
||||
|
||||
def test_transform(self):
|
||||
|
||||
# choose target scale from init when override is True
|
||||
results = {}
|
||||
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0))
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert results_update['scale'][0] >= 224 and results_update['scale'][
|
||||
0] <= 448
|
||||
assert results_update['scale'][1] >= 224 and results_update['scale'][
|
||||
1] <= 448
|
||||
|
||||
# keep ratio is True
|
||||
results = {
|
||||
'img': np.random.random((224, 224, 3)),
|
||||
'gt_semantic_seg': np.random.random((224, 224, 3)),
|
||||
'gt_bboxes': np.array([[0, 0, 112, 112]]),
|
||||
'gt_keypoints': np.array([[[112, 112]]])
|
||||
}
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0), keep_ratio=True)
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert 224 <= results_update['height']
|
||||
assert 448 >= results_update['height']
|
||||
assert 224 <= results_update['width']
|
||||
assert 448 >= results_update['width']
|
||||
assert results_update['keep_ratio']
|
||||
assert results['gt_bboxes'][0][2] >= 112
|
||||
assert results['gt_bboxes'][0][2] <= 112
|
||||
|
||||
# keep ratio is False
|
||||
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0), keep_ratio=False)
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
|
||||
# choose target scale from init when override is False and scale is a
|
||||
# list of tuples
|
||||
results = {}
|
||||
TRANSFORMS = RandomResize([(224, 448), (112, 224)], keep_ratio=True)
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
assert results_update['scale'][0] >= 224 and results_update['scale'][
|
||||
0] <= 448
|
||||
assert results_update['scale'][1] >= 112 and results_update['scale'][
|
||||
1] <= 224
|
||||
|
||||
# the type of scale is invalid in init
|
||||
with pytest.raises(NotImplementedError):
|
||||
results = {}
|
||||
TRANSFORMS = RandomResize([(224, 448), [112, 224]],
|
||||
keep_ratio=True)
|
||||
results_update = TRANSFORMS.transform(copy.deepcopy(results))
|
||||
|
|
Loading…
Reference in New Issue