[Feature] Add Rerange transform (#228)

* add rerange transform

* restore

* delete rerange config

* delete rerange config hint

* add min < max assert

* restore
pull/246/head
yamengxi 2020-11-10 18:30:07 +08:00 committed by GitHub
parent 7c68bca595
commit 2bd51e60fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 0 deletions

View File

@ -392,6 +392,53 @@ class Normalize(object):
return repr_str
@PIPELINES.register_module()
class Rerange(object):
"""Rerange the image pixel value.
Args:
min_value (float or int): Minimum value of the reranged image.
Default: 0.
max_value (float or int): Maximum value of the reranged image.
Default: 255.
"""
def __init__(self, min_value=0, max_value=255):
assert isinstance(min_value, float) or isinstance(min_value, int)
assert isinstance(max_value, float) or isinstance(max_value, int)
assert min_value < max_value
self.min_value = min_value
self.max_value = max_value
def __call__(self, results):
"""Call function to rerange images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
img = results['img']
img_min_value = np.min(img)
img_max_value = np.max(img)
assert img_min_value < img_max_value
# rerange to [0, 1]
img = (img - img_min_value) / (img_max_value - img_min_value)
# rerange to [min_value, max_value]
img = img * (self.max_value - self.min_value) + self.min_value
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
return repr_str
@PIPELINES.register_module()
class RandomCrop(object):
"""Random crop the image & seg.

View File

@ -330,6 +330,49 @@ def test_rgb2gray():
assert results['ori_shape'] == (h, w, c)
def test_rerange():
# test assertion if min_value or max_value is illegal
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=[0], max_value=[255])
build_from_cfg(transform, PIPELINES)
# test assertion if min_value >= max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=1, max_value=1)
build_from_cfg(transform, PIPELINES)
# test assertion if img_min_value == img_max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=0, max_value=1)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
results['img'] = np.array([[1, 1], [1, 1]])
transform(results)
img_rerange_cfg = dict()
transform = dict(type='Rerange', **img_rerange_cfg)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
min_value = np.min(original_img)
max_value = np.max(original_img)
converted_img = (original_img - min_value) / (max_value - min_value) * 255
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
def test_seg_rescale():
results = dict()
seg = np.array(