[Feature] Add color pipeline (#171)
* add ColorTransform pipeline * fix docstring * minor change * revised according to commentspull/172/head
parent
c8033ece8e
commit
9614787fc4
|
@ -1,4 +1,4 @@
|
||||||
from .auto_augment import Invert, Rotate, Shear, Translate
|
from .auto_augment import ColorTransform, Invert, Rotate, Shear, Translate
|
||||||
from .compose import Compose
|
from .compose import Compose
|
||||||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||||
Transpose, to_tensor)
|
Transpose, to_tensor)
|
||||||
|
@ -10,5 +10,6 @@ __all__ = [
|
||||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
||||||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert'
|
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||||
|
'ColorTransform'
|
||||||
]
|
]
|
||||||
|
|
|
@ -264,6 +264,7 @@ class Rotate(object):
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class Invert(object):
|
class Invert(object):
|
||||||
"""Invert images.
|
"""Invert images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prob (float): The probability for performing invert therefore should
|
prob (float): The probability for performing invert therefore should
|
||||||
be in range [0, 1]. Defaults to 0.5.
|
be in range [0, 1]. Defaults to 0.5.
|
||||||
|
@ -288,3 +289,47 @@ class Invert(object):
|
||||||
repr_str = self.__class__.__name__
|
repr_str = self.__class__.__name__
|
||||||
repr_str += f'(prob={self.prob})'
|
repr_str += f'(prob={self.prob})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class ColorTransform(object):
|
||||||
|
"""Adjust the color balance of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
magnitude (int | float): The magnitude used for color transform. A
|
||||||
|
positive magnitude would enhance the color and a negative magnitude
|
||||||
|
would make the image grayer. A magnitude=0 gives the origin img.
|
||||||
|
prob (float): The probability for performing ColorTransform therefore
|
||||||
|
should be in range [0, 1]. Defaults to 0.5.
|
||||||
|
random_negative_prob (float): The probability that turns the magnitude
|
||||||
|
negative, which should be in range [0,1]. Defaults to 0.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5):
|
||||||
|
assert isinstance(magnitude, (int, float)), 'The magnitude type must '\
|
||||||
|
f'be int or float, but got {type(magnitude)} instead.'
|
||||||
|
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||||
|
f'got {prob} instead.'
|
||||||
|
assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \
|
||||||
|
f'should be in range [0,1], got {random_negative_prob} instead.'
|
||||||
|
|
||||||
|
self.magnitude = magnitude
|
||||||
|
self.prob = prob
|
||||||
|
self.random_negative_prob = random_negative_prob
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
if np.random.rand() > self.prob:
|
||||||
|
return results
|
||||||
|
magnitude = random_negative(self.magnitude, self.random_negative_prob)
|
||||||
|
for key in results.get('img_fields', ['img']):
|
||||||
|
img = results[key]
|
||||||
|
img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude)
|
||||||
|
results[key] = img_color_adjusted.astype(img.dtype)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(magnitude={self.magnitude}, '
|
||||||
|
repr_str += f'prob={self.prob}, '
|
||||||
|
repr_str += f'random_negative_prob={self.random_negative_prob})'
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from mmcv.utils import build_from_cfg
|
from mmcv.utils import build_from_cfg
|
||||||
|
@ -22,6 +23,21 @@ def construct_toy_data():
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def construct_toy_data_photometric():
|
||||||
|
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
|
||||||
|
dtype=np.uint8)
|
||||||
|
img = np.stack([img, img, img], axis=-1)
|
||||||
|
results = dict()
|
||||||
|
# image
|
||||||
|
results['ori_img'] = img
|
||||||
|
results['img'] = img
|
||||||
|
results['img2'] = copy.deepcopy(img)
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
results['img_fields'] = ['img', 'img2']
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def test_shear():
|
def test_shear():
|
||||||
# test assertion for invalid type of magnitude
|
# test assertion for invalid type of magnitude
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
@ -335,3 +351,69 @@ def test_invert():
|
||||||
axis=-1)
|
axis=-1)
|
||||||
assert (results['img'] == inverted_img).all()
|
assert (results['img'] == inverted_img).all()
|
||||||
assert (results['img'] == results['img2']).all()
|
assert (results['img'] == results['img2']).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_color_transform():
|
||||||
|
# test assertion for invalid type of magnitude
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='ColorTransform', magnitude=None)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test assertion for invalid value of prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='ColorTransform', magnitude=0.5, prob=100)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test assertion for invalid value of random_negative_prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='ColorTransform', magnitude=0.5, random_negative_prob=100)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test case when magnitude=0, therefore no color transform
|
||||||
|
results = construct_toy_data_photometric()
|
||||||
|
transform = dict(type='ColorTransform', magnitude=0., prob=1.)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
assert (results['img'] == results['ori_img']).all()
|
||||||
|
|
||||||
|
# test case when prob=0, therefore no color transform
|
||||||
|
results = construct_toy_data_photometric()
|
||||||
|
transform = dict(type='ColorTransform', magnitude=1., prob=0.)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
assert (results['img'] == results['ori_img']).all()
|
||||||
|
|
||||||
|
# test case when magnitude=-1, therefore got gray img
|
||||||
|
results = construct_toy_data_photometric()
|
||||||
|
transform = dict(
|
||||||
|
type='ColorTransform', magnitude=-1., prob=1., random_negative_prob=0)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
img_gray = mmcv.bgr2gray(results['ori_img'])
|
||||||
|
img_gray = np.stack([img_gray, img_gray, img_gray], axis=-1)
|
||||||
|
assert (results['img'] == img_gray).all()
|
||||||
|
|
||||||
|
# test case when magnitude=0.5
|
||||||
|
results = construct_toy_data_photometric()
|
||||||
|
transform = dict(
|
||||||
|
type='ColorTransform', magnitude=.5, prob=1., random_negative_prob=0)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
img_r = np.round(
|
||||||
|
np.clip((results['ori_img'] * 0.5 + img_gray * 0.5), 0,
|
||||||
|
255)).astype(results['ori_img'].dtype)
|
||||||
|
assert (results['img'] == img_r).all()
|
||||||
|
assert (results['img'] == results['img2']).all()
|
||||||
|
|
||||||
|
# test case when magnitude=0.3, random_negative_prob=1
|
||||||
|
results = construct_toy_data_photometric()
|
||||||
|
transform = dict(
|
||||||
|
type='ColorTransform', magnitude=.3, prob=1., random_negative_prob=1.)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
img_r = np.round(
|
||||||
|
np.clip((results['ori_img'] * 0.7 + img_gray * 0.3), 0,
|
||||||
|
255)).astype(results['ori_img'].dtype)
|
||||||
|
assert (results['img'] == img_r).all()
|
||||||
|
assert (results['img'] == results['img2']).all()
|
||||||
|
|
Loading…
Reference in New Issue