mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] add AdjustGamma transform (#232)
* add AdjustGamma transform * restore * change cv2 to mmcv * simplify AdjustGamma * fix syntax error * modify * fix syntax error * change mmcv version to 1.3.0 * fix lut function name error * fix syntax error * fix range
This commit is contained in:
parent
1530af6533
commit
e8d643fe3a
@ -3,7 +3,7 @@ import mmcv
|
||||
from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '1.1.4'
|
||||
MMCV_MAX = '1.2.0'
|
||||
MMCV_MAX = '1.3.0'
|
||||
|
||||
|
||||
def digit_version(version_str):
|
||||
|
@ -650,6 +650,42 @@ class RGB2Gray(object):
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class AdjustGamma(object):
|
||||
"""Using gamma correction to process the image.
|
||||
|
||||
Args:
|
||||
gamma (float or int): Gamma value used in gamma correction.
|
||||
Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma=1.0):
|
||||
assert isinstance(gamma, float) or isinstance(gamma, int)
|
||||
assert gamma > 0
|
||||
self.gamma = gamma
|
||||
inv_gamma = 1.0 / gamma
|
||||
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
||||
for i in np.arange(256)]).astype('uint8')
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to process the image with gamma correction.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Processed results.
|
||||
"""
|
||||
|
||||
results['img'] = mmcv.lut_transform(
|
||||
np.array(results['img'], dtype=np.uint8), self.table)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(gamma={self.gamma})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class SegRescale(object):
|
||||
"""Rescale semantic segmentation maps.
|
||||
|
@ -330,6 +330,42 @@ def test_rgb2gray():
|
||||
assert results['ori_shape'] == (h, w, c)
|
||||
|
||||
|
||||
def test_adjust_gamma():
|
||||
# test assertion if gamma <= 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AdjustGamma', gamma=0)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if gamma is list
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AdjustGamma', gamma=[1.2])
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test with gamma = 1.2
|
||||
transform = dict(type='AdjustGamma', gamma=1.2)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
inv_gamma = 1.0 / 1.2
|
||||
table = np.array([((i / 255.0)**inv_gamma) * 255
|
||||
for i in np.arange(0, 256)]).astype('uint8')
|
||||
converted_img = mmcv.lut_transform(
|
||||
np.array(original_img, dtype=np.uint8), table)
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'AdjustGamma(gamma={1.2})'
|
||||
|
||||
|
||||
def test_rerange():
|
||||
# test assertion if min_value or max_value is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
|
Loading…
x
Reference in New Issue
Block a user