mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add P1 DataTransform (#1843)
* [Feature] Add P1 DataTransform * fix unit test error * fix @cache_randomness location
This commit is contained in:
parent
76c5ce1396
commit
ecab73a892
@ -105,8 +105,8 @@ class MultiImageMixDataset:
|
|||||||
transform_type in self._skip_type_keys:
|
transform_type in self._skip_type_keys:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hasattr(transform, 'get_indexes'):
|
if hasattr(transform, 'get_indices'):
|
||||||
indexes = transform.get_indexes(self.dataset)
|
indexes = transform.get_indices(self.dataset)
|
||||||
if not isinstance(indexes, collections.abc.Sequence):
|
if not isinstance(indexes, collections.abc.Sequence):
|
||||||
indexes = [indexes]
|
indexes = [indexes]
|
||||||
mix_results = [
|
mix_results = [
|
||||||
|
@ -9,13 +9,25 @@ from mmcv.transforms.utils import cache_randomness
|
|||||||
from mmcv.utils import is_tuple_of
|
from mmcv.utils import is_tuple_of
|
||||||
from numpy import random
|
from numpy import random
|
||||||
|
|
||||||
|
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ResizeToMultiple(object):
|
class ResizeToMultiple(BaseTransform):
|
||||||
"""Resize images & seg to multiple of divisor.
|
"""Resize images & seg to multiple of divisor.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- img_shape
|
||||||
|
- pad_shape
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
size_divisor (int): images and gt seg maps need to resize to multiple
|
size_divisor (int): images and gt seg maps need to resize to multiple
|
||||||
of size_divisor. Default: 32.
|
of size_divisor. Default: 32.
|
||||||
@ -27,7 +39,7 @@ class ResizeToMultiple(object):
|
|||||||
self.size_divisor = size_divisor
|
self.size_divisor = size_divisor
|
||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
|
|
||||||
def __call__(self, results):
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to resize images, semantic segmentation map to
|
"""Call function to resize images, semantic segmentation map to
|
||||||
multiple of size divisor.
|
multiple of size divisor.
|
||||||
|
|
||||||
@ -70,9 +82,17 @@ class ResizeToMultiple(object):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class Rerange(object):
|
class Rerange(BaseTransform):
|
||||||
"""Rerange the image pixel value.
|
"""Rerange the image pixel value.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_value (float or int): Minimum value of the reranged image.
|
min_value (float or int): Minimum value of the reranged image.
|
||||||
Default: 0.
|
Default: 0.
|
||||||
@ -87,7 +107,7 @@ class Rerange(object):
|
|||||||
self.min_value = min_value
|
self.min_value = min_value
|
||||||
self.max_value = max_value
|
self.max_value = max_value
|
||||||
|
|
||||||
def __call__(self, results):
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to rerange images.
|
"""Call function to rerange images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -116,12 +136,20 @@ class Rerange(object):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class CLAHE(object):
|
class CLAHE(BaseTransform):
|
||||||
"""Use CLAHE method to process the image.
|
"""Use CLAHE method to process the image.
|
||||||
|
|
||||||
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
||||||
Graphics Gems, 1994:474-485.` for more information.
|
Graphics Gems, 1994:474-485.` for more information.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
||||||
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
||||||
@ -136,7 +164,7 @@ class CLAHE(object):
|
|||||||
assert len(tile_grid_size) == 2
|
assert len(tile_grid_size) == 2
|
||||||
self.tile_grid_size = tile_grid_size
|
self.tile_grid_size = tile_grid_size
|
||||||
|
|
||||||
def __call__(self, results):
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to Use CLAHE method process images.
|
"""Call function to Use CLAHE method process images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -167,13 +195,13 @@ class RandomCrop(BaseTransform):
|
|||||||
Required Keys:
|
Required Keys:
|
||||||
|
|
||||||
- img
|
- img
|
||||||
- gt_semantic_seg
|
- gt_seg_map
|
||||||
|
|
||||||
Modified Keys:
|
Modified Keys:
|
||||||
|
|
||||||
- img
|
- img
|
||||||
- img_shape
|
- img_shape
|
||||||
- gt_semantic_seg
|
- gt_seg_map
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -293,9 +321,19 @@ class RandomCrop(BaseTransform):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class RandomRotate(object):
|
class RandomRotate(BaseTransform):
|
||||||
"""Rotate the image & seg.
|
"""Rotate the image & seg.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prob (float): The rotation probability.
|
prob (float): The rotation probability.
|
||||||
degree (float, tuple[float]): Range of degrees to select from. If
|
degree (float, tuple[float]): Range of degrees to select from. If
|
||||||
@ -332,7 +370,12 @@ class RandomRotate(object):
|
|||||||
self.center = center
|
self.center = center
|
||||||
self.auto_bound = auto_bound
|
self.auto_bound = auto_bound
|
||||||
|
|
||||||
def __call__(self, results):
|
@cache_randomness
|
||||||
|
def generate_degree(self):
|
||||||
|
return np.random.rand() < self.prob, np.random.uniform(
|
||||||
|
min(*self.degree), max(*self.degree))
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to rotate image, semantic segmentation maps.
|
"""Call function to rotate image, semantic segmentation maps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -342,8 +385,7 @@ class RandomRotate(object):
|
|||||||
dict: Rotated results.
|
dict: Rotated results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rotate = True if np.random.rand() < self.prob else False
|
rotate, degree = self.generate_degree()
|
||||||
degree = np.random.uniform(min(*self.degree), max(*self.degree))
|
|
||||||
if rotate:
|
if rotate:
|
||||||
# rotate image
|
# rotate image
|
||||||
results['img'] = mmcv.imrotate(
|
results['img'] = mmcv.imrotate(
|
||||||
@ -376,9 +418,18 @@ class RandomRotate(object):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class RGB2Gray(object):
|
class RGB2Gray(BaseTransform):
|
||||||
"""Convert RGB image to grayscale image.
|
"""Convert RGB image to grayscale image.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- img_shape
|
||||||
|
|
||||||
This transform calculate the weighted mean of input image channels with
|
This transform calculate the weighted mean of input image channels with
|
||||||
``weights`` and then expand the channels to ``out_channels``. When
|
``weights`` and then expand the channels to ``out_channels``. When
|
||||||
``out_channels`` is None, the number of output channels is the same as
|
``out_channels`` is None, the number of output channels is the same as
|
||||||
@ -399,7 +450,7 @@ class RGB2Gray(object):
|
|||||||
assert isinstance(item, (float, int))
|
assert isinstance(item, (float, int))
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
def __call__(self, results):
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to convert RGB image to grayscale image.
|
"""Call function to convert RGB image to grayscale image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -431,9 +482,17 @@ class RGB2Gray(object):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class AdjustGamma(object):
|
class AdjustGamma(BaseTransform):
|
||||||
"""Using gamma correction to process the image.
|
"""Using gamma correction to process the image.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gamma (float or int): Gamma value used in gamma correction.
|
gamma (float or int): Gamma value used in gamma correction.
|
||||||
Default: 1.0.
|
Default: 1.0.
|
||||||
@ -447,7 +506,7 @@ class AdjustGamma(object):
|
|||||||
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
||||||
for i in np.arange(256)]).astype('uint8')
|
for i in np.arange(256)]).astype('uint8')
|
||||||
|
|
||||||
def __call__(self, results):
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to process the image with gamma correction.
|
"""Call function to process the image with gamma correction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -467,9 +526,17 @@ class AdjustGamma(object):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class SegRescale(object):
|
class SegRescale(BaseTransform):
|
||||||
"""Rescale semantic segmentation maps.
|
"""Rescale semantic segmentation maps.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scale_factor (float): The scale factor of the final output.
|
scale_factor (float): The scale factor of the final output.
|
||||||
"""
|
"""
|
||||||
@ -477,7 +544,7 @@ class SegRescale(object):
|
|||||||
def __init__(self, scale_factor=1):
|
def __init__(self, scale_factor=1):
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
def __call__(self, results):
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to scale the semantic segmentation map.
|
"""Call function to scale the semantic segmentation map.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -667,11 +734,22 @@ class PhotoMetricDistortion(BaseTransform):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class RandomCutOut(object):
|
class RandomCutOut(BaseTransform):
|
||||||
"""CutOut operation.
|
"""CutOut operation.
|
||||||
|
|
||||||
Randomly drop some regions of image used in
|
Randomly drop some regions of image used in
|
||||||
`Cutout <https://arxiv.org/abs/1708.04552>`_.
|
`Cutout <https://arxiv.org/abs/1708.04552>`_.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prob (float): cutout probability.
|
prob (float): cutout probability.
|
||||||
n_holes (int | tuple[int, int]): Number of regions to be dropped.
|
n_holes (int | tuple[int, int]): Number of regions to be dropped.
|
||||||
@ -721,16 +799,38 @@ class RandomCutOut(object):
|
|||||||
if not isinstance(self.candidates, list):
|
if not isinstance(self.candidates, list):
|
||||||
self.candidates = [self.candidates]
|
self.candidates = [self.candidates]
|
||||||
|
|
||||||
def __call__(self, results):
|
@cache_randomness
|
||||||
|
def do_cutout(self):
|
||||||
|
return np.random.rand() < self.prob
|
||||||
|
|
||||||
|
@cache_randomness
|
||||||
|
def generate_patches(self, results):
|
||||||
|
cutout = self.do_cutout()
|
||||||
|
|
||||||
|
h, w, _ = results['img'].shape
|
||||||
|
if cutout:
|
||||||
|
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
|
||||||
|
else:
|
||||||
|
n_holes = 0
|
||||||
|
x1_lst = []
|
||||||
|
y1_lst = []
|
||||||
|
index_lst = []
|
||||||
|
for _ in range(n_holes):
|
||||||
|
x1_lst.append(np.random.randint(0, w))
|
||||||
|
y1_lst.append(np.random.randint(0, h))
|
||||||
|
index_lst.append(np.random.randint(0, len(self.candidates)))
|
||||||
|
return cutout, n_holes, x1_lst, y1_lst, index_lst
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to drop some regions of image."""
|
"""Call function to drop some regions of image."""
|
||||||
cutout = True if np.random.rand() < self.prob else False
|
cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
|
||||||
|
results)
|
||||||
if cutout:
|
if cutout:
|
||||||
h, w, c = results['img'].shape
|
h, w, c = results['img'].shape
|
||||||
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
|
for i in range(n_holes):
|
||||||
for _ in range(n_holes):
|
x1 = x1_lst[i]
|
||||||
x1 = np.random.randint(0, w)
|
y1 = y1_lst[i]
|
||||||
y1 = np.random.randint(0, h)
|
index = index_lst[i]
|
||||||
index = np.random.randint(0, len(self.candidates))
|
|
||||||
if not self.with_ratio:
|
if not self.with_ratio:
|
||||||
cutout_w, cutout_h = self.candidates[index]
|
cutout_w, cutout_h = self.candidates[index]
|
||||||
else:
|
else:
|
||||||
@ -759,7 +859,7 @@ class RandomCutOut(object):
|
|||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class RandomMosaic(object):
|
class RandomMosaic(BaseTransform):
|
||||||
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
|
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
|
||||||
one output image. The output image is composed of the parts from each sub-
|
one output image. The output image is composed of the parts from each sub-
|
||||||
image.
|
image.
|
||||||
@ -789,6 +889,19 @@ class RandomMosaic(object):
|
|||||||
sample another 3 images from the custom dataset.
|
sample another 3 images from the custom dataset.
|
||||||
3. Sub image will be cropped if image is larger than mosaic patch
|
3. Sub image will be cropped if image is larger than mosaic patch
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_seg_map
|
||||||
|
- mix_results
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- img_shape
|
||||||
|
- ori_shape
|
||||||
|
- gt_seg_map
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prob (float): mosaic probability.
|
prob (float): mosaic probability.
|
||||||
img_scale (Sequence[int]): Image size after mosaic pipeline of
|
img_scale (Sequence[int]): Image size after mosaic pipeline of
|
||||||
@ -815,7 +928,11 @@ class RandomMosaic(object):
|
|||||||
self.pad_val = pad_val
|
self.pad_val = pad_val
|
||||||
self.seg_pad_val = seg_pad_val
|
self.seg_pad_val = seg_pad_val
|
||||||
|
|
||||||
def __call__(self, results):
|
@cache_randomness
|
||||||
|
def do_mosaic(self):
|
||||||
|
return np.random.rand() < self.prob
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
"""Call function to make a mosaic of image.
|
"""Call function to make a mosaic of image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -824,13 +941,13 @@ class RandomMosaic(object):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Result dict with mosaic transformed.
|
dict: Result dict with mosaic transformed.
|
||||||
"""
|
"""
|
||||||
mosaic = True if np.random.rand() < self.prob else False
|
mosaic = self.do_mosaic()
|
||||||
if mosaic:
|
if mosaic:
|
||||||
results = self._mosaic_transform_img(results)
|
results = self._mosaic_transform_img(results)
|
||||||
results = self._mosaic_transform_seg(results)
|
results = self._mosaic_transform_seg(results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_indexes(self, dataset):
|
def get_indices(self, dataset: MultiImageMixDataset) -> list:
|
||||||
"""Call function to collect indexes.
|
"""Call function to collect indexes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -843,7 +960,16 @@ class RandomMosaic(object):
|
|||||||
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
|
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
|
||||||
return indexes
|
return indexes
|
||||||
|
|
||||||
def _mosaic_transform_img(self, results):
|
@cache_randomness
|
||||||
|
def generate_mosaic_center(self):
|
||||||
|
# mosaic center x, y
|
||||||
|
center_x = int(
|
||||||
|
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
||||||
|
center_y = int(
|
||||||
|
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
||||||
|
return center_x, center_y
|
||||||
|
|
||||||
|
def _mosaic_transform_img(self, results: dict) -> dict:
|
||||||
"""Mosaic transform function.
|
"""Mosaic transform function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -866,10 +992,7 @@ class RandomMosaic(object):
|
|||||||
dtype=results['img'].dtype)
|
dtype=results['img'].dtype)
|
||||||
|
|
||||||
# mosaic center x, y
|
# mosaic center x, y
|
||||||
self.center_x = int(
|
self.center_x, self.center_y = self.generate_mosaic_center()
|
||||||
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
|
||||||
self.center_y = int(
|
|
||||||
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
|
||||||
center_position = (self.center_x, self.center_y)
|
center_position = (self.center_x, self.center_y)
|
||||||
|
|
||||||
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
||||||
@ -902,7 +1025,7 @@ class RandomMosaic(object):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _mosaic_transform_seg(self, results):
|
def _mosaic_transform_seg(self, results: dict) -> dict:
|
||||||
"""Mosaic transform function for label annotations.
|
"""Mosaic transform function for label annotations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -953,7 +1076,8 @@ class RandomMosaic(object):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
|
def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float],
|
||||||
|
img_shape_wh: Sequence[int]) -> tuple:
|
||||||
"""Calculate global coordinate of mosaic image and local coordinate of
|
"""Calculate global coordinate of mosaic image and local coordinate of
|
||||||
cropped sub-image.
|
cropped sub-image.
|
||||||
|
|
||||||
|
@ -7,11 +7,9 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from mmseg.datasets.transforms import * # noqa
|
||||||
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
from mmseg.utils import register_all_modules
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
|
|
||||||
def test_resize():
|
def test_resize():
|
||||||
@ -233,6 +231,72 @@ def test_random_crop():
|
|||||||
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
|
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rgb2gray():
|
||||||
|
# test assertion out_channels should be greater than 0
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RGB2Gray', out_channels=-1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# test assertion weights should be tuple[float]
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test out_channels is None
|
||||||
|
transform = dict(type='RGB2Gray')
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
assert str(transform) == f'RGB2Gray(' \
|
||||||
|
f'out_channels={None}, ' \
|
||||||
|
f'weights={(0.299, 0.587, 0.114)})'
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
h, w, c = img.shape
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
assert results['img'].shape == (h, w, c)
|
||||||
|
assert results['img_shape'] == (h, w, c)
|
||||||
|
assert results['ori_shape'] == (h, w, c)
|
||||||
|
|
||||||
|
# test out_channels = 2
|
||||||
|
transform = dict(type='RGB2Gray', out_channels=2)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
assert str(transform) == f'RGB2Gray(' \
|
||||||
|
f'out_channels={2}, ' \
|
||||||
|
f'weights={(0.299, 0.587, 0.114)})'
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
h, w, c = img.shape
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
assert results['img'].shape == (h, w, 2)
|
||||||
|
assert results['img_shape'] == (h, w, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_photo_metric_distortion():
|
def test_photo_metric_distortion():
|
||||||
|
|
||||||
results = dict()
|
results = dict()
|
||||||
@ -252,3 +316,366 @@ def test_photo_metric_distortion():
|
|||||||
|
|
||||||
assert (results['gt_semantic_seg'] == seg).all()
|
assert (results['gt_semantic_seg'] == seg).all()
|
||||||
assert results['img_shape'] == img.shape
|
assert results['img_shape'] == img.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerange():
|
||||||
|
# test assertion if min_value or max_value is illegal
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='Rerange', min_value=[0], max_value=[255])
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test assertion if min_value >= max_value
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='Rerange', min_value=1, max_value=1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test assertion if img_min_value == img_max_value
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='Rerange', min_value=0, max_value=1)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
results = dict()
|
||||||
|
results['img'] = np.array([[1, 1], [1, 1]])
|
||||||
|
transform(results)
|
||||||
|
|
||||||
|
img_rerange_cfg = dict()
|
||||||
|
transform = dict(type='Rerange', **img_rerange_cfg)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
original_img = copy.deepcopy(img)
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
|
||||||
|
min_value = np.min(original_img)
|
||||||
|
max_value = np.max(original_img)
|
||||||
|
converted_img = (original_img - min_value) / (max_value - min_value) * 255
|
||||||
|
|
||||||
|
assert np.allclose(results['img'], converted_img)
|
||||||
|
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
|
||||||
|
|
||||||
|
|
||||||
|
def test_CLAHE():
|
||||||
|
# test assertion if clip_limit is None
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='CLAHE', clip_limit=None)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test assertion if tile_grid_size is illegal
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test assertion if tile_grid_size is illegal
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
transform = dict(type='CLAHE', clip_limit=2)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
original_img = copy.deepcopy(img)
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
|
||||||
|
converted_img = np.empty(original_img.shape)
|
||||||
|
for i in range(original_img.shape[2]):
|
||||||
|
converted_img[:, :, i] = mmcv.clahe(
|
||||||
|
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
|
||||||
|
|
||||||
|
assert np.allclose(results['img'], converted_img)
|
||||||
|
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
|
||||||
|
|
||||||
|
|
||||||
|
def test_adjust_gamma():
|
||||||
|
# test assertion if gamma <= 0
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='AdjustGamma', gamma=0)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test assertion if gamma is list
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='AdjustGamma', gamma=[1.2])
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test with gamma = 1.2
|
||||||
|
transform = dict(type='AdjustGamma', gamma=1.2)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
original_img = copy.deepcopy(img)
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
|
||||||
|
inv_gamma = 1.0 / 1.2
|
||||||
|
table = np.array([((i / 255.0)**inv_gamma) * 255
|
||||||
|
for i in np.arange(0, 256)]).astype('uint8')
|
||||||
|
converted_img = mmcv.lut_transform(
|
||||||
|
np.array(original_img, dtype=np.uint8), table)
|
||||||
|
assert np.allclose(results['img'], converted_img)
|
||||||
|
assert str(transform) == f'AdjustGamma(gamma={1.2})'
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate():
|
||||||
|
# test assertion degree should be tuple[float] or float
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# test assertion degree should be tuple[float] or float
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
transform = dict(type='RandomRotate', degree=10., prob=1.)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
assert str(transform) == f'RandomRotate(' \
|
||||||
|
f'prob={1.}, ' \
|
||||||
|
f'degree=({-10.}, {10.}), ' \
|
||||||
|
f'pad_val={0}, ' \
|
||||||
|
f'seg_pad_val={255}, ' \
|
||||||
|
f'center={None}, ' \
|
||||||
|
f'auto_bound={False})'
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
h, w, _ = img.shape
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
assert results['img'].shape[:2] == (h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seg_rescale():
|
||||||
|
results = dict()
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
h, w = seg.shape
|
||||||
|
|
||||||
|
transform = dict(type='SegRescale', scale_factor=1. / 2)
|
||||||
|
rescale_module = TRANSFORMS.build(transform)
|
||||||
|
rescale_results = rescale_module(results.copy())
|
||||||
|
assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2)
|
||||||
|
|
||||||
|
transform = dict(type='SegRescale', scale_factor=1)
|
||||||
|
rescale_module = TRANSFORMS.build(transform)
|
||||||
|
rescale_results = rescale_module(results.copy())
|
||||||
|
assert rescale_results['gt_semantic_seg'].shape == (h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mosaic():
|
||||||
|
# test prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomMosaic', prob=1.5)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# test assertion for invalid img_scale
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
|
||||||
|
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||||
|
mosaic_module = TRANSFORMS.build(transform)
|
||||||
|
assert 'Mosaic' in repr(mosaic_module)
|
||||||
|
|
||||||
|
# test assertion for invalid mix_results
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
mosaic_module(results)
|
||||||
|
|
||||||
|
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||||
|
results = mosaic_module(results)
|
||||||
|
assert results['img'].shape[:2] == (20, 24)
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
results['img'] = img[:, :, 0]
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
|
||||||
|
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
|
||||||
|
mosaic_module = TRANSFORMS.build(transform)
|
||||||
|
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||||
|
results = mosaic_module(results)
|
||||||
|
assert results['img'].shape[:2] == img.shape[:2]
|
||||||
|
|
||||||
|
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||||
|
mosaic_module = TRANSFORMS.build(transform)
|
||||||
|
results = mosaic_module(results)
|
||||||
|
assert results['img'].shape[:2] == (20, 24)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cutout():
|
||||||
|
# test prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# test n_holes
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=(3, 4, 5),
|
||||||
|
cutout_shape=(8, 8))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# test cutout_shape and cutout_ratio
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# either of cutout_shape and cutout_ratio should be given
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_shape=(2, 2),
|
||||||
|
cutout_ratio=(0.4, 0.4))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
# test seg_fill_in
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_shape=(8, 8),
|
||||||
|
seg_fill_in='a')
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_shape=(8, 8),
|
||||||
|
seg_fill_in=256)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['img_fields'] = ['img']
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
|
||||||
|
cutout_module = TRANSFORMS.build(transform)
|
||||||
|
assert 'cutout_shape' in repr(cutout_module)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() < img.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||||
|
cutout_module = TRANSFORMS.build(transform)
|
||||||
|
assert 'cutout_ratio' in repr(cutout_module)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() < img.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||||
|
cutout_module = TRANSFORMS.build(transform)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() == img.sum()
|
||||||
|
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=1,
|
||||||
|
n_holes=(2, 4),
|
||||||
|
cutout_shape=[(10, 10), (15, 15)],
|
||||||
|
fill_in=(255, 255, 255),
|
||||||
|
seg_fill_in=None)
|
||||||
|
cutout_module = TRANSFORMS.build(transform)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() > img.sum()
|
||||||
|
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=1,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_ratio=(0.8, 0.8),
|
||||||
|
fill_in=(255, 255, 255),
|
||||||
|
seg_fill_in=255)
|
||||||
|
cutout_module = TRANSFORMS.build(transform)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() > img.sum()
|
||||||
|
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_to_multiple():
|
||||||
|
transform = dict(type='ResizeToMultiple', size_divisor=32)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
img = np.random.randn(213, 232, 3)
|
||||||
|
seg = np.random.randint(0, 19, (213, 232))
|
||||||
|
results = dict()
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
assert results['img'].shape == (224, 256, 3)
|
||||||
|
assert results['gt_semantic_seg'].shape == (224, 256)
|
||||||
|
assert results['img_shape'] == (224, 256, 3)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user