diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py index 57136e33f..933eb50d9 100644 --- a/mmseg/datasets/dataset_wrappers.py +++ b/mmseg/datasets/dataset_wrappers.py @@ -105,8 +105,8 @@ class MultiImageMixDataset: transform_type in self._skip_type_keys: continue - if hasattr(transform, 'get_indexes'): - indexes = transform.get_indexes(self.dataset) + if hasattr(transform, 'get_indices'): + indexes = transform.get_indices(self.dataset) if not isinstance(indexes, collections.abc.Sequence): indexes = [indexes] mix_results = [ diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 52c61953b..8485ab2b0 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -9,13 +9,25 @@ from mmcv.transforms.utils import cache_randomness from mmcv.utils import is_tuple_of from numpy import random +from mmseg.datasets.dataset_wrappers import MultiImageMixDataset from mmseg.registry import TRANSFORMS @TRANSFORMS.register_module() -class ResizeToMultiple(object): +class ResizeToMultiple(BaseTransform): """Resize images & seg to multiple of divisor. + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - pad_shape + Args: size_divisor (int): images and gt seg maps need to resize to multiple of size_divisor. Default: 32. @@ -27,7 +39,7 @@ class ResizeToMultiple(object): self.size_divisor = size_divisor self.interpolation = interpolation - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to resize images, semantic segmentation map to multiple of size divisor. @@ -70,9 +82,17 @@ class ResizeToMultiple(object): @TRANSFORMS.register_module() -class Rerange(object): +class Rerange(BaseTransform): """Rerange the image pixel value. + Required Keys: + + - img + + Modified Keys: + + - img + Args: min_value (float or int): Minimum value of the reranged image. Default: 0. @@ -87,7 +107,7 @@ class Rerange(object): self.min_value = min_value self.max_value = max_value - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to rerange images. Args: @@ -116,12 +136,20 @@ class Rerange(object): @TRANSFORMS.register_module() -class CLAHE(object): +class CLAHE(BaseTransform): """Use CLAHE method to process the image. See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. Graphics Gems, 1994:474-485.` for more information. + Required Keys: + + - img + + Modified Keys: + + - img + Args: clip_limit (float): Threshold for contrast limiting. Default: 40.0. tile_grid_size (tuple[int]): Size of grid for histogram equalization. @@ -136,7 +164,7 @@ class CLAHE(object): assert len(tile_grid_size) == 2 self.tile_grid_size = tile_grid_size - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to Use CLAHE method process images. Args: @@ -167,13 +195,13 @@ class RandomCrop(BaseTransform): Required Keys: - img - - gt_semantic_seg + - gt_seg_map Modified Keys: - img - img_shape - - gt_semantic_seg + - gt_seg_map Args: @@ -293,9 +321,19 @@ class RandomCrop(BaseTransform): @TRANSFORMS.register_module() -class RandomRotate(object): +class RandomRotate(BaseTransform): """Rotate the image & seg. + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + Args: prob (float): The rotation probability. degree (float, tuple[float]): Range of degrees to select from. If @@ -332,7 +370,12 @@ class RandomRotate(object): self.center = center self.auto_bound = auto_bound - def __call__(self, results): + @cache_randomness + def generate_degree(self): + return np.random.rand() < self.prob, np.random.uniform( + min(*self.degree), max(*self.degree)) + + def transform(self, results: dict) -> dict: """Call function to rotate image, semantic segmentation maps. Args: @@ -342,8 +385,7 @@ class RandomRotate(object): dict: Rotated results. """ - rotate = True if np.random.rand() < self.prob else False - degree = np.random.uniform(min(*self.degree), max(*self.degree)) + rotate, degree = self.generate_degree() if rotate: # rotate image results['img'] = mmcv.imrotate( @@ -376,9 +418,18 @@ class RandomRotate(object): @TRANSFORMS.register_module() -class RGB2Gray(object): +class RGB2Gray(BaseTransform): """Convert RGB image to grayscale image. + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + This transform calculate the weighted mean of input image channels with ``weights`` and then expand the channels to ``out_channels``. When ``out_channels`` is None, the number of output channels is the same as @@ -399,7 +450,7 @@ class RGB2Gray(object): assert isinstance(item, (float, int)) self.weights = weights - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to convert RGB image to grayscale image. Args: @@ -431,9 +482,17 @@ class RGB2Gray(object): @TRANSFORMS.register_module() -class AdjustGamma(object): +class AdjustGamma(BaseTransform): """Using gamma correction to process the image. + Required Keys: + + - img + + Modified Keys: + + - img + Args: gamma (float or int): Gamma value used in gamma correction. Default: 1.0. @@ -447,7 +506,7 @@ class AdjustGamma(object): self.table = np.array([(i / 255.0)**inv_gamma * 255 for i in np.arange(256)]).astype('uint8') - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to process the image with gamma correction. Args: @@ -467,9 +526,17 @@ class AdjustGamma(object): @TRANSFORMS.register_module() -class SegRescale(object): +class SegRescale(BaseTransform): """Rescale semantic segmentation maps. + Required Keys: + + - gt_seg_map + + Modified Keys: + + - gt_seg_map + Args: scale_factor (float): The scale factor of the final output. """ @@ -477,7 +544,7 @@ class SegRescale(object): def __init__(self, scale_factor=1): self.scale_factor = scale_factor - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to scale the semantic segmentation map. Args: @@ -667,11 +734,22 @@ class PhotoMetricDistortion(BaseTransform): @TRANSFORMS.register_module() -class RandomCutOut(object): +class RandomCutOut(BaseTransform): """CutOut operation. Randomly drop some regions of image used in `Cutout `_. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + Args: prob (float): cutout probability. n_holes (int | tuple[int, int]): Number of regions to be dropped. @@ -721,16 +799,38 @@ class RandomCutOut(object): if not isinstance(self.candidates, list): self.candidates = [self.candidates] - def __call__(self, results): + @cache_randomness + def do_cutout(self): + return np.random.rand() < self.prob + + @cache_randomness + def generate_patches(self, results): + cutout = self.do_cutout() + + h, w, _ = results['img'].shape + if cutout: + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + else: + n_holes = 0 + x1_lst = [] + y1_lst = [] + index_lst = [] + for _ in range(n_holes): + x1_lst.append(np.random.randint(0, w)) + y1_lst.append(np.random.randint(0, h)) + index_lst.append(np.random.randint(0, len(self.candidates))) + return cutout, n_holes, x1_lst, y1_lst, index_lst + + def transform(self, results: dict) -> dict: """Call function to drop some regions of image.""" - cutout = True if np.random.rand() < self.prob else False + cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches( + results) if cutout: h, w, c = results['img'].shape - n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) - for _ in range(n_holes): - x1 = np.random.randint(0, w) - y1 = np.random.randint(0, h) - index = np.random.randint(0, len(self.candidates)) + for i in range(n_holes): + x1 = x1_lst[i] + y1 = y1_lst[i] + index = index_lst[i] if not self.with_ratio: cutout_w, cutout_h = self.candidates[index] else: @@ -759,7 +859,7 @@ class RandomCutOut(object): @TRANSFORMS.register_module() -class RandomMosaic(object): +class RandomMosaic(BaseTransform): """Mosaic augmentation. Given 4 images, mosaic transform combines them into one output image. The output image is composed of the parts from each sub- image. @@ -789,6 +889,19 @@ class RandomMosaic(object): sample another 3 images from the custom dataset. 3. Sub image will be cropped if image is larger than mosaic patch + Required Keys: + + - img + - gt_seg_map + - mix_results + + Modified Keys: + + - img + - img_shape + - ori_shape + - gt_seg_map + Args: prob (float): mosaic probability. img_scale (Sequence[int]): Image size after mosaic pipeline of @@ -815,7 +928,11 @@ class RandomMosaic(object): self.pad_val = pad_val self.seg_pad_val = seg_pad_val - def __call__(self, results): + @cache_randomness + def do_mosaic(self): + return np.random.rand() < self.prob + + def transform(self, results: dict) -> dict: """Call function to make a mosaic of image. Args: @@ -824,13 +941,13 @@ class RandomMosaic(object): Returns: dict: Result dict with mosaic transformed. """ - mosaic = True if np.random.rand() < self.prob else False + mosaic = self.do_mosaic() if mosaic: results = self._mosaic_transform_img(results) results = self._mosaic_transform_seg(results) return results - def get_indexes(self, dataset): + def get_indices(self, dataset: MultiImageMixDataset) -> list: """Call function to collect indexes. Args: @@ -843,7 +960,16 @@ class RandomMosaic(object): indexes = [random.randint(0, len(dataset)) for _ in range(3)] return indexes - def _mosaic_transform_img(self, results): + @cache_randomness + def generate_mosaic_center(self): + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + return center_x, center_y + + def _mosaic_transform_img(self, results: dict) -> dict: """Mosaic transform function. Args: @@ -866,10 +992,7 @@ class RandomMosaic(object): dtype=results['img'].dtype) # mosaic center x, y - self.center_x = int( - random.uniform(*self.center_ratio_range) * self.img_scale[1]) - self.center_y = int( - random.uniform(*self.center_ratio_range) * self.img_scale[0]) + self.center_x, self.center_y = self.generate_mosaic_center() center_position = (self.center_x, self.center_y) loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') @@ -902,7 +1025,7 @@ class RandomMosaic(object): return results - def _mosaic_transform_seg(self, results): + def _mosaic_transform_seg(self, results: dict) -> dict: """Mosaic transform function for label annotations. Args: @@ -953,7 +1076,8 @@ class RandomMosaic(object): return results - def _mosaic_combine(self, loc, center_position_xy, img_shape_wh): + def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float], + img_shape_wh: Sequence[int]) -> tuple: """Calculate global coordinate of mosaic image and local coordinate of cropped sub-image. diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index fac21f2a8..bf4accf67 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -7,11 +7,9 @@ import numpy as np import pytest from PIL import Image +from mmseg.datasets.transforms import * # noqa from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop from mmseg.registry import TRANSFORMS -from mmseg.utils import register_all_modules - -register_all_modules() def test_resize(): @@ -233,6 +231,72 @@ def test_random_crop(): assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20) +def test_rgb2gray(): + # test assertion out_channels should be greater than 0 + with pytest.raises(AssertionError): + transform = dict(type='RGB2Gray', out_channels=-1) + TRANSFORMS.build(transform) + # test assertion weights should be tuple[float] + with pytest.raises(AssertionError): + transform = dict(type='RGB2Gray', out_channels=1, weights=1.1) + TRANSFORMS.build(transform) + + # test out_channels is None + transform = dict(type='RGB2Gray') + transform = TRANSFORMS.build(transform) + + assert str(transform) == f'RGB2Gray(' \ + f'out_channels={None}, ' \ + f'weights={(0.299, 0.587, 0.114)})' + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + h, w, c = img.shape + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + assert results['img'].shape == (h, w, c) + assert results['img_shape'] == (h, w, c) + assert results['ori_shape'] == (h, w, c) + + # test out_channels = 2 + transform = dict(type='RGB2Gray', out_channels=2) + transform = TRANSFORMS.build(transform) + + assert str(transform) == f'RGB2Gray(' \ + f'out_channels={2}, ' \ + f'weights={(0.299, 0.587, 0.114)})' + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + h, w, c = img.shape + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + assert results['img'].shape == (h, w, 2) + assert results['img_shape'] == (h, w, 2) + + def test_photo_metric_distortion(): results = dict() @@ -252,3 +316,366 @@ def test_photo_metric_distortion(): assert (results['gt_semantic_seg'] == seg).all() assert results['img_shape'] == img.shape + + +def test_rerange(): + # test assertion if min_value or max_value is illegal + with pytest.raises(AssertionError): + transform = dict(type='Rerange', min_value=[0], max_value=[255]) + TRANSFORMS.build(transform) + + # test assertion if min_value >= max_value + with pytest.raises(AssertionError): + transform = dict(type='Rerange', min_value=1, max_value=1) + TRANSFORMS.build(transform) + + # test assertion if img_min_value == img_max_value + with pytest.raises(AssertionError): + transform = dict(type='Rerange', min_value=0, max_value=1) + transform = TRANSFORMS.build(transform) + results = dict() + results['img'] = np.array([[1, 1], [1, 1]]) + transform(results) + + img_rerange_cfg = dict() + transform = dict(type='Rerange', **img_rerange_cfg) + transform = TRANSFORMS.build(transform) + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + + min_value = np.min(original_img) + max_value = np.max(original_img) + converted_img = (original_img - min_value) / (max_value - min_value) * 255 + + assert np.allclose(results['img'], converted_img) + assert str(transform) == f'Rerange(min_value={0}, max_value={255})' + + +def test_CLAHE(): + # test assertion if clip_limit is None + with pytest.raises(AssertionError): + transform = dict(type='CLAHE', clip_limit=None) + TRANSFORMS.build(transform) + + # test assertion if tile_grid_size is illegal + with pytest.raises(AssertionError): + transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0)) + TRANSFORMS.build(transform) + + # test assertion if tile_grid_size is illegal + with pytest.raises(AssertionError): + transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9)) + TRANSFORMS.build(transform) + + transform = dict(type='CLAHE', clip_limit=2) + transform = TRANSFORMS.build(transform) + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + + converted_img = np.empty(original_img.shape) + for i in range(original_img.shape[2]): + converted_img[:, :, i] = mmcv.clahe( + np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8)) + + assert np.allclose(results['img'], converted_img) + assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})' + + +def test_adjust_gamma(): + # test assertion if gamma <= 0 + with pytest.raises(AssertionError): + transform = dict(type='AdjustGamma', gamma=0) + TRANSFORMS.build(transform) + + # test assertion if gamma is list + with pytest.raises(AssertionError): + transform = dict(type='AdjustGamma', gamma=[1.2]) + TRANSFORMS.build(transform) + + # test with gamma = 1.2 + transform = dict(type='AdjustGamma', gamma=1.2) + transform = TRANSFORMS.build(transform) + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + + inv_gamma = 1.0 / 1.2 + table = np.array([((i / 255.0)**inv_gamma) * 255 + for i in np.arange(0, 256)]).astype('uint8') + converted_img = mmcv.lut_transform( + np.array(original_img, dtype=np.uint8), table) + assert np.allclose(results['img'], converted_img) + assert str(transform) == f'AdjustGamma(gamma={1.2})' + + +def test_rotate(): + # test assertion degree should be tuple[float] or float + with pytest.raises(AssertionError): + transform = dict(type='RandomRotate', prob=0.5, degree=-10) + TRANSFORMS.build(transform) + # test assertion degree should be tuple[float] or float + with pytest.raises(AssertionError): + transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.)) + TRANSFORMS.build(transform) + + transform = dict(type='RandomRotate', degree=10., prob=1.) + transform = TRANSFORMS.build(transform) + + assert str(transform) == f'RandomRotate(' \ + f'prob={1.}, ' \ + f'degree=({-10.}, {10.}), ' \ + f'pad_val={0}, ' \ + f'seg_pad_val={255}, ' \ + f'center={None}, ' \ + f'auto_bound={False})' + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + h, w, _ = img.shape + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + assert results['img'].shape[:2] == (h, w) + + +def test_seg_rescale(): + results = dict() + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + h, w = seg.shape + + transform = dict(type='SegRescale', scale_factor=1. / 2) + rescale_module = TRANSFORMS.build(transform) + rescale_results = rescale_module(results.copy()) + assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2) + + transform = dict(type='SegRescale', scale_factor=1) + rescale_module = TRANSFORMS.build(transform) + rescale_results = rescale_module(results.copy()) + assert rescale_results['gt_semantic_seg'].shape == (h, w) + + +def test_mosaic(): + # test prob + with pytest.raises(AssertionError): + transform = dict(type='RandomMosaic', prob=1.5) + TRANSFORMS.build(transform) + # test assertion for invalid img_scale + with pytest.raises(AssertionError): + transform = dict(type='RandomMosaic', prob=1, img_scale=640) + TRANSFORMS.build(transform) + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + + transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + mosaic_module = TRANSFORMS.build(transform) + assert 'Mosaic' in repr(mosaic_module) + + # test assertion for invalid mix_results + with pytest.raises(AssertionError): + mosaic_module(results) + + results['mix_results'] = [copy.deepcopy(results)] * 3 + results = mosaic_module(results) + assert results['img'].shape[:2] == (20, 24) + + results = dict() + results['img'] = img[:, :, 0] + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + + transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12)) + mosaic_module = TRANSFORMS.build(transform) + results['mix_results'] = [copy.deepcopy(results)] * 3 + results = mosaic_module(results) + assert results['img'].shape[:2] == img.shape[:2] + + transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + mosaic_module = TRANSFORMS.build(transform) + results = mosaic_module(results) + assert results['img'].shape[:2] == (20, 24) + + +def test_cutout(): + # test prob + with pytest.raises(AssertionError): + transform = dict(type='RandomCutOut', prob=1.5, n_holes=1) + TRANSFORMS.build(transform) + # test n_holes + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8)) + TRANSFORMS.build(transform) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=(3, 4, 5), + cutout_shape=(8, 8)) + TRANSFORMS.build(transform) + # test cutout_shape and cutout_ratio + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8) + TRANSFORMS.build(transform) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2) + TRANSFORMS.build(transform) + # either of cutout_shape and cutout_ratio should be given + with pytest.raises(AssertionError): + transform = dict(type='RandomCutOut', prob=0.5, n_holes=1) + TRANSFORMS.build(transform) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(2, 2), + cutout_ratio=(0.4, 0.4)) + TRANSFORMS.build(transform) + # test seg_fill_in + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(8, 8), + seg_fill_in='a') + TRANSFORMS.build(transform) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(8, 8), + seg_fill_in=256) + TRANSFORMS.build(transform) + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['pad_shape'] = img.shape + results['img_fields'] = ['img'] + + transform = dict( + type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10)) + cutout_module = TRANSFORMS.build(transform) + assert 'cutout_shape' in repr(cutout_module) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() < img.sum() + + transform = dict( + type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8)) + cutout_module = TRANSFORMS.build(transform) + assert 'cutout_ratio' in repr(cutout_module) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() < img.sum() + + transform = dict( + type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8)) + cutout_module = TRANSFORMS.build(transform) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() == img.sum() + assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + + transform = dict( + type='RandomCutOut', + prob=1, + n_holes=(2, 4), + cutout_shape=[(10, 10), (15, 15)], + fill_in=(255, 255, 255), + seg_fill_in=None) + cutout_module = TRANSFORMS.build(transform) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() > img.sum() + assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + + transform = dict( + type='RandomCutOut', + prob=1, + n_holes=1, + cutout_ratio=(0.8, 0.8), + fill_in=(255, 255, 255), + seg_fill_in=255) + cutout_module = TRANSFORMS.build(transform) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() > img.sum() + assert cutout_result['gt_semantic_seg'].sum() > seg.sum() + + +def test_resize_to_multiple(): + transform = dict(type='ResizeToMultiple', size_divisor=32) + transform = TRANSFORMS.build(transform) + + img = np.random.randn(213, 232, 3) + seg = np.random.randint(0, 19, (213, 232)) + results = dict() + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + + results = transform(results) + assert results['img'].shape == (224, 256, 3) + assert results['gt_semantic_seg'].shape == (224, 256) + assert results['img_shape'] == (224, 256, 3)