mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add P1 DataTransform (#1843)
* [Feature] Add P1 DataTransform * fix unit test error * fix @cache_randomness location
This commit is contained in:
parent
76c5ce1396
commit
ecab73a892
@ -105,8 +105,8 @@ class MultiImageMixDataset:
|
||||
transform_type in self._skip_type_keys:
|
||||
continue
|
||||
|
||||
if hasattr(transform, 'get_indexes'):
|
||||
indexes = transform.get_indexes(self.dataset)
|
||||
if hasattr(transform, 'get_indices'):
|
||||
indexes = transform.get_indices(self.dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
mix_results = [
|
||||
|
@ -9,13 +9,25 @@ from mmcv.transforms.utils import cache_randomness
|
||||
from mmcv.utils import is_tuple_of
|
||||
from numpy import random
|
||||
|
||||
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeToMultiple(object):
|
||||
class ResizeToMultiple(BaseTransform):
|
||||
"""Resize images & seg to multiple of divisor.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- pad_shape
|
||||
|
||||
Args:
|
||||
size_divisor (int): images and gt seg maps need to resize to multiple
|
||||
of size_divisor. Default: 32.
|
||||
@ -27,7 +39,7 @@ class ResizeToMultiple(object):
|
||||
self.size_divisor = size_divisor
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to resize images, semantic segmentation map to
|
||||
multiple of size divisor.
|
||||
|
||||
@ -70,9 +82,17 @@ class ResizeToMultiple(object):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Rerange(object):
|
||||
class Rerange(BaseTransform):
|
||||
"""Rerange the image pixel value.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
min_value (float or int): Minimum value of the reranged image.
|
||||
Default: 0.
|
||||
@ -87,7 +107,7 @@ class Rerange(object):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to rerange images.
|
||||
|
||||
Args:
|
||||
@ -116,12 +136,20 @@ class Rerange(object):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class CLAHE(object):
|
||||
class CLAHE(BaseTransform):
|
||||
"""Use CLAHE method to process the image.
|
||||
|
||||
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
||||
Graphics Gems, 1994:474-485.` for more information.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
||||
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
||||
@ -136,7 +164,7 @@ class CLAHE(object):
|
||||
assert len(tile_grid_size) == 2
|
||||
self.tile_grid_size = tile_grid_size
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to Use CLAHE method process images.
|
||||
|
||||
Args:
|
||||
@ -167,13 +195,13 @@ class RandomCrop(BaseTransform):
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_semantic_seg
|
||||
- gt_seg_map
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_semantic_seg
|
||||
- gt_seg_map
|
||||
|
||||
|
||||
Args:
|
||||
@ -293,9 +321,19 @@ class RandomCrop(BaseTransform):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotate(object):
|
||||
class RandomRotate(BaseTransform):
|
||||
"""Rotate the image & seg.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map
|
||||
|
||||
Args:
|
||||
prob (float): The rotation probability.
|
||||
degree (float, tuple[float]): Range of degrees to select from. If
|
||||
@ -332,7 +370,12 @@ class RandomRotate(object):
|
||||
self.center = center
|
||||
self.auto_bound = auto_bound
|
||||
|
||||
def __call__(self, results):
|
||||
@cache_randomness
|
||||
def generate_degree(self):
|
||||
return np.random.rand() < self.prob, np.random.uniform(
|
||||
min(*self.degree), max(*self.degree))
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to rotate image, semantic segmentation maps.
|
||||
|
||||
Args:
|
||||
@ -342,8 +385,7 @@ class RandomRotate(object):
|
||||
dict: Rotated results.
|
||||
"""
|
||||
|
||||
rotate = True if np.random.rand() < self.prob else False
|
||||
degree = np.random.uniform(min(*self.degree), max(*self.degree))
|
||||
rotate, degree = self.generate_degree()
|
||||
if rotate:
|
||||
# rotate image
|
||||
results['img'] = mmcv.imrotate(
|
||||
@ -376,9 +418,18 @@ class RandomRotate(object):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RGB2Gray(object):
|
||||
class RGB2Gray(BaseTransform):
|
||||
"""Convert RGB image to grayscale image.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
This transform calculate the weighted mean of input image channels with
|
||||
``weights`` and then expand the channels to ``out_channels``. When
|
||||
``out_channels`` is None, the number of output channels is the same as
|
||||
@ -399,7 +450,7 @@ class RGB2Gray(object):
|
||||
assert isinstance(item, (float, int))
|
||||
self.weights = weights
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to convert RGB image to grayscale image.
|
||||
|
||||
Args:
|
||||
@ -431,9 +482,17 @@ class RGB2Gray(object):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class AdjustGamma(object):
|
||||
class AdjustGamma(BaseTransform):
|
||||
"""Using gamma correction to process the image.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
gamma (float or int): Gamma value used in gamma correction.
|
||||
Default: 1.0.
|
||||
@ -447,7 +506,7 @@ class AdjustGamma(object):
|
||||
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
||||
for i in np.arange(256)]).astype('uint8')
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to process the image with gamma correction.
|
||||
|
||||
Args:
|
||||
@ -467,9 +526,17 @@ class AdjustGamma(object):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class SegRescale(object):
|
||||
class SegRescale(BaseTransform):
|
||||
"""Rescale semantic segmentation maps.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- gt_seg_map
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- gt_seg_map
|
||||
|
||||
Args:
|
||||
scale_factor (float): The scale factor of the final output.
|
||||
"""
|
||||
@ -477,7 +544,7 @@ class SegRescale(object):
|
||||
def __init__(self, scale_factor=1):
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to scale the semantic segmentation map.
|
||||
|
||||
Args:
|
||||
@ -667,11 +734,22 @@ class PhotoMetricDistortion(BaseTransform):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCutOut(object):
|
||||
class RandomCutOut(BaseTransform):
|
||||
"""CutOut operation.
|
||||
|
||||
Randomly drop some regions of image used in
|
||||
`Cutout <https://arxiv.org/abs/1708.04552>`_.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map
|
||||
|
||||
Args:
|
||||
prob (float): cutout probability.
|
||||
n_holes (int | tuple[int, int]): Number of regions to be dropped.
|
||||
@ -721,16 +799,38 @@ class RandomCutOut(object):
|
||||
if not isinstance(self.candidates, list):
|
||||
self.candidates = [self.candidates]
|
||||
|
||||
def __call__(self, results):
|
||||
@cache_randomness
|
||||
def do_cutout(self):
|
||||
return np.random.rand() < self.prob
|
||||
|
||||
@cache_randomness
|
||||
def generate_patches(self, results):
|
||||
cutout = self.do_cutout()
|
||||
|
||||
h, w, _ = results['img'].shape
|
||||
if cutout:
|
||||
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
|
||||
else:
|
||||
n_holes = 0
|
||||
x1_lst = []
|
||||
y1_lst = []
|
||||
index_lst = []
|
||||
for _ in range(n_holes):
|
||||
x1_lst.append(np.random.randint(0, w))
|
||||
y1_lst.append(np.random.randint(0, h))
|
||||
index_lst.append(np.random.randint(0, len(self.candidates)))
|
||||
return cutout, n_holes, x1_lst, y1_lst, index_lst
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to drop some regions of image."""
|
||||
cutout = True if np.random.rand() < self.prob else False
|
||||
cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
|
||||
results)
|
||||
if cutout:
|
||||
h, w, c = results['img'].shape
|
||||
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
|
||||
for _ in range(n_holes):
|
||||
x1 = np.random.randint(0, w)
|
||||
y1 = np.random.randint(0, h)
|
||||
index = np.random.randint(0, len(self.candidates))
|
||||
for i in range(n_holes):
|
||||
x1 = x1_lst[i]
|
||||
y1 = y1_lst[i]
|
||||
index = index_lst[i]
|
||||
if not self.with_ratio:
|
||||
cutout_w, cutout_h = self.candidates[index]
|
||||
else:
|
||||
@ -759,7 +859,7 @@ class RandomCutOut(object):
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomMosaic(object):
|
||||
class RandomMosaic(BaseTransform):
|
||||
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
|
||||
one output image. The output image is composed of the parts from each sub-
|
||||
image.
|
||||
@ -789,6 +889,19 @@ class RandomMosaic(object):
|
||||
sample another 3 images from the custom dataset.
|
||||
3. Sub image will be cropped if image is larger than mosaic patch
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map
|
||||
- mix_results
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
- gt_seg_map
|
||||
|
||||
Args:
|
||||
prob (float): mosaic probability.
|
||||
img_scale (Sequence[int]): Image size after mosaic pipeline of
|
||||
@ -815,7 +928,11 @@ class RandomMosaic(object):
|
||||
self.pad_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
|
||||
def __call__(self, results):
|
||||
@cache_randomness
|
||||
def do_mosaic(self):
|
||||
return np.random.rand() < self.prob
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to make a mosaic of image.
|
||||
|
||||
Args:
|
||||
@ -824,13 +941,13 @@ class RandomMosaic(object):
|
||||
Returns:
|
||||
dict: Result dict with mosaic transformed.
|
||||
"""
|
||||
mosaic = True if np.random.rand() < self.prob else False
|
||||
mosaic = self.do_mosaic()
|
||||
if mosaic:
|
||||
results = self._mosaic_transform_img(results)
|
||||
results = self._mosaic_transform_seg(results)
|
||||
return results
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
def get_indices(self, dataset: MultiImageMixDataset) -> list:
|
||||
"""Call function to collect indexes.
|
||||
|
||||
Args:
|
||||
@ -843,7 +960,16 @@ class RandomMosaic(object):
|
||||
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
|
||||
return indexes
|
||||
|
||||
def _mosaic_transform_img(self, results):
|
||||
@cache_randomness
|
||||
def generate_mosaic_center(self):
|
||||
# mosaic center x, y
|
||||
center_x = int(
|
||||
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
||||
center_y = int(
|
||||
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
||||
return center_x, center_y
|
||||
|
||||
def _mosaic_transform_img(self, results: dict) -> dict:
|
||||
"""Mosaic transform function.
|
||||
|
||||
Args:
|
||||
@ -866,10 +992,7 @@ class RandomMosaic(object):
|
||||
dtype=results['img'].dtype)
|
||||
|
||||
# mosaic center x, y
|
||||
self.center_x = int(
|
||||
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
||||
self.center_y = int(
|
||||
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
||||
self.center_x, self.center_y = self.generate_mosaic_center()
|
||||
center_position = (self.center_x, self.center_y)
|
||||
|
||||
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
||||
@ -902,7 +1025,7 @@ class RandomMosaic(object):
|
||||
|
||||
return results
|
||||
|
||||
def _mosaic_transform_seg(self, results):
|
||||
def _mosaic_transform_seg(self, results: dict) -> dict:
|
||||
"""Mosaic transform function for label annotations.
|
||||
|
||||
Args:
|
||||
@ -953,7 +1076,8 @@ class RandomMosaic(object):
|
||||
|
||||
return results
|
||||
|
||||
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
|
||||
def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float],
|
||||
img_shape_wh: Sequence[int]) -> tuple:
|
||||
"""Calculate global coordinate of mosaic image and local coordinate of
|
||||
cropped sub-image.
|
||||
|
||||
|
@ -7,11 +7,9 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def test_resize():
|
||||
@ -233,6 +231,72 @@ def test_random_crop():
|
||||
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
|
||||
|
||||
|
||||
def test_rgb2gray():
|
||||
# test assertion out_channels should be greater than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RGB2Gray', out_channels=-1)
|
||||
TRANSFORMS.build(transform)
|
||||
# test assertion weights should be tuple[float]
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test out_channels is None
|
||||
transform = dict(type='RGB2Gray')
|
||||
transform = TRANSFORMS.build(transform)
|
||||
|
||||
assert str(transform) == f'RGB2Gray(' \
|
||||
f'out_channels={None}, ' \
|
||||
f'weights={(0.299, 0.587, 0.114)})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, c = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (h, w, c)
|
||||
assert results['img_shape'] == (h, w, c)
|
||||
assert results['ori_shape'] == (h, w, c)
|
||||
|
||||
# test out_channels = 2
|
||||
transform = dict(type='RGB2Gray', out_channels=2)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
|
||||
assert str(transform) == f'RGB2Gray(' \
|
||||
f'out_channels={2}, ' \
|
||||
f'weights={(0.299, 0.587, 0.114)})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, c = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (h, w, 2)
|
||||
assert results['img_shape'] == (h, w, 2)
|
||||
|
||||
|
||||
def test_photo_metric_distortion():
|
||||
|
||||
results = dict()
|
||||
@ -252,3 +316,366 @@ def test_photo_metric_distortion():
|
||||
|
||||
assert (results['gt_semantic_seg'] == seg).all()
|
||||
assert results['img_shape'] == img.shape
|
||||
|
||||
|
||||
def test_rerange():
|
||||
# test assertion if min_value or max_value is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Rerange', min_value=[0], max_value=[255])
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if min_value >= max_value
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Rerange', min_value=1, max_value=1)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if img_min_value == img_max_value
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Rerange', min_value=0, max_value=1)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
results['img'] = np.array([[1, 1], [1, 1]])
|
||||
transform(results)
|
||||
|
||||
img_rerange_cfg = dict()
|
||||
transform = dict(type='Rerange', **img_rerange_cfg)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
min_value = np.min(original_img)
|
||||
max_value = np.max(original_img)
|
||||
converted_img = (original_img - min_value) / (max_value - min_value) * 255
|
||||
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
|
||||
|
||||
|
||||
def test_CLAHE():
|
||||
# test assertion if clip_limit is None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', clip_limit=None)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if tile_grid_size is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if tile_grid_size is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
transform = dict(type='CLAHE', clip_limit=2)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
converted_img = np.empty(original_img.shape)
|
||||
for i in range(original_img.shape[2]):
|
||||
converted_img[:, :, i] = mmcv.clahe(
|
||||
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
|
||||
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
|
||||
|
||||
|
||||
def test_adjust_gamma():
|
||||
# test assertion if gamma <= 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AdjustGamma', gamma=0)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion if gamma is list
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AdjustGamma', gamma=[1.2])
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test with gamma = 1.2
|
||||
transform = dict(type='AdjustGamma', gamma=1.2)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
inv_gamma = 1.0 / 1.2
|
||||
table = np.array([((i / 255.0)**inv_gamma) * 255
|
||||
for i in np.arange(0, 256)]).astype('uint8')
|
||||
converted_img = mmcv.lut_transform(
|
||||
np.array(original_img, dtype=np.uint8), table)
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'AdjustGamma(gamma={1.2})'
|
||||
|
||||
|
||||
def test_rotate():
|
||||
# test assertion degree should be tuple[float] or float
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
|
||||
TRANSFORMS.build(transform)
|
||||
# test assertion degree should be tuple[float] or float
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
transform = dict(type='RandomRotate', degree=10., prob=1.)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
|
||||
assert str(transform) == f'RandomRotate(' \
|
||||
f'prob={1.}, ' \
|
||||
f'degree=({-10.}, {10.}), ' \
|
||||
f'pad_val={0}, ' \
|
||||
f'seg_pad_val={255}, ' \
|
||||
f'center={None}, ' \
|
||||
f'auto_bound={False})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, _ = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape[:2] == (h, w)
|
||||
|
||||
|
||||
def test_seg_rescale():
|
||||
results = dict()
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
h, w = seg.shape
|
||||
|
||||
transform = dict(type='SegRescale', scale_factor=1. / 2)
|
||||
rescale_module = TRANSFORMS.build(transform)
|
||||
rescale_results = rescale_module(results.copy())
|
||||
assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2)
|
||||
|
||||
transform = dict(type='SegRescale', scale_factor=1)
|
||||
rescale_module = TRANSFORMS.build(transform)
|
||||
rescale_results = rescale_module(results.copy())
|
||||
assert rescale_results['gt_semantic_seg'].shape == (h, w)
|
||||
|
||||
|
||||
def test_mosaic():
|
||||
# test prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMosaic', prob=1.5)
|
||||
TRANSFORMS.build(transform)
|
||||
# test assertion for invalid img_scale
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||
mosaic_module = TRANSFORMS.build(transform)
|
||||
assert 'Mosaic' in repr(mosaic_module)
|
||||
|
||||
# test assertion for invalid mix_results
|
||||
with pytest.raises(AssertionError):
|
||||
mosaic_module(results)
|
||||
|
||||
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
||||
|
||||
results = dict()
|
||||
results['img'] = img[:, :, 0]
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
|
||||
mosaic_module = TRANSFORMS.build(transform)
|
||||
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == img.shape[:2]
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||
mosaic_module = TRANSFORMS.build(transform)
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
||||
|
||||
|
||||
def test_cutout():
|
||||
# test prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
|
||||
TRANSFORMS.build(transform)
|
||||
# test n_holes
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
|
||||
TRANSFORMS.build(transform)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=(3, 4, 5),
|
||||
cutout_shape=(8, 8))
|
||||
TRANSFORMS.build(transform)
|
||||
# test cutout_shape and cutout_ratio
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
|
||||
TRANSFORMS.build(transform)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
|
||||
TRANSFORMS.build(transform)
|
||||
# either of cutout_shape and cutout_ratio should be given
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
|
||||
TRANSFORMS.build(transform)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=1,
|
||||
cutout_shape=(2, 2),
|
||||
cutout_ratio=(0.4, 0.4))
|
||||
TRANSFORMS.build(transform)
|
||||
# test seg_fill_in
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=1,
|
||||
cutout_shape=(8, 8),
|
||||
seg_fill_in='a')
|
||||
TRANSFORMS.build(transform)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=1,
|
||||
cutout_shape=(8, 8),
|
||||
seg_fill_in=256)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['pad_shape'] = img.shape
|
||||
results['img_fields'] = ['img']
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
|
||||
cutout_module = TRANSFORMS.build(transform)
|
||||
assert 'cutout_shape' in repr(cutout_module)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() < img.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||
cutout_module = TRANSFORMS.build(transform)
|
||||
assert 'cutout_ratio' in repr(cutout_module)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() < img.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||
cutout_module = TRANSFORMS.build(transform)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() == img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=1,
|
||||
n_holes=(2, 4),
|
||||
cutout_shape=[(10, 10), (15, 15)],
|
||||
fill_in=(255, 255, 255),
|
||||
seg_fill_in=None)
|
||||
cutout_module = TRANSFORMS.build(transform)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() > img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=1,
|
||||
n_holes=1,
|
||||
cutout_ratio=(0.8, 0.8),
|
||||
fill_in=(255, 255, 255),
|
||||
seg_fill_in=255)
|
||||
cutout_module = TRANSFORMS.build(transform)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() > img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
|
||||
|
||||
|
||||
def test_resize_to_multiple():
|
||||
transform = dict(type='ResizeToMultiple', size_divisor=32)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
|
||||
img = np.random.randn(213, 232, 3)
|
||||
seg = np.random.randint(0, 19, (213, 232))
|
||||
results = dict()
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['pad_shape'] = img.shape
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (224, 256, 3)
|
||||
assert results['gt_semantic_seg'].shape == (224, 256)
|
||||
assert results['img_shape'] == (224, 256, 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user