[Feature] Add Part1 of data transform (#1736)

* [Feature] Add Part1 of data transform

* api.rst transform->transforms

* fix typo

* fix typo

* rename test_processing into test_transforms_processing for avoiding same name

* fix mypy

* fix comment

* fix comment
pull/2133/head
liukuikun 2022-02-25 11:36:11 +08:00 committed by zhouzaida
parent 95b5a07b90
commit 9e4b2ff58e
8 changed files with 994 additions and 9 deletions

View File

@ -48,7 +48,7 @@ ops
.. automodule:: mmcv.ops
:members:
transform
transforms
---------
.. automodule:: mmcv.transform
.. automodule:: mmcv.transforms
:members:

View File

@ -3,7 +3,7 @@
from .arraymisc import *
from .fileio import *
from .image import *
from .transform import *
from .transforms import *
from .utils import *
from .version import *
from .video import *

View File

@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import TRANSFORMS
from .loading import LoadAnnotation, LoadImageFromFile
from .processing import Normalize, Pad, Resize
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
__all__ = ['TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap']
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad'
]

View File

@ -0,0 +1,271 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import mmcv
from .base import BaseTransform
from .builder import TRANSFORMS
@TRANSFORMS.register_module()
class LoadImageFromFile(BaseTransform):
"""Load an image from file.
Required Keys:
- img_path
Modified Keys:
- img
- width
- height
- ori_width
- ori_height
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
Defaults to 'color'.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(
self,
to_float32: bool = False,
color_type: str = 'color',
imdecode_backend: str = 'cv2',
file_client_args: dict = dict(backend='disk')
) -> None:
self.to_float32 = to_float32
self.color_type = color_type
self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args)
def transform(self, results: dict) -> dict:
"""Functions to load image.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded image and meta information.
"""
filename = results['img_path']
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
if self.to_float32:
img = img.astype(np.float32)
results['img'] = img
height, width = img.shape[:2]
results['height'] = height
results['width'] = width
results['ori_height'] = height
results['ori_width'] = width
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f"imdecode_backend='{self.imdecode_backend}', "
f'file_client_args={self.file_client_args})')
return repr_str
class LoadAnnotation(BaseTransform):
"""Load and process the ``instances`` and ``seg_map`` annotation provided
by dataset.
The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
# Used in key point detection.
# Can only load the format of [x1, y1, v1,…, xn, yn, vn]. v[i]
# means the visibility of this keypoint. n must be equal to the
# number of keypoint categories.
'keypoints': [x1, y1, v1, ..., xn, yn, vn]
}
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': np.ndarray(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In uint8 type.
'gt_semantic_seg': np.ndarray (H, W)
# in (x, y, v) order, float type.
'gt_keypoints': np.ndarray(N, NK, 3)
}
Required Keys:
- instances
- bbox (optional)
- bbox_label
- keypoints (optional)
- seg_map (optional)
Added Keys:
- gt_bboxes
- gt_bboxes_labels
- gt_semantic_seg
- gt_keypoints
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Defaults to True.
with_label (bool): Whether to parse and load the label annotation.
Defaults to True.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Defaults to False.
with_kps (bool): Whether to parse and load the keypoints annotation.
Defaults to False.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:``mmcv.fileio.FileClient`` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(
self,
with_bbox: bool = True,
with_label: bool = True,
with_seg: bool = False,
with_kps: bool = False,
imdecode_backend: str = 'cv2',
file_client_args: dict = dict(backend='disk')
) -> None:
super().__init__()
self.with_bbox = with_bbox
self.with_label = with_label
self.with_seg = with_seg
self.with_kps = with_kps
self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args)
def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
gt_bboxes = []
for instance in results['instances']:
gt_bboxes.append(instance['bbox'])
results['gt_bboxes'] = np.array(gt_bboxes)
def _load_labels(self, results: dict) -> None:
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded label annotations.
"""
gt_bboxes_labels = []
for instance in results['instances']:
gt_bboxes_labels.append(instance['bbox_label'])
results['gt_bboxes_labels'] = np.array(gt_bboxes_labels)
def _load_semantic_seg(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
img_bytes = self.file_client.get(results['seg_map'])
results['gt_semantic_seg'] = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze()
def _load_kps(self, results: dict) -> None:
"""Private function to load keypoints annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded keypoints annotations.
"""
gt_keypoints = []
for instance in results['instances']:
gt_keypoints.append(instance['keypoints'])
results['gt_keypoints'] = np.array(gt_keypoints).reshape(
(len(gt_keypoints), -1, 3))
def transform(self, results: dict) -> dict:
"""Function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box, label and
semantic segmentation and keypoints annotations.
"""
if self.with_bbox:
self._load_bboxes(results)
if self.with_label:
self._load_labels(results)
if self.with_seg:
self._load_semantic_seg(results)
if self.with_kps:
self._load_kps(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'with_kps={self.with_kps}, '
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
repr_str += f'file_client_args={self.file_client_args})'
return repr_str

View File

@ -0,0 +1,402 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import mmcv
from mmcv.image.geometric import _scale_size
from .base import BaseTransform
from .builder import TRANSFORMS
@TRANSFORMS.register_module()
class Normalize(BaseTransform):
"""Normalize the image.
Required Keys:
- img
Added Keys:
- img_norm_cfg
- mean
- std
- to_rgb
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB before
normlizing the image. If ``to_rgb=True``, the order of mean and std
should be RGB. If ``to_rgb=False``, the order of mean and std
should be BGR. Defaults to True.
"""
def __init__(self,
mean: Sequence[float],
std: Sequence[float],
to_rgb: bool = True) -> None:
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def transform(self, results: dict) -> dict:
"""Function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, key 'img_norm_cfg' key is added in to
result dict.
"""
results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
return repr_str
@TRANSFORMS.register_module()
class Resize(BaseTransform):
"""Resize images & bbox & seg & keypoints.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Bboxes, seg map and keypoints are then resized with the
same scale factor.
if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
resize.
Required Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_keypoints (optional)
Modified Keys:
- img
- gt_bboxes
- gt_semantic_seg
- gt_keypoints
- height
- width
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (int or tuple): Images scales for resizing. Defaults to None
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool, optional): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'cv2'.
"""
def __init__(self,
scale: Optional[Union[int, Tuple[int, int]]] = None,
scale_factor: Optional[Union[float, Tuple[float,
float]]] = None,
keep_ratio: bool = False,
clip_object_border: bool = True,
backend: str = 'cv2',
interpolation='bilinear') -> None:
assert scale is not None or scale_factor is not None, (
'`scale` and'
'`scale_factor` can not both be `None`')
if scale is None:
self.scale = None
else:
if isinstance(scale, int):
self.scale = (scale, scale)
else:
self.scale = scale
self.backend = backend
self.interpolation = interpolation
self.keep_ratio = keep_ratio
self.clip_object_border = clip_object_border
if scale_factor is None:
self.scale_factor = None
elif isinstance(scale_factor, float):
self.scale_factor = (scale_factor, scale_factor)
elif isinstance(scale_factor, tuple):
assert (len(scale_factor)) == 2
self.scale_factor = scale_factor
else:
raise TypeError(
f'expect scale_factor is float or Tuple(float), but'
f'get {type(scale_factor)}')
def _resize_img(self, results: dict) -> None:
"""Resize images with ``results['scale']``."""
if results.get('img', None) is not None:
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
results['img'],
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'],
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
results['img'] = img
results['height'], results['width'] = img.shape[:2]
results['scale'] = img.shape[:2][::-1]
results['scale_factor'] = (w_scale, h_scale)
results['keep_ratio'] = self.keep_ratio
def _resize_bboxes(self, results: dict) -> None:
"""Resize bounding boxes with ``results['scale_factor']``."""
if results.get('gt_bboxes', None) is not None:
bboxes = results['gt_bboxes'] * np.tile(
np.array(results['scale_factor']), 2)
if self.clip_object_border:
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, results['width'])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0,
results['height'])
results['gt_bboxes'] = bboxes
def _resize_seg(self, results: dict) -> None:
"""Resize semantic segmentation map with ``results['scale']``."""
if results.get('gt_semantic_seg', None) is not None:
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results['gt_semantic_seg'],
results['scale'],
interpolation=self.interpolation,
backend=self.backend)
else:
gt_seg = mmcv.imresize(
results['gt_semantic_seg'],
results['scale'],
interpolation=self.interpolation,
backend=self.backend)
results['gt_semantic_seg'] = gt_seg
def _resize_keypoints(self, results: dict) -> None:
"""Resize keypoints with ``results['scale_factor']``."""
if results.get('gt_keypoints', None) is not None:
keypoints = results['gt_keypoints']
keypoints[:, :, :2] = keypoints[:, :, :2] * np.array(
results['scale_factor'])
if self.clip_object_border:
keypoints[:, :, 0] = np.clip(keypoints[:, :, 0], 0,
results['width'])
keypoints[:, :, 1] = np.clip(keypoints[:, :, 1], 0,
results['height'])
results['gt_keypoints'] = keypoints
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes, semantic
segmentation map and keypoints.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
and 'keep_ratio' keys are updated in result dict.
"""
if self.scale:
results['scale'] = self.scale
else:
img_shape = results['img'].shape[:2]
results['scale'] = _scale_size(img_shape[::-1], self.scale_factor)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_seg(results)
self._resize_keypoints(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class Pad(BaseTransform):
"""Pad the image & segmentation map.
There are three padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number. and (3)pad to square. Also,
pad to square and pad to the minimum size can be used as the same time.
Required Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
Modified Keys:
- img
- gt_semantic_seg
- height
- width
Added Keys:
- pad_shape
- pad_fixed_size
- pad_size_divisor
Args:
size (tuple, optional): Fixed padding size.
Expected padding shape (h, w)Defaults to None.
size_divisor (int, optional): The divisor of padded size. Defaults to
None.
pad_to_square (bool): Whether to pad the image into a square.
Currently only used for YOLOX. Defaults to False.
pad_val (int or dict, optional): A dict for padding value.
if ``type(pad_val) == int``, the val to pad seg is 255. Defaults to
``dict(img=0, seg=255)``.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Defaults to 'constant'.
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self,
size: Optional[Tuple[int, int]] = None,
size_divisor: Optional[int] = None,
pad_to_square: bool = False,
pad_val: Union[int, dict] = dict(img=0, seg=255),
padding_mode: str = 'constant') -> None:
self.size = size
self.size_divisor = size_divisor
if isinstance(pad_val, int):
pad_val = dict(img=pad_val, seg=255)
assert isinstance(pad_val, dict), 'pad_val '
self.pad_val = pad_val
self.pad_to_square = pad_to_square
if pad_to_square:
assert size is None, \
'The size and size_divisor must be None ' \
'when pad2square is True'
else:
assert size is not None or size_divisor is not None, \
'only one of size and size_divisor should be valid'
assert size is None or size_divisor is None
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.padding_mode = padding_mode
def _pad_img(self, results: dict) -> None:
"""Pad images according to ``self.size``."""
pad_val = self.pad_val.get('img', 0)
size = None
if self.pad_to_square:
max_size = max(results['img'].shape[:2])
size = (max_size, max_size)
if self.size_divisor is not None:
if size is None:
size = (results['img'].shape[0], results['img'].shape[1])
pad_h = int(np.ceil(
size[0] / self.size_divisor)) * self.size_divisor
pad_w = int(np.ceil(
size[1] / self.size_divisor)) * self.size_divisor
size = (pad_h, pad_w)
elif self.size is not None:
size = self.size[::-1]
padded_img = mmcv.impad(
results['img'],
shape=size,
pad_val=pad_val,
padding_mode=self.padding_mode)
results['img'] = padded_img
results['pad_shape'] = padded_img.shape
results['pad_fixed_size'] = self.size
results['pad_size_divisor'] = self.size_divisor
results['height'] = padded_img.shape[0]
results['width'] = padded_img.shape[1]
def _pad_seg(self, results: dict) -> None:
"""Pad semantic segmentation map according to
``results['pad_shape']``."""
if results.get('gt_semantic_seg', None) is not None:
pad_val = self.pad_val.get('seg', 255)
results['gt_semantic_seg'] = mmcv.impad(
results['gt_semantic_seg'],
shape=results['pad_shape'][:2],
pad_val=pad_val,
padding_mode=self.padding_mode)
def transform(self, results: dict) -> dict:
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results)
self._pad_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'size_divisor={self.size_divisor}, '
repr_str += f'pad_to_square={self.pad_to_square}, '
repr_str += f'pad_val={self.pad_val}), '
repr_str += f'padding_mode={self.padding_mode})'
return repr_str

View File

@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import numpy as np
from mmcv.transforms import LoadAnnotation, LoadImageFromFile
class TestLoadImageFromFile:
def test_load_img(self):
data_prefix = osp.join(osp.dirname(__file__), '../data')
results = dict(img_path=osp.join(data_prefix, 'color.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['img_path'] == osp.join(data_prefix, 'color.jpg')
assert results['img'].shape == (300, 400, 3)
assert results['img'].dtype == np.uint8
assert results['height'] == 300
assert results['width'] == 400
assert results['ori_height'] == 300
assert results['ori_width'] == 400
assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color', " + \
"imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
# to_float32
transform = LoadImageFromFile(to_float32=True)
results = transform(copy.deepcopy(results))
assert results['img'].dtype == np.float32
# gray image
results = dict(img_path=osp.join(data_prefix, 'grayscale.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['img'].shape == (300, 400, 3)
assert results['img'].dtype == np.uint8
transform = LoadImageFromFile(color_type='unchanged')
results = transform(copy.deepcopy(results))
assert results['img'].shape == (300, 400)
assert results['img'].dtype == np.uint8
class TestLoadAnnotation:
def setup_class(cls):
data_prefix = osp.join(osp.dirname(__file__), '../data')
seg_map = osp.join(data_prefix, 'grayscale.jpg')
cls.results = {
'seg_map':
seg_map,
'instances': [{
'bbox': [0, 0, 10, 20],
'bbox_label': 1,
'keypoints': [1, 2, 3]
}, {
'bbox': [10, 10, 110, 120],
'bbox_label': 2,
'keypoints': [4, 5, 6]
}]
}
def test_load_bboxes(self):
transform = LoadAnnotation(
with_bbox=True,
with_label=False,
with_seg=False,
with_kps=False,
)
results = transform(copy.deepcopy(self.results))
assert 'gt_bboxes' in results
assert (results['gt_bboxes'] == np.array([[0, 0, 10, 20],
[10, 10, 110, 120]])).all()
def test_load_labels(self):
transform = LoadAnnotation(
with_bbox=False,
with_label=True,
with_seg=False,
with_kps=False,
)
results = transform(copy.deepcopy(self.results))
assert 'gt_bboxes_labels' in results
assert (results['gt_bboxes_labels'] == np.array([1, 2])).all()
def test_load_kps(self):
transform = LoadAnnotation(
with_bbox=False,
with_label=False,
with_seg=False,
with_kps=True,
)
results = transform(copy.deepcopy(self.results))
assert 'gt_keypoints' in results
assert (results['gt_keypoints'] == np.array([[[1, 2, 3]],
[[4, 5, 6]]])).all()
def test_load_seg_map(self):
transform = LoadAnnotation(
with_bbox=False,
with_label=False,
with_seg=True,
with_kps=False,
)
results = transform(copy.deepcopy(self.results))
assert 'gt_semantic_seg' in results
assert results['gt_semantic_seg'].shape[:2] == (300, 400)
def test_repr(self):
transform = LoadAnnotation(
with_bbox=True,
with_label=False,
with_seg=False,
with_kps=False,
)
assert repr(transform) == ('LoadAnnotation(with_bbox=True, '
'with_label=False, with_seg=False, '
"with_kps=False, imdecode_backend='cv2', "
"file_client_args={'backend': 'disk'})")

View File

@ -0,0 +1,185 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import numpy as np
import pytest
import mmcv
from mmcv.transforms import Normalize, Pad, Resize
class TestNormalize:
def test_normalize(self):
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transform = Normalize(**img_norm_cfg)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results = transform(results)
mean = np.array(img_norm_cfg['mean'])
std = np.array(img_norm_cfg['std'])
converted_img = (original_img[..., ::-1] - mean) / std
assert np.allclose(results['img'], converted_img)
def test_repr(self):
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transform = Normalize(**img_norm_cfg)
assert repr(transform) == ('Normalize(mean=[123.675 116.28 103.53 ], '
'std=[58.395 57.12 57.375], to_rgb=True)')
class TestResize:
def test_resize(self):
data_info = dict(
img=np.random.random((1333, 800, 3)),
gt_semantic_seg=np.random.random((1333, 800, 3)),
gt_bboxes=np.array([[0, 0, 112, 112]]),
gt_keypoints=np.array([[[20, 50, 1]]]))
with pytest.raises(AssertionError):
transform = Resize(scale=None, scale_factor=None)
with pytest.raises(TypeError):
transform = Resize(scale_factor=[])
# test scale is int
transform = Resize(scale=2000)
results = transform(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2000, 2000)
assert results['scale_factor'] == (2000 / 800, 2000 / 1333)
# test scale is tuple
transform = Resize(scale=(2000, 2000))
results = transform(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2000, 2000)
assert results['scale_factor'] == (2000 / 800, 2000 / 1333)
# test scale_factor is float
transform = Resize(scale_factor=2.0)
results = transform(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2666, 1600)
assert results['scale_factor'] == (2.0, 2.0)
# test scale_factor is tuple
transform = Resize(scale_factor=(1.5, 2))
results = transform(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2666, 1200)
assert results['scale_factor'] == (1.5, 2)
# test keep_ratio is True
transform = Resize(scale=(2000, 2000), keep_ratio=True)
results = transform(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2000, 1200)
assert results['scale'] == (1200, 2000)
assert results['scale_factor'] == (1200 / 800, 2000 / 1333)
# test resize_bboxes/seg/kps
transform = Resize(scale_factor=(1.5, 2))
results = transform(copy.deepcopy(data_info))
assert (results['gt_bboxes'] == np.array([[0, 0, 168, 224]])).all()
assert (results['gt_keypoints'] == np.array([[[30, 100, 1]]])).all()
assert results['gt_semantic_seg'].shape[:2] == (2666, 1200)
# test clip_object_border = False
data_info = dict(
img=np.random.random((300, 400, 3)),
gt_bboxes=np.array([[200, 150, 600, 450]]))
transform = Resize(scale=(200, 150), clip_object_border=False)
results = transform(data_info)
assert (results['gt_bboxes'] == np.array([100, 75, 300, 225])).all()
def test_repr(self):
transform = Resize(scale=(2000, 2000), keep_ratio=True)
assert repr(transform) == ('Resize(scale=(2000, 2000), '
'scale_factor=None, keep_ratio=True, '
'clip_object_border=True), backend=cv2), '
'interpolation=bilinear)')
class TestPad:
def test_pad(self):
# test size and size_divisor are both set
with pytest.raises(AssertionError):
Pad(size=(10, 10), size_divisor=2)
# test size and size_divisor are both None
with pytest.raises(AssertionError):
Pad(size=None, size_divisor=None)
# test size and pad_to_square are both None
with pytest.raises(AssertionError):
Pad(size=(10, 10), pad_to_square=True)
# test pad_val is not int or tuple
with pytest.raises(AssertionError):
Pad(size=(10, 10), pad_val=[])
# test padding_mode is not 'constant', 'edge', 'reflect' or 'symmetric'
with pytest.raises(AssertionError):
Pad(size=(10, 10), padding_mode='edg')
data_info = dict(
img=np.random.random((1333, 800, 3)),
gt_semantic_seg=np.random.random((1333, 800, 3)),
gt_bboxes=np.array([[0, 0, 112, 112]]),
gt_keypoints=np.array([[[20, 50, 1]]]))
# test pad img / gt_semantic_seg with size
trans = Pad(size=(1200, 2000))
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2000, 1200)
assert results['gt_semantic_seg'].shape[:2] == (2000, 1200)
# test pad img/gt_semantic_seg with size_divisor
trans = Pad(size_divisor=11)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 803)
assert results['gt_semantic_seg'].shape[:2] == (1342, 803)
# test pad img/gt_semantic_seg with pad_to_square
trans = Pad(pad_to_square=True)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1333, 1333)
assert results['gt_semantic_seg'].shape[:2] == (1333, 1333)
# test pad img/gt_semantic_seg with pad_to_square and size_divisor
trans = Pad(pad_to_square=True, size_divisor=11)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 1342)
assert results['gt_semantic_seg'].shape[:2] == (1342, 1342)
# test pad img/gt_semantic_seg with pad_to_square and size_divisor
trans = Pad(pad_to_square=True, size_divisor=11)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 1342)
assert results['gt_semantic_seg'].shape[:2] == (1342, 1342)
# test padding_mode
new_img = np.ones((1333, 800, 3))
data_info['img'] = new_img
trans = Pad(pad_to_square=True, padding_mode='edge')
results = trans(copy.deepcopy(data_info))
assert (results['img'] == np.ones((1333, 1333, 3))).all()
# test pad_val
new_img = np.zeros((1333, 800, 3))
data_info['img'] = new_img
trans = Pad(pad_to_square=True, pad_val=0)
results = trans(copy.deepcopy(data_info))
assert (results['img'] == np.zeros((1333, 1333, 3))).all()
def test_repr(self):
trans = Pad(pad_to_square=True, size_divisor=11, padding_mode='edge')
assert repr(trans) == (
'Pad(size=None, size_divisor=11, pad_to_square=True, '
"pad_val={'img': 0, 'seg': 255}), padding_mode=edge)")

View File

@ -4,11 +4,11 @@ import warnings
import numpy as np
import pytest
from mmcv.transform.base import BaseTransform
from mmcv.transform.builder import TRANSFORMS
from mmcv.transform.utils import cache_random_params, cacheable_method
from mmcv.transform.wrappers import (ApplyToMultiple, Compose, RandomChoice,
Remap)
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cacheable_method
from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice,
Remap)
@TRANSFORMS.register_module()