add increasing in solarize and posterize

pull/249/head
lixinran 2021-05-12 10:49:27 +08:00
parent 9e93feabac
commit 128af36e9b
2 changed files with 41 additions and 7 deletions

View File

@ -510,18 +510,23 @@ class Solarize(object):
Args:
thr (int | float): The threshold above which the pixels value will be
inverted.
inverted when incresing is set to False.
prob (float): The probability for solarizing therefore should be in
range [0, 1]. Defaults to 0.5.
increasing (bool): When setting to True, the meaning of thr is
8 - actual thr.
"""
def __init__(self, thr, prob=0.5):
def __init__(self, thr, prob=0.5, increasing=False):
assert isinstance(thr, (int, float)), 'The thr type must '\
f'be int or float, but got {type(thr)} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.thr = thr
if increasing:
self.thr = 256 - thr
else:
self.thr = thr
self.prob = prob
def __call__(self, results):
@ -588,18 +593,23 @@ class Posterize(object):
"""Posterize images (reduce the number of bits for each color channel).
Args:
bits (int | float): Number of bits for each pixel in the output img,
which should be less or equal to 8.
bits (int | float): Number of bits for each pixel in the output img
when increasing is False, which should be less or equal to 8.
prob (float): The probability for posterizing therefore should be in
range [0, 1]. Defaults to 0.5.
increasing (bool): When setting to True, the meaning of bits is
8 - actual number of bits.
"""
def __init__(self, bits, prob=0.5):
def __init__(self, bits, prob=0.5, increasing=False):
assert bits <= 8, f'The bits must be less than 8, got {bits} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.bits = int(bits)
if increasing:
self.bits = 8 - int(bits)
else:
self.bits = int(bits)
self.prob = prob
def __call__(self, results):

View File

@ -727,6 +727,18 @@ def test_solarize():
assert (results['img'] == img_solarized).all()
assert (results['img'] == results['img2']).all()
# test case when thr=156
results = construct_toy_data_photometric()
transform = dict(type='Solarize', thr=156, prob=1., increasing=True)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_solarized = np.array([[0, 127, 0], [1, 128, 1], [2, 126, 2]],
dtype=np.uint8)
img_solarized = np.stack([img_solarized, img_solarized, img_solarized],
axis=-1)
assert (results['img'] == img_solarized).all()
assert (results['img'] == results['img2']).all()
def test_solarize_add():
# test assertion for invalid type of magnitude
@ -822,6 +834,18 @@ def test_posterize():
assert (results['img'] == img_posterized).all()
assert (results['img'] == results['img2']).all()
# test case when bits=5, incresing= True
results = construct_toy_data_photometric()
transform = dict(type='Posterize', bits=5, prob=1., increasing=True)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_posterized = np.array([[0, 128, 224], [0, 96, 224], [0, 128, 224]],
dtype=np.uint8)
img_posterized = np.stack([img_posterized, img_posterized, img_posterized],
axis=-1)
assert (results['img'] == img_posterized).all()
assert (results['img'] == results['img2']).all()
def test_contrast(nb_rand_test=100):