[Feature] Add P1 DataTransform (#1843)

* [Feature] Add P1 DataTransform

* fix unit test error

* fix @cache_randomness location
This commit is contained in:
MengzhangLI 2022-08-01 18:46:36 +08:00 committed by MeowZheng
parent 76c5ce1396
commit ecab73a892
3 changed files with 593 additions and 42 deletions

View File

@ -105,8 +105,8 @@ class MultiImageMixDataset:
transform_type in self._skip_type_keys:
continue
if hasattr(transform, 'get_indexes'):
indexes = transform.get_indexes(self.dataset)
if hasattr(transform, 'get_indices'):
indexes = transform.get_indices(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [

View File

@ -9,13 +9,25 @@ from mmcv.transforms.utils import cache_randomness
from mmcv.utils import is_tuple_of
from numpy import random
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ResizeToMultiple(object):
class ResizeToMultiple(BaseTransform):
"""Resize images & seg to multiple of divisor.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- img_shape
- pad_shape
Args:
size_divisor (int): images and gt seg maps need to resize to multiple
of size_divisor. Default: 32.
@ -27,7 +39,7 @@ class ResizeToMultiple(object):
self.size_divisor = size_divisor
self.interpolation = interpolation
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to resize images, semantic segmentation map to
multiple of size divisor.
@ -70,9 +82,17 @@ class ResizeToMultiple(object):
@TRANSFORMS.register_module()
class Rerange(object):
class Rerange(BaseTransform):
"""Rerange the image pixel value.
Required Keys:
- img
Modified Keys:
- img
Args:
min_value (float or int): Minimum value of the reranged image.
Default: 0.
@ -87,7 +107,7 @@ class Rerange(object):
self.min_value = min_value
self.max_value = max_value
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to rerange images.
Args:
@ -116,12 +136,20 @@ class Rerange(object):
@TRANSFORMS.register_module()
class CLAHE(object):
class CLAHE(BaseTransform):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Required Keys:
- img
Modified Keys:
- img
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
@ -136,7 +164,7 @@ class CLAHE(object):
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to Use CLAHE method process images.
Args:
@ -167,13 +195,13 @@ class RandomCrop(BaseTransform):
Required Keys:
- img
- gt_semantic_seg
- gt_seg_map
Modified Keys:
- img
- img_shape
- gt_semantic_seg
- gt_seg_map
Args:
@ -293,9 +321,19 @@ class RandomCrop(BaseTransform):
@TRANSFORMS.register_module()
class RandomRotate(object):
class RandomRotate(BaseTransform):
"""Rotate the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
@ -332,7 +370,12 @@ class RandomRotate(object):
self.center = center
self.auto_bound = auto_bound
def __call__(self, results):
@cache_randomness
def generate_degree(self):
return np.random.rand() < self.prob, np.random.uniform(
min(*self.degree), max(*self.degree))
def transform(self, results: dict) -> dict:
"""Call function to rotate image, semantic segmentation maps.
Args:
@ -342,8 +385,7 @@ class RandomRotate(object):
dict: Rotated results.
"""
rotate = True if np.random.rand() < self.prob else False
degree = np.random.uniform(min(*self.degree), max(*self.degree))
rotate, degree = self.generate_degree()
if rotate:
# rotate image
results['img'] = mmcv.imrotate(
@ -376,9 +418,18 @@ class RandomRotate(object):
@TRANSFORMS.register_module()
class RGB2Gray(object):
class RGB2Gray(BaseTransform):
"""Convert RGB image to grayscale image.
Required Keys:
- img
Modified Keys:
- img
- img_shape
This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as
@ -399,7 +450,7 @@ class RGB2Gray(object):
assert isinstance(item, (float, int))
self.weights = weights
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to convert RGB image to grayscale image.
Args:
@ -431,9 +482,17 @@ class RGB2Gray(object):
@TRANSFORMS.register_module()
class AdjustGamma(object):
class AdjustGamma(BaseTransform):
"""Using gamma correction to process the image.
Required Keys:
- img
Modified Keys:
- img
Args:
gamma (float or int): Gamma value used in gamma correction.
Default: 1.0.
@ -447,7 +506,7 @@ class AdjustGamma(object):
self.table = np.array([(i / 255.0)**inv_gamma * 255
for i in np.arange(256)]).astype('uint8')
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to process the image with gamma correction.
Args:
@ -467,9 +526,17 @@ class AdjustGamma(object):
@TRANSFORMS.register_module()
class SegRescale(object):
class SegRescale(BaseTransform):
"""Rescale semantic segmentation maps.
Required Keys:
- gt_seg_map
Modified Keys:
- gt_seg_map
Args:
scale_factor (float): The scale factor of the final output.
"""
@ -477,7 +544,7 @@ class SegRescale(object):
def __init__(self, scale_factor=1):
self.scale_factor = scale_factor
def __call__(self, results):
def transform(self, results: dict) -> dict:
"""Call function to scale the semantic segmentation map.
Args:
@ -667,11 +734,22 @@ class PhotoMetricDistortion(BaseTransform):
@TRANSFORMS.register_module()
class RandomCutOut(object):
class RandomCutOut(BaseTransform):
"""CutOut operation.
Randomly drop some regions of image used in
`Cutout <https://arxiv.org/abs/1708.04552>`_.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
prob (float): cutout probability.
n_holes (int | tuple[int, int]): Number of regions to be dropped.
@ -721,16 +799,38 @@ class RandomCutOut(object):
if not isinstance(self.candidates, list):
self.candidates = [self.candidates]
def __call__(self, results):
@cache_randomness
def do_cutout(self):
return np.random.rand() < self.prob
@cache_randomness
def generate_patches(self, results):
cutout = self.do_cutout()
h, w, _ = results['img'].shape
if cutout:
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
else:
n_holes = 0
x1_lst = []
y1_lst = []
index_lst = []
for _ in range(n_holes):
x1_lst.append(np.random.randint(0, w))
y1_lst.append(np.random.randint(0, h))
index_lst.append(np.random.randint(0, len(self.candidates)))
return cutout, n_holes, x1_lst, y1_lst, index_lst
def transform(self, results: dict) -> dict:
"""Call function to drop some regions of image."""
cutout = True if np.random.rand() < self.prob else False
cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
results)
if cutout:
h, w, c = results['img'].shape
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
for _ in range(n_holes):
x1 = np.random.randint(0, w)
y1 = np.random.randint(0, h)
index = np.random.randint(0, len(self.candidates))
for i in range(n_holes):
x1 = x1_lst[i]
y1 = y1_lst[i]
index = index_lst[i]
if not self.with_ratio:
cutout_w, cutout_h = self.candidates[index]
else:
@ -759,7 +859,7 @@ class RandomCutOut(object):
@TRANSFORMS.register_module()
class RandomMosaic(object):
class RandomMosaic(BaseTransform):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.
@ -789,6 +889,19 @@ class RandomMosaic(object):
sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch
Required Keys:
- img
- gt_seg_map
- mix_results
Modified Keys:
- img
- img_shape
- ori_shape
- gt_seg_map
Args:
prob (float): mosaic probability.
img_scale (Sequence[int]): Image size after mosaic pipeline of
@ -815,7 +928,11 @@ class RandomMosaic(object):
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
def __call__(self, results):
@cache_randomness
def do_mosaic(self):
return np.random.rand() < self.prob
def transform(self, results: dict) -> dict:
"""Call function to make a mosaic of image.
Args:
@ -824,13 +941,13 @@ class RandomMosaic(object):
Returns:
dict: Result dict with mosaic transformed.
"""
mosaic = True if np.random.rand() < self.prob else False
mosaic = self.do_mosaic()
if mosaic:
results = self._mosaic_transform_img(results)
results = self._mosaic_transform_seg(results)
return results
def get_indexes(self, dataset):
def get_indices(self, dataset: MultiImageMixDataset) -> list:
"""Call function to collect indexes.
Args:
@ -843,7 +960,16 @@ class RandomMosaic(object):
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes
def _mosaic_transform_img(self, results):
@cache_randomness
def generate_mosaic_center(self):
# mosaic center x, y
center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
return center_x, center_y
def _mosaic_transform_img(self, results: dict) -> dict:
"""Mosaic transform function.
Args:
@ -866,10 +992,7 @@ class RandomMosaic(object):
dtype=results['img'].dtype)
# mosaic center x, y
self.center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
self.center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
self.center_x, self.center_y = self.generate_mosaic_center()
center_position = (self.center_x, self.center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
@ -902,7 +1025,7 @@ class RandomMosaic(object):
return results
def _mosaic_transform_seg(self, results):
def _mosaic_transform_seg(self, results: dict) -> dict:
"""Mosaic transform function for label annotations.
Args:
@ -953,7 +1076,8 @@ class RandomMosaic(object):
return results
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float],
img_shape_wh: Sequence[int]) -> tuple:
"""Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image.

View File

@ -7,11 +7,9 @@ import numpy as np
import pytest
from PIL import Image
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
register_all_modules()
def test_resize():
@ -233,6 +231,72 @@ def test_random_crop():
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
def test_rgb2gray():
# test assertion out_channels should be greater than 0
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=-1)
TRANSFORMS.build(transform)
# test assertion weights should be tuple[float]
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
TRANSFORMS.build(transform)
# test out_channels is None
transform = dict(type='RGB2Gray')
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RGB2Gray(' \
f'out_channels={None}, ' \
f'weights={(0.299, 0.587, 0.114)})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape == (h, w, c)
assert results['img_shape'] == (h, w, c)
assert results['ori_shape'] == (h, w, c)
# test out_channels = 2
transform = dict(type='RGB2Gray', out_channels=2)
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RGB2Gray(' \
f'out_channels={2}, ' \
f'weights={(0.299, 0.587, 0.114)})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape == (h, w, 2)
assert results['img_shape'] == (h, w, 2)
def test_photo_metric_distortion():
results = dict()
@ -252,3 +316,366 @@ def test_photo_metric_distortion():
assert (results['gt_semantic_seg'] == seg).all()
assert results['img_shape'] == img.shape
def test_rerange():
# test assertion if min_value or max_value is illegal
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=[0], max_value=[255])
TRANSFORMS.build(transform)
# test assertion if min_value >= max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=1, max_value=1)
TRANSFORMS.build(transform)
# test assertion if img_min_value == img_max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=0, max_value=1)
transform = TRANSFORMS.build(transform)
results = dict()
results['img'] = np.array([[1, 1], [1, 1]])
transform(results)
img_rerange_cfg = dict()
transform = dict(type='Rerange', **img_rerange_cfg)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
min_value = np.min(original_img)
max_value = np.max(original_img)
converted_img = (original_img - min_value) / (max_value - min_value) * 255
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
TRANSFORMS.build(transform)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
TRANSFORMS.build(transform)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
TRANSFORMS.build(transform)
transform = dict(type='CLAHE', clip_limit=2)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
def test_adjust_gamma():
# test assertion if gamma <= 0
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=0)
TRANSFORMS.build(transform)
# test assertion if gamma is list
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=[1.2])
TRANSFORMS.build(transform)
# test with gamma = 1.2
transform = dict(type='AdjustGamma', gamma=1.2)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
inv_gamma = 1.0 / 1.2
table = np.array([((i / 255.0)**inv_gamma) * 255
for i in np.arange(0, 256)]).astype('uint8')
converted_img = mmcv.lut_transform(
np.array(original_img, dtype=np.uint8), table)
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'AdjustGamma(gamma={1.2})'
def test_rotate():
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
TRANSFORMS.build(transform)
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
TRANSFORMS.build(transform)
transform = dict(type='RandomRotate', degree=10., prob=1.)
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RandomRotate(' \
f'prob={1.}, ' \
f'degree=({-10.}, {10.}), ' \
f'pad_val={0}, ' \
f'seg_pad_val={255}, ' \
f'center={None}, ' \
f'auto_bound={False})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, _ = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape[:2] == (h, w)
def test_seg_rescale():
results = dict()
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
h, w = seg.shape
transform = dict(type='SegRescale', scale_factor=1. / 2)
rescale_module = TRANSFORMS.build(transform)
rescale_results = rescale_module(results.copy())
assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2)
transform = dict(type='SegRescale', scale_factor=1)
rescale_module = TRANSFORMS.build(transform)
rescale_results = rescale_module(results.copy())
assert rescale_results['gt_semantic_seg'].shape == (h, w)
def test_mosaic():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid img_scale
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
assert 'Mosaic' in repr(mosaic_module)
# test assertion for invalid mix_results
with pytest.raises(AssertionError):
mosaic_module(results)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
results = dict()
results['img'] = img[:, :, 0]
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == img.shape[:2]
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
def test_cutout():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
TRANSFORMS.build(transform)
# test n_holes
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=(3, 4, 5),
cutout_shape=(8, 8))
TRANSFORMS.build(transform)
# test cutout_shape and cutout_ratio
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
TRANSFORMS.build(transform)
# either of cutout_shape and cutout_ratio should be given
with pytest.raises(AssertionError):
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(2, 2),
cutout_ratio=(0.4, 0.4))
TRANSFORMS.build(transform)
# test seg_fill_in
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(8, 8),
seg_fill_in='a')
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(8, 8),
seg_fill_in=256)
TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['pad_shape'] = img.shape
results['img_fields'] = ['img']
transform = dict(
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
cutout_module = TRANSFORMS.build(transform)
assert 'cutout_shape' in repr(cutout_module)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() < img.sum()
transform = dict(
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
cutout_module = TRANSFORMS.build(transform)
assert 'cutout_ratio' in repr(cutout_module)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() < img.sum()
transform = dict(
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() == img.sum()
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
transform = dict(
type='RandomCutOut',
prob=1,
n_holes=(2, 4),
cutout_shape=[(10, 10), (15, 15)],
fill_in=(255, 255, 255),
seg_fill_in=None)
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
transform = dict(
type='RandomCutOut',
prob=1,
n_holes=1,
cutout_ratio=(0.8, 0.8),
fill_in=(255, 255, 255),
seg_fill_in=255)
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
def test_resize_to_multiple():
transform = dict(type='ResizeToMultiple', size_divisor=32)
transform = TRANSFORMS.build(transform)
img = np.random.randn(213, 232, 3)
seg = np.random.randint(0, 19, (213, 232))
results = dict()
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['pad_shape'] = img.shape
results = transform(results)
assert results['img'].shape == (224, 256, 3)
assert results['gt_semantic_seg'].shape == (224, 256)
assert results['img_shape'] == (224, 256, 3)