[Feature] Add P1 DataTransform (#1843)

* [Feature] Add P1 DataTransform

* fix unit test error

* fix @cache_randomness location
This commit is contained in:
MengzhangLI 2022-08-01 18:46:36 +08:00 committed by MeowZheng
parent 76c5ce1396
commit ecab73a892
3 changed files with 593 additions and 42 deletions

View File

@ -105,8 +105,8 @@ class MultiImageMixDataset:
transform_type in self._skip_type_keys: transform_type in self._skip_type_keys:
continue continue
if hasattr(transform, 'get_indexes'): if hasattr(transform, 'get_indices'):
indexes = transform.get_indexes(self.dataset) indexes = transform.get_indices(self.dataset)
if not isinstance(indexes, collections.abc.Sequence): if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes] indexes = [indexes]
mix_results = [ mix_results = [

View File

@ -9,13 +9,25 @@ from mmcv.transforms.utils import cache_randomness
from mmcv.utils import is_tuple_of from mmcv.utils import is_tuple_of
from numpy import random from numpy import random
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class ResizeToMultiple(object): class ResizeToMultiple(BaseTransform):
"""Resize images & seg to multiple of divisor. """Resize images & seg to multiple of divisor.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- img_shape
- pad_shape
Args: Args:
size_divisor (int): images and gt seg maps need to resize to multiple size_divisor (int): images and gt seg maps need to resize to multiple
of size_divisor. Default: 32. of size_divisor. Default: 32.
@ -27,7 +39,7 @@ class ResizeToMultiple(object):
self.size_divisor = size_divisor self.size_divisor = size_divisor
self.interpolation = interpolation self.interpolation = interpolation
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to resize images, semantic segmentation map to """Call function to resize images, semantic segmentation map to
multiple of size divisor. multiple of size divisor.
@ -70,9 +82,17 @@ class ResizeToMultiple(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class Rerange(object): class Rerange(BaseTransform):
"""Rerange the image pixel value. """Rerange the image pixel value.
Required Keys:
- img
Modified Keys:
- img
Args: Args:
min_value (float or int): Minimum value of the reranged image. min_value (float or int): Minimum value of the reranged image.
Default: 0. Default: 0.
@ -87,7 +107,7 @@ class Rerange(object):
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to rerange images. """Call function to rerange images.
Args: Args:
@ -116,12 +136,20 @@ class Rerange(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class CLAHE(object): class CLAHE(BaseTransform):
"""Use CLAHE method to process the image. """Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information. Graphics Gems, 1994:474-485.` for more information.
Required Keys:
- img
Modified Keys:
- img
Args: Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0. clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization. tile_grid_size (tuple[int]): Size of grid for histogram equalization.
@ -136,7 +164,7 @@ class CLAHE(object):
assert len(tile_grid_size) == 2 assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size self.tile_grid_size = tile_grid_size
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to Use CLAHE method process images. """Call function to Use CLAHE method process images.
Args: Args:
@ -167,13 +195,13 @@ class RandomCrop(BaseTransform):
Required Keys: Required Keys:
- img - img
- gt_semantic_seg - gt_seg_map
Modified Keys: Modified Keys:
- img - img
- img_shape - img_shape
- gt_semantic_seg - gt_seg_map
Args: Args:
@ -293,9 +321,19 @@ class RandomCrop(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomRotate(object): class RandomRotate(BaseTransform):
"""Rotate the image & seg. """Rotate the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args: Args:
prob (float): The rotation probability. prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If degree (float, tuple[float]): Range of degrees to select from. If
@ -332,7 +370,12 @@ class RandomRotate(object):
self.center = center self.center = center
self.auto_bound = auto_bound self.auto_bound = auto_bound
def __call__(self, results): @cache_randomness
def generate_degree(self):
return np.random.rand() < self.prob, np.random.uniform(
min(*self.degree), max(*self.degree))
def transform(self, results: dict) -> dict:
"""Call function to rotate image, semantic segmentation maps. """Call function to rotate image, semantic segmentation maps.
Args: Args:
@ -342,8 +385,7 @@ class RandomRotate(object):
dict: Rotated results. dict: Rotated results.
""" """
rotate = True if np.random.rand() < self.prob else False rotate, degree = self.generate_degree()
degree = np.random.uniform(min(*self.degree), max(*self.degree))
if rotate: if rotate:
# rotate image # rotate image
results['img'] = mmcv.imrotate( results['img'] = mmcv.imrotate(
@ -376,9 +418,18 @@ class RandomRotate(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RGB2Gray(object): class RGB2Gray(BaseTransform):
"""Convert RGB image to grayscale image. """Convert RGB image to grayscale image.
Required Keys:
- img
Modified Keys:
- img
- img_shape
This transform calculate the weighted mean of input image channels with This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When ``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as ``out_channels`` is None, the number of output channels is the same as
@ -399,7 +450,7 @@ class RGB2Gray(object):
assert isinstance(item, (float, int)) assert isinstance(item, (float, int))
self.weights = weights self.weights = weights
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to convert RGB image to grayscale image. """Call function to convert RGB image to grayscale image.
Args: Args:
@ -431,9 +482,17 @@ class RGB2Gray(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class AdjustGamma(object): class AdjustGamma(BaseTransform):
"""Using gamma correction to process the image. """Using gamma correction to process the image.
Required Keys:
- img
Modified Keys:
- img
Args: Args:
gamma (float or int): Gamma value used in gamma correction. gamma (float or int): Gamma value used in gamma correction.
Default: 1.0. Default: 1.0.
@ -447,7 +506,7 @@ class AdjustGamma(object):
self.table = np.array([(i / 255.0)**inv_gamma * 255 self.table = np.array([(i / 255.0)**inv_gamma * 255
for i in np.arange(256)]).astype('uint8') for i in np.arange(256)]).astype('uint8')
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to process the image with gamma correction. """Call function to process the image with gamma correction.
Args: Args:
@ -467,9 +526,17 @@ class AdjustGamma(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class SegRescale(object): class SegRescale(BaseTransform):
"""Rescale semantic segmentation maps. """Rescale semantic segmentation maps.
Required Keys:
- gt_seg_map
Modified Keys:
- gt_seg_map
Args: Args:
scale_factor (float): The scale factor of the final output. scale_factor (float): The scale factor of the final output.
""" """
@ -477,7 +544,7 @@ class SegRescale(object):
def __init__(self, scale_factor=1): def __init__(self, scale_factor=1):
self.scale_factor = scale_factor self.scale_factor = scale_factor
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to scale the semantic segmentation map. """Call function to scale the semantic segmentation map.
Args: Args:
@ -667,11 +734,22 @@ class PhotoMetricDistortion(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomCutOut(object): class RandomCutOut(BaseTransform):
"""CutOut operation. """CutOut operation.
Randomly drop some regions of image used in Randomly drop some regions of image used in
`Cutout <https://arxiv.org/abs/1708.04552>`_. `Cutout <https://arxiv.org/abs/1708.04552>`_.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args: Args:
prob (float): cutout probability. prob (float): cutout probability.
n_holes (int | tuple[int, int]): Number of regions to be dropped. n_holes (int | tuple[int, int]): Number of regions to be dropped.
@ -721,16 +799,38 @@ class RandomCutOut(object):
if not isinstance(self.candidates, list): if not isinstance(self.candidates, list):
self.candidates = [self.candidates] self.candidates = [self.candidates]
def __call__(self, results): @cache_randomness
def do_cutout(self):
return np.random.rand() < self.prob
@cache_randomness
def generate_patches(self, results):
cutout = self.do_cutout()
h, w, _ = results['img'].shape
if cutout:
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
else:
n_holes = 0
x1_lst = []
y1_lst = []
index_lst = []
for _ in range(n_holes):
x1_lst.append(np.random.randint(0, w))
y1_lst.append(np.random.randint(0, h))
index_lst.append(np.random.randint(0, len(self.candidates)))
return cutout, n_holes, x1_lst, y1_lst, index_lst
def transform(self, results: dict) -> dict:
"""Call function to drop some regions of image.""" """Call function to drop some regions of image."""
cutout = True if np.random.rand() < self.prob else False cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
results)
if cutout: if cutout:
h, w, c = results['img'].shape h, w, c = results['img'].shape
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) for i in range(n_holes):
for _ in range(n_holes): x1 = x1_lst[i]
x1 = np.random.randint(0, w) y1 = y1_lst[i]
y1 = np.random.randint(0, h) index = index_lst[i]
index = np.random.randint(0, len(self.candidates))
if not self.with_ratio: if not self.with_ratio:
cutout_w, cutout_h = self.candidates[index] cutout_w, cutout_h = self.candidates[index]
else: else:
@ -759,7 +859,7 @@ class RandomCutOut(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomMosaic(object): class RandomMosaic(BaseTransform):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into """Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub- one output image. The output image is composed of the parts from each sub-
image. image.
@ -789,6 +889,19 @@ class RandomMosaic(object):
sample another 3 images from the custom dataset. sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch 3. Sub image will be cropped if image is larger than mosaic patch
Required Keys:
- img
- gt_seg_map
- mix_results
Modified Keys:
- img
- img_shape
- ori_shape
- gt_seg_map
Args: Args:
prob (float): mosaic probability. prob (float): mosaic probability.
img_scale (Sequence[int]): Image size after mosaic pipeline of img_scale (Sequence[int]): Image size after mosaic pipeline of
@ -815,7 +928,11 @@ class RandomMosaic(object):
self.pad_val = pad_val self.pad_val = pad_val
self.seg_pad_val = seg_pad_val self.seg_pad_val = seg_pad_val
def __call__(self, results): @cache_randomness
def do_mosaic(self):
return np.random.rand() < self.prob
def transform(self, results: dict) -> dict:
"""Call function to make a mosaic of image. """Call function to make a mosaic of image.
Args: Args:
@ -824,13 +941,13 @@ class RandomMosaic(object):
Returns: Returns:
dict: Result dict with mosaic transformed. dict: Result dict with mosaic transformed.
""" """
mosaic = True if np.random.rand() < self.prob else False mosaic = self.do_mosaic()
if mosaic: if mosaic:
results = self._mosaic_transform_img(results) results = self._mosaic_transform_img(results)
results = self._mosaic_transform_seg(results) results = self._mosaic_transform_seg(results)
return results return results
def get_indexes(self, dataset): def get_indices(self, dataset: MultiImageMixDataset) -> list:
"""Call function to collect indexes. """Call function to collect indexes.
Args: Args:
@ -843,7 +960,16 @@ class RandomMosaic(object):
indexes = [random.randint(0, len(dataset)) for _ in range(3)] indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes return indexes
def _mosaic_transform_img(self, results): @cache_randomness
def generate_mosaic_center(self):
# mosaic center x, y
center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
return center_x, center_y
def _mosaic_transform_img(self, results: dict) -> dict:
"""Mosaic transform function. """Mosaic transform function.
Args: Args:
@ -866,10 +992,7 @@ class RandomMosaic(object):
dtype=results['img'].dtype) dtype=results['img'].dtype)
# mosaic center x, y # mosaic center x, y
self.center_x = int( self.center_x, self.center_y = self.generate_mosaic_center()
random.uniform(*self.center_ratio_range) * self.img_scale[1])
self.center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
center_position = (self.center_x, self.center_y) center_position = (self.center_x, self.center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
@ -902,7 +1025,7 @@ class RandomMosaic(object):
return results return results
def _mosaic_transform_seg(self, results): def _mosaic_transform_seg(self, results: dict) -> dict:
"""Mosaic transform function for label annotations. """Mosaic transform function for label annotations.
Args: Args:
@ -953,7 +1076,8 @@ class RandomMosaic(object):
return results return results
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh): def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float],
img_shape_wh: Sequence[int]) -> tuple:
"""Calculate global coordinate of mosaic image and local coordinate of """Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image. cropped sub-image.

View File

@ -7,11 +7,9 @@ import numpy as np
import pytest import pytest
from PIL import Image from PIL import Image
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
register_all_modules()
def test_resize(): def test_resize():
@ -233,6 +231,72 @@ def test_random_crop():
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20) assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
def test_rgb2gray():
# test assertion out_channels should be greater than 0
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=-1)
TRANSFORMS.build(transform)
# test assertion weights should be tuple[float]
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
TRANSFORMS.build(transform)
# test out_channels is None
transform = dict(type='RGB2Gray')
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RGB2Gray(' \
f'out_channels={None}, ' \
f'weights={(0.299, 0.587, 0.114)})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape == (h, w, c)
assert results['img_shape'] == (h, w, c)
assert results['ori_shape'] == (h, w, c)
# test out_channels = 2
transform = dict(type='RGB2Gray', out_channels=2)
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RGB2Gray(' \
f'out_channels={2}, ' \
f'weights={(0.299, 0.587, 0.114)})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape == (h, w, 2)
assert results['img_shape'] == (h, w, 2)
def test_photo_metric_distortion(): def test_photo_metric_distortion():
results = dict() results = dict()
@ -252,3 +316,366 @@ def test_photo_metric_distortion():
assert (results['gt_semantic_seg'] == seg).all() assert (results['gt_semantic_seg'] == seg).all()
assert results['img_shape'] == img.shape assert results['img_shape'] == img.shape
def test_rerange():
# test assertion if min_value or max_value is illegal
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=[0], max_value=[255])
TRANSFORMS.build(transform)
# test assertion if min_value >= max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=1, max_value=1)
TRANSFORMS.build(transform)
# test assertion if img_min_value == img_max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=0, max_value=1)
transform = TRANSFORMS.build(transform)
results = dict()
results['img'] = np.array([[1, 1], [1, 1]])
transform(results)
img_rerange_cfg = dict()
transform = dict(type='Rerange', **img_rerange_cfg)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
min_value = np.min(original_img)
max_value = np.max(original_img)
converted_img = (original_img - min_value) / (max_value - min_value) * 255
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
TRANSFORMS.build(transform)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
TRANSFORMS.build(transform)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
TRANSFORMS.build(transform)
transform = dict(type='CLAHE', clip_limit=2)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
def test_adjust_gamma():
# test assertion if gamma <= 0
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=0)
TRANSFORMS.build(transform)
# test assertion if gamma is list
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=[1.2])
TRANSFORMS.build(transform)
# test with gamma = 1.2
transform = dict(type='AdjustGamma', gamma=1.2)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
inv_gamma = 1.0 / 1.2
table = np.array([((i / 255.0)**inv_gamma) * 255
for i in np.arange(0, 256)]).astype('uint8')
converted_img = mmcv.lut_transform(
np.array(original_img, dtype=np.uint8), table)
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'AdjustGamma(gamma={1.2})'
def test_rotate():
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
TRANSFORMS.build(transform)
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
TRANSFORMS.build(transform)
transform = dict(type='RandomRotate', degree=10., prob=1.)
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RandomRotate(' \
f'prob={1.}, ' \
f'degree=({-10.}, {10.}), ' \
f'pad_val={0}, ' \
f'seg_pad_val={255}, ' \
f'center={None}, ' \
f'auto_bound={False})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, _ = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape[:2] == (h, w)
def test_seg_rescale():
results = dict()
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
h, w = seg.shape
transform = dict(type='SegRescale', scale_factor=1. / 2)
rescale_module = TRANSFORMS.build(transform)
rescale_results = rescale_module(results.copy())
assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2)
transform = dict(type='SegRescale', scale_factor=1)
rescale_module = TRANSFORMS.build(transform)
rescale_results = rescale_module(results.copy())
assert rescale_results['gt_semantic_seg'].shape == (h, w)
def test_mosaic():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid img_scale
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
assert 'Mosaic' in repr(mosaic_module)
# test assertion for invalid mix_results
with pytest.raises(AssertionError):
mosaic_module(results)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
results = dict()
results['img'] = img[:, :, 0]
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == img.shape[:2]
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
def test_cutout():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
TRANSFORMS.build(transform)
# test n_holes
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=(3, 4, 5),
cutout_shape=(8, 8))
TRANSFORMS.build(transform)
# test cutout_shape and cutout_ratio
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
TRANSFORMS.build(transform)
# either of cutout_shape and cutout_ratio should be given
with pytest.raises(AssertionError):
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(2, 2),
cutout_ratio=(0.4, 0.4))
TRANSFORMS.build(transform)
# test seg_fill_in
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(8, 8),
seg_fill_in='a')
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(8, 8),
seg_fill_in=256)
TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['pad_shape'] = img.shape
results['img_fields'] = ['img']
transform = dict(
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
cutout_module = TRANSFORMS.build(transform)
assert 'cutout_shape' in repr(cutout_module)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() < img.sum()
transform = dict(
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
cutout_module = TRANSFORMS.build(transform)
assert 'cutout_ratio' in repr(cutout_module)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() < img.sum()
transform = dict(
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() == img.sum()
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
transform = dict(
type='RandomCutOut',
prob=1,
n_holes=(2, 4),
cutout_shape=[(10, 10), (15, 15)],
fill_in=(255, 255, 255),
seg_fill_in=None)
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
transform = dict(
type='RandomCutOut',
prob=1,
n_holes=1,
cutout_ratio=(0.8, 0.8),
fill_in=(255, 255, 255),
seg_fill_in=255)
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
def test_resize_to_multiple():
transform = dict(type='ResizeToMultiple', size_divisor=32)
transform = TRANSFORMS.build(transform)
img = np.random.randn(213, 232, 3)
seg = np.random.randint(0, 19, (213, 232))
results = dict()
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['pad_shape'] = img.shape
results = transform(results)
assert results['img'].shape == (224, 256, 3)
assert results['gt_semantic_seg'].shape == (224, 256)
assert results['img_shape'] == (224, 256, 3)