[Feature] Support AutoAug, AutoContrast, Equalize, Contrast, Brightness and Sharpness (#179)
* add AutoContrast, Equalize, Contrast, Brightness and Sharpness pipelines * add ImageNetPolicy * add configs * add unittest * remove config * rerun CI * rerun CI * [Fix] Update pip install mmcv command in ci (#187) * update pip install mmcv command in ci * update pip install mmcv command in ci * fix ci * fix cipull/194/head
parent
d412841e08
commit
93cd960466
|
@ -1,5 +1,7 @@
|
|||
from .auto_augment import (ColorTransform, Invert, Posterize, Rotate, Shear,
|
||||
Solarize, Translate)
|
||||
from .auto_augment import (AutoAugment, AutoContrast, Brightness,
|
||||
ColorTransform, Contrast, Equalize, Invert,
|
||||
Posterize, Rotate, Sharpness, Shear, Solarize,
|
||||
Translate)
|
||||
from .compose import Compose
|
||||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
|
@ -12,5 +14,6 @@ __all__ = [
|
|||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
||||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||
'ColorTransform', 'Solarize', 'Posterize'
|
||||
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment'
|
||||
]
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import copy
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from ..builder import PIPELINES
|
||||
from .compose import Compose
|
||||
|
||||
|
||||
def random_negative(value, random_negative_prob):
|
||||
|
@ -9,6 +12,44 @@ def random_negative(value, random_negative_prob):
|
|||
return -value if np.random.rand() < random_negative_prob else value
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class AutoAugment(object):
|
||||
"""Auto augmentation.
|
||||
This data augmentation is proposed in `AutoAugment: Learning Augmentation
|
||||
Policies from Data <https://arxiv.org/abs/1805.09501>`_.
|
||||
|
||||
Args:
|
||||
policies (list[list[dict]]): The policies of auto augmentation. Each
|
||||
policy in ``policies`` is a specific augmentation policy, and is
|
||||
composed by several augmentations (dict). When AutoAugment is
|
||||
called, a random policy in ``policies`` will be selected to
|
||||
augment images.
|
||||
"""
|
||||
|
||||
def __init__(self, policies):
|
||||
assert isinstance(policies, list) and len(policies) > 0, \
|
||||
'Policies must be a non-empty list.'
|
||||
for policy in policies:
|
||||
assert isinstance(policy, list) and len(policy) > 0, \
|
||||
'Each policy in policies must be a non-empty list.'
|
||||
for augment in policy:
|
||||
assert isinstance(augment, dict) and 'type' in augment, \
|
||||
'Each specific augmentation must be a dict with key' \
|
||||
' "type".'
|
||||
|
||||
self.policies = copy.deepcopy(policies)
|
||||
self.sub_policy = [Compose(policy) for policy in self.policies]
|
||||
|
||||
def __call__(self, results):
|
||||
sub_policy = np.random.choice(self.sub_policy)
|
||||
return sub_policy(results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(policies={self.policies})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Shear(object):
|
||||
"""Shear images.
|
||||
|
@ -261,6 +302,36 @@ class Rotate(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class AutoContrast(object):
|
||||
"""Auto adjust image contrast.
|
||||
|
||||
Args:
|
||||
prob (float): The probability for performing invert therefore should
|
||||
be in range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, prob=0.5):
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_contrasted = mmcv.auto_contrast(img)
|
||||
results[key] = img_contrasted.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Invert(object):
|
||||
"""Invert images.
|
||||
|
@ -291,9 +362,157 @@ class Invert(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Equalize(object):
|
||||
"""Equalize the image histogram.
|
||||
|
||||
Args:
|
||||
prob (float): The probability for performing invert therefore should
|
||||
be in range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, prob=0.5):
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_equalized = mmcv.imequalize(img)
|
||||
results[key] = img_equalized.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Solarize(object):
|
||||
"""Solarize images (invert all pixel values above a threshold).
|
||||
|
||||
Args:
|
||||
thr (int | float): The threshold above which the pixels value will be
|
||||
inverted.
|
||||
prob (float): The probability for solarizing therefore should be in
|
||||
range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, thr, prob=0.5):
|
||||
assert isinstance(thr, (int, float)), 'The thr type must '\
|
||||
f'be int or float, but got {type(thr)} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.thr = thr
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_solarized = mmcv.solarize(img, thr=self.thr)
|
||||
results[key] = img_solarized.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(thr={self.thr}, '
|
||||
repr_str += f'prob={self.prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Posterize(object):
|
||||
"""Posterize images (reduce the number of bits for each color channel).
|
||||
|
||||
Args:
|
||||
bits (int): Number of bits for each pixel in the output img, which
|
||||
should be less or equal to 8.
|
||||
prob (float): The probability for posterizing therefore should be in
|
||||
range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, bits, prob=0.5):
|
||||
assert isinstance(bits, int), 'The bits type must be int, '\
|
||||
f'but got {type(bits)} instead.'
|
||||
assert bits <= 8, f'The bits must be less than 8, got {bits} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.bits = bits
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_posterized = mmcv.posterize(img, bits=self.bits)
|
||||
results[key] = img_posterized.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(bits={self.bits}, '
|
||||
repr_str += f'prob={self.prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Contrast(object):
|
||||
"""Adjust images contrast.
|
||||
|
||||
Args:
|
||||
magnitude (int | float): The magnitude used for adjusting contrast. A
|
||||
positive magnitude would enhance the contrast and a negative
|
||||
magnitude would make the image grayer. A magnitude=0 gives the
|
||||
origin img.
|
||||
prob (float): The probability for performing contrast adjusting
|
||||
therefore should be in range [0, 1]. Defaults to 0.5.
|
||||
random_negative_prob (float): The probability that turns the magnitude
|
||||
negative, which should be in range [0,1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5):
|
||||
assert isinstance(magnitude, (int, float)), 'The magnitude type must '\
|
||||
f'be int or float, but got {type(magnitude)} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \
|
||||
f'should be in range [0,1], got {random_negative_prob} instead.'
|
||||
|
||||
self.magnitude = magnitude
|
||||
self.prob = prob
|
||||
self.random_negative_prob = random_negative_prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
magnitude = random_negative(self.magnitude, self.random_negative_prob)
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude)
|
||||
results[key] = img_contrasted.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(magnitude={self.magnitude}, '
|
||||
repr_str += f'prob={self.prob}, '
|
||||
repr_str += f'random_negative_prob={self.random_negative_prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ColorTransform(object):
|
||||
"""Adjust the color balance of images.
|
||||
"""Adjust images color balance.
|
||||
|
||||
Args:
|
||||
magnitude (int | float): The magnitude used for color transform. A
|
||||
|
@ -336,71 +555,90 @@ class ColorTransform(object):
|
|||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Solarize(object):
|
||||
"""Solarize an image (invert all pixel values above a threshold).
|
||||
class Brightness(object):
|
||||
"""Adjust images brightness.
|
||||
|
||||
Args:
|
||||
thr (int | float): The threshold above which the pixels value will be
|
||||
inverted.
|
||||
prob (float): The probability for solarizing therefore should be in
|
||||
range [0, 1]. Defaults to 0.5.
|
||||
magnitude (int | float): The magnitude used for adjusting brightness. A
|
||||
positive magnitude would enhance the brightness and a negative
|
||||
magnitude would make the image darker. A magnitude=0 gives the
|
||||
origin img.
|
||||
prob (float): The probability for performing contrast adjusting
|
||||
therefore should be in range [0, 1]. Defaults to 0.5.
|
||||
random_negative_prob (float): The probability that turns the magnitude
|
||||
negative, which should be in range [0,1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, thr, prob=0.5):
|
||||
assert isinstance(thr, (int, float)), 'The thr type must '\
|
||||
f'be int or float, but got {type(thr)} instead.'
|
||||
def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5):
|
||||
assert isinstance(magnitude, (int, float)), 'The magnitude type must '\
|
||||
f'be int or float, but got {type(magnitude)} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \
|
||||
f'should be in range [0,1], got {random_negative_prob} instead.'
|
||||
|
||||
self.thr = thr
|
||||
self.magnitude = magnitude
|
||||
self.prob = prob
|
||||
self.random_negative_prob = random_negative_prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
magnitude = random_negative(self.magnitude, self.random_negative_prob)
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_solarized = mmcv.solarize(img, thr=self.thr)
|
||||
results[key] = img_solarized.astype(img.dtype)
|
||||
img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude)
|
||||
results[key] = img_brightened.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(thr={self.thr}, '
|
||||
repr_str += f'prob={self.prob})'
|
||||
repr_str += f'(magnitude={self.magnitude}, '
|
||||
repr_str += f'prob={self.prob}, '
|
||||
repr_str += f'random_negative_prob={self.random_negative_prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Posterize(object):
|
||||
"""Posterize an image (reduce the number of bits for each color channel).
|
||||
class Sharpness(object):
|
||||
"""Adjust images sharpness.
|
||||
|
||||
Args:
|
||||
bits (int): Number of bits for each pixel in the output img, which
|
||||
should be less or equal to 8.
|
||||
prob (float): The probability for posterizing therefore should be in
|
||||
range [0, 1]. Defaults to 0.5.
|
||||
magnitude (int | float): The magnitude used for adjusting sharpness. A
|
||||
positive magnitude would enhance the sharpness and a negative
|
||||
magnitude would make the image bulr. A magnitude=0 gives the
|
||||
origin img.
|
||||
prob (float): The probability for performing contrast adjusting
|
||||
therefore should be in range [0, 1]. Defaults to 0.5.
|
||||
random_negative_prob (float): The probability that turns the magnitude
|
||||
negative, which should be in range [0,1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, bits, prob=0.5):
|
||||
assert isinstance(bits, int), 'The bits type must be int, '\
|
||||
f'but got {type(bits)} instead.'
|
||||
assert bits <= 8, f'The bits must be less than 8, got {bits} instead.'
|
||||
def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5):
|
||||
assert isinstance(magnitude, (int, float)), 'The magnitude type must '\
|
||||
f'be int or float, but got {type(magnitude)} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \
|
||||
f'should be in range [0,1], got {random_negative_prob} instead.'
|
||||
|
||||
self.bits = bits
|
||||
self.magnitude = magnitude
|
||||
self.prob = prob
|
||||
self.random_negative_prob = random_negative_prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
magnitude = random_negative(self.magnitude, self.random_negative_prob)
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_posterized = mmcv.posterize(img, bits=self.bits)
|
||||
results[key] = img_posterized.astype(img.dtype)
|
||||
img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude)
|
||||
results[key] = img_sharpened.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(bits={self.bits}, '
|
||||
repr_str += f'prob={self.prob})'
|
||||
repr_str += f'(magnitude={self.magnitude}, '
|
||||
repr_str += f'prob={self.prob}, '
|
||||
repr_str += f'random_negative_prob={self.random_negative_prob})'
|
||||
return repr_str
|
||||
|
|
|
@ -122,6 +122,14 @@ def test_shear():
|
|||
sheared_img = np.stack([sheared_img, sheared_img, sheared_img], axis=-1)
|
||||
assert (results['img'] == sheared_img).all()
|
||||
|
||||
# test auto aug with shear
|
||||
results = construct_toy_data()
|
||||
policies = [[transform]]
|
||||
autoaug = dict(type='AutoAugment', policies=policies)
|
||||
pipeline = build_from_cfg(autoaug, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == sheared_img).all()
|
||||
|
||||
|
||||
def test_translate():
|
||||
# test assertion for invalid type of magnitude
|
||||
|
@ -326,6 +334,34 @@ def test_rotate():
|
|||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_auto_contrast():
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AutoContrast', prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when prob=0, therefore no auto_contrast
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='AutoContrast', prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=1
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='AutoContrast', prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
auto_contrasted_img = np.array(
|
||||
[[0, 23, 46, 69], [92, 115, 139, 162], [185, 208, 231, 255]],
|
||||
dtype=np.uint8)
|
||||
auto_contrasted_img = np.stack(
|
||||
[auto_contrasted_img, auto_contrasted_img, auto_contrasted_img],
|
||||
axis=-1)
|
||||
assert (results['img'] == auto_contrasted_img).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_invert():
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -353,70 +389,37 @@ def test_invert():
|
|||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_color_transform():
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='ColorTransform', magnitude=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
def test_equalize(nb_rand_test=100):
|
||||
|
||||
def _imequalize(img):
|
||||
# equalize the image using PIL.ImageOps.equalize
|
||||
from PIL import ImageOps, Image
|
||||
img = Image.fromarray(img)
|
||||
equalized_img = np.asarray(ImageOps.equalize(img))
|
||||
return equalized_img
|
||||
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='ColorTransform', magnitude=0.5, prob=100)
|
||||
transform = dict(type='Equalize', prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of random_negative_prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=0.5, random_negative_prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when magnitude=0, therefore no color transform
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='ColorTransform', magnitude=0., prob=1.)
|
||||
# test case when prob=0, therefore no equalize
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Equalize', prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=0, therefore no color transform
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='ColorTransform', magnitude=1., prob=0.)
|
||||
# test case when prob=1 with randomly sampled image.
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Equalize', prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when magnitude=-1, therefore got gray img
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=-1., prob=1., random_negative_prob=0)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_gray = mmcv.bgr2gray(results['ori_img'])
|
||||
img_gray = np.stack([img_gray, img_gray, img_gray], axis=-1)
|
||||
assert (results['img'] == img_gray).all()
|
||||
|
||||
# test case when magnitude=0.5
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=.5, prob=1., random_negative_prob=0)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_r = np.round(
|
||||
np.clip((results['ori_img'] * 0.5 + img_gray * 0.5), 0,
|
||||
255)).astype(results['ori_img'].dtype)
|
||||
assert (results['img'] == img_r).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
# test case when magnitude=0.3, random_negative_prob=1
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=.3, prob=1., random_negative_prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_r = np.round(
|
||||
np.clip((results['ori_img'] * 0.7 + img_gray * 0.3), 0,
|
||||
255)).astype(results['ori_img'].dtype)
|
||||
assert (results['img'] == img_r).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
for _ in range(nb_rand_test):
|
||||
img = np.clip(np.random.normal(0, 1, (1000, 1200, 3)) * 260, 0,
|
||||
255).astype(np.uint8)
|
||||
results['img'] = img
|
||||
results = pipeline(copy.deepcopy(results))
|
||||
assert (results['img'] == _imequalize(img)).all()
|
||||
|
||||
|
||||
def test_solarize():
|
||||
|
@ -512,3 +515,259 @@ def test_posterize():
|
|||
axis=-1)
|
||||
assert (results['img'] == img_posterized).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_contrast(nb_rand_test=100):
|
||||
|
||||
def _adjust_contrast(img, factor):
|
||||
from PIL.ImageEnhance import Contrast
|
||||
from PIL import Image
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
# convert from BGR to RGB
|
||||
img = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
contrasted_img = Contrast(img).enhance(factor)
|
||||
# convert from RGB to BGR
|
||||
return np.asarray(contrasted_img)[..., ::-1]
|
||||
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Contrast', magnitude=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Contrast', magnitude=0.5, prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of random_negative_prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='Contrast', magnitude=0.5, random_negative_prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when magnitude=0, therefore no adjusting contrast
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='Contrast', magnitude=0., prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=0, therefore no adjusting contrast
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='Contrast', magnitude=1., prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=1 with randomly sampled image.
|
||||
results = construct_toy_data()
|
||||
for _ in range(nb_rand_test):
|
||||
magnitude = np.random.uniform() * np.random.choice([-1, 1])
|
||||
transform = dict(
|
||||
type='Contrast',
|
||||
magnitude=magnitude,
|
||||
prob=1.,
|
||||
random_negative_prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0,
|
||||
255).astype(np.uint8)
|
||||
results['img'] = img
|
||||
results = pipeline(copy.deepcopy(results))
|
||||
# Note the gap (less_equal 1) between PIL.ImageEnhance.Contrast
|
||||
# and mmcv.adjust_contrast comes from the gap that converts from
|
||||
# a color image to gray image using mmcv or PIL.
|
||||
np.testing.assert_allclose(
|
||||
results['img'],
|
||||
_adjust_contrast(img, 1 + magnitude),
|
||||
rtol=0,
|
||||
atol=1)
|
||||
|
||||
|
||||
def test_color_transform():
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='ColorTransform', magnitude=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='ColorTransform', magnitude=0.5, prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of random_negative_prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=0.5, random_negative_prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when magnitude=0, therefore no color transform
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='ColorTransform', magnitude=0., prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=0, therefore no color transform
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='ColorTransform', magnitude=1., prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when magnitude=-1, therefore got gray img
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=-1., prob=1., random_negative_prob=0)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_gray = mmcv.bgr2gray(results['ori_img'])
|
||||
img_gray = np.stack([img_gray, img_gray, img_gray], axis=-1)
|
||||
assert (results['img'] == img_gray).all()
|
||||
|
||||
# test case when magnitude=0.5
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=.5, prob=1., random_negative_prob=0)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_r = np.round(
|
||||
np.clip((results['ori_img'] * 0.5 + img_gray * 0.5), 0,
|
||||
255)).astype(results['ori_img'].dtype)
|
||||
assert (results['img'] == img_r).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
# test case when magnitude=0.3, random_negative_prob=1
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(
|
||||
type='ColorTransform', magnitude=.3, prob=1., random_negative_prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_r = np.round(
|
||||
np.clip((results['ori_img'] * 0.7 + img_gray * 0.3), 0,
|
||||
255)).astype(results['ori_img'].dtype)
|
||||
assert (results['img'] == img_r).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_brightness(nb_rand_test=100):
|
||||
|
||||
def _adjust_brightness(img, factor):
|
||||
# adjust the brightness of image using
|
||||
# PIL.ImageEnhance.Brightness
|
||||
from PIL.ImageEnhance import Brightness
|
||||
from PIL import Image
|
||||
img = Image.fromarray(img)
|
||||
brightened_img = Brightness(img).enhance(factor)
|
||||
return np.asarray(brightened_img)
|
||||
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Brightness', magnitude=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Brightness', magnitude=0.5, prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of random_negative_prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='Brightness', magnitude=0.5, random_negative_prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when magnitude=0, therefore no adjusting brightness
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='Brightness', magnitude=0., prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=0, therefore no adjusting brightness
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='Brightness', magnitude=1., prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=1 with randomly sampled image.
|
||||
results = construct_toy_data()
|
||||
for _ in range(nb_rand_test):
|
||||
magnitude = np.random.uniform() * np.random.choice([-1, 1])
|
||||
transform = dict(
|
||||
type='Brightness',
|
||||
magnitude=magnitude,
|
||||
prob=1.,
|
||||
random_negative_prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0,
|
||||
255).astype(np.uint8)
|
||||
results['img'] = img
|
||||
results = pipeline(copy.deepcopy(results))
|
||||
np.testing.assert_allclose(
|
||||
results['img'],
|
||||
_adjust_brightness(img, 1 + magnitude),
|
||||
rtol=0,
|
||||
atol=1)
|
||||
|
||||
|
||||
def test_sharpness(nb_rand_test=100):
|
||||
|
||||
def _adjust_sharpness(img, factor):
|
||||
# adjust the sharpness of image using
|
||||
# PIL.ImageEnhance.Sharpness
|
||||
from PIL.ImageEnhance import Sharpness
|
||||
from PIL import Image
|
||||
img = Image.fromarray(img)
|
||||
sharpened_img = Sharpness(img).enhance(factor)
|
||||
return np.asarray(sharpened_img)
|
||||
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Sharpness', magnitude=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Sharpness', magnitude=0.5, prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of random_negative_prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='Sharpness', magnitude=0.5, random_negative_prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when magnitude=0, therefore no adjusting sharpness
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='Sharpness', magnitude=0., prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=0, therefore no adjusting sharpness
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='Sharpness', magnitude=1., prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=1 with randomly sampled image.
|
||||
results = construct_toy_data()
|
||||
for _ in range(nb_rand_test):
|
||||
magnitude = np.random.uniform() * np.random.choice([-1, 1])
|
||||
transform = dict(
|
||||
type='Sharpness',
|
||||
magnitude=magnitude,
|
||||
prob=1.,
|
||||
random_negative_prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0,
|
||||
255).astype(np.uint8)
|
||||
results['img'] = img
|
||||
results = pipeline(copy.deepcopy(results))
|
||||
np.testing.assert_allclose(
|
||||
results['img'][1:-1, 1:-1],
|
||||
_adjust_sharpness(img, 1 + magnitude)[1:-1, 1:-1],
|
||||
rtol=0,
|
||||
atol=1)
|
||||
|
|
Loading…
Reference in New Issue