mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] add AdjustGamma transform (#232)
* add AdjustGamma transform * restore * change cv2 to mmcv * simplify AdjustGamma * fix syntax error * modify * fix syntax error * change mmcv version to 1.3.0 * fix lut function name error * fix syntax error * fix range
This commit is contained in:
parent
1530af6533
commit
e8d643fe3a
@ -3,7 +3,7 @@ import mmcv
|
|||||||
from .version import __version__, version_info
|
from .version import __version__, version_info
|
||||||
|
|
||||||
MMCV_MIN = '1.1.4'
|
MMCV_MIN = '1.1.4'
|
||||||
MMCV_MAX = '1.2.0'
|
MMCV_MAX = '1.3.0'
|
||||||
|
|
||||||
|
|
||||||
def digit_version(version_str):
|
def digit_version(version_str):
|
||||||
|
@ -650,6 +650,42 @@ class RGB2Gray(object):
|
|||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class AdjustGamma(object):
|
||||||
|
"""Using gamma correction to process the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gamma (float or int): Gamma value used in gamma correction.
|
||||||
|
Default: 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gamma=1.0):
|
||||||
|
assert isinstance(gamma, float) or isinstance(gamma, int)
|
||||||
|
assert gamma > 0
|
||||||
|
self.gamma = gamma
|
||||||
|
inv_gamma = 1.0 / gamma
|
||||||
|
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
||||||
|
for i in np.arange(256)]).astype('uint8')
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call function to process the image with gamma correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Processed results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
results['img'] = mmcv.lut_transform(
|
||||||
|
np.array(results['img'], dtype=np.uint8), self.table)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + f'(gamma={self.gamma})'
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class SegRescale(object):
|
class SegRescale(object):
|
||||||
"""Rescale semantic segmentation maps.
|
"""Rescale semantic segmentation maps.
|
||||||
|
@ -330,6 +330,42 @@ def test_rgb2gray():
|
|||||||
assert results['ori_shape'] == (h, w, c)
|
assert results['ori_shape'] == (h, w, c)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adjust_gamma():
|
||||||
|
# test assertion if gamma <= 0
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='AdjustGamma', gamma=0)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test assertion if gamma is list
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='AdjustGamma', gamma=[1.2])
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test with gamma = 1.2
|
||||||
|
transform = dict(type='AdjustGamma', gamma=1.2)
|
||||||
|
transform = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
original_img = copy.deepcopy(img)
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
# Set initial values for default meta_keys
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
|
||||||
|
inv_gamma = 1.0 / 1.2
|
||||||
|
table = np.array([((i / 255.0)**inv_gamma) * 255
|
||||||
|
for i in np.arange(0, 256)]).astype('uint8')
|
||||||
|
converted_img = mmcv.lut_transform(
|
||||||
|
np.array(original_img, dtype=np.uint8), table)
|
||||||
|
assert np.allclose(results['img'], converted_img)
|
||||||
|
assert str(transform) == f'AdjustGamma(gamma={1.2})'
|
||||||
|
|
||||||
|
|
||||||
def test_rerange():
|
def test_rerange():
|
||||||
# test assertion if min_value or max_value is illegal
|
# test assertion if min_value or max_value is illegal
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user