[Feature] Add color pipeline (#171)

* add ColorTransform pipeline

* fix docstring

* minor change

* revised according to comments
pull/172/head
LXXXXR 2021-03-09 19:28:50 +08:00 committed by GitHub
parent c8033ece8e
commit 9614787fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from .auto_augment import Invert, Rotate, Shear, Translate from .auto_augment import ColorTransform, Invert, Rotate, Shear, Translate
from .compose import Compose from .compose import Compose
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
Transpose, to_tensor) Transpose, to_tensor)
@ -10,5 +10,6 @@ __all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert' 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform'
] ]

View File

@ -264,6 +264,7 @@ class Rotate(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class Invert(object): class Invert(object):
"""Invert images. """Invert images.
Args: Args:
prob (float): The probability for performing invert therefore should prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 0.5. be in range [0, 1]. Defaults to 0.5.
@ -288,3 +289,47 @@ class Invert(object):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob})' repr_str += f'(prob={self.prob})'
return repr_str return repr_str
@PIPELINES.register_module()
class ColorTransform(object):
"""Adjust the color balance of images.
Args:
magnitude (int | float): The magnitude used for color transform. A
positive magnitude would enhance the color and a negative magnitude
would make the image grayer. A magnitude=0 gives the origin img.
prob (float): The probability for performing ColorTransform therefore
should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude
negative, which should be in range [0,1]. Defaults to 0.5.
"""
def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5):
assert isinstance(magnitude, (int, float)), 'The magnitude type must '\
f'be int or float, but got {type(magnitude)} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \
f'should be in range [0,1], got {random_negative_prob} instead.'
self.magnitude = magnitude
self.prob = prob
self.random_negative_prob = random_negative_prob
def __call__(self, results):
if np.random.rand() > self.prob:
return results
magnitude = random_negative(self.magnitude, self.random_negative_prob)
for key in results.get('img_fields', ['img']):
img = results[key]
img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude)
results[key] = img_color_adjusted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob})'
return repr_str

View File

@ -1,5 +1,6 @@
import copy import copy
import mmcv
import numpy as np import numpy as np
import pytest import pytest
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
@ -22,6 +23,21 @@ def construct_toy_data():
return results return results
def construct_toy_data_photometric():
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['ori_img'] = img
results['img'] = img
results['img2'] = copy.deepcopy(img)
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img', 'img2']
return results
def test_shear(): def test_shear():
# test assertion for invalid type of magnitude # test assertion for invalid type of magnitude
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
@ -335,3 +351,69 @@ def test_invert():
axis=-1) axis=-1)
assert (results['img'] == inverted_img).all() assert (results['img'] == inverted_img).all()
assert (results['img'] == results['img2']).all() assert (results['img'] == results['img2']).all()
def test_color_transform():
# test assertion for invalid type of magnitude
with pytest.raises(AssertionError):
transform = dict(type='ColorTransform', magnitude=None)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid value of prob
with pytest.raises(AssertionError):
transform = dict(type='ColorTransform', magnitude=0.5, prob=100)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid value of random_negative_prob
with pytest.raises(AssertionError):
transform = dict(
type='ColorTransform', magnitude=0.5, random_negative_prob=100)
build_from_cfg(transform, PIPELINES)
# test case when magnitude=0, therefore no color transform
results = construct_toy_data_photometric()
transform = dict(type='ColorTransform', magnitude=0., prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when prob=0, therefore no color transform
results = construct_toy_data_photometric()
transform = dict(type='ColorTransform', magnitude=1., prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when magnitude=-1, therefore got gray img
results = construct_toy_data_photometric()
transform = dict(
type='ColorTransform', magnitude=-1., prob=1., random_negative_prob=0)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_gray = mmcv.bgr2gray(results['ori_img'])
img_gray = np.stack([img_gray, img_gray, img_gray], axis=-1)
assert (results['img'] == img_gray).all()
# test case when magnitude=0.5
results = construct_toy_data_photometric()
transform = dict(
type='ColorTransform', magnitude=.5, prob=1., random_negative_prob=0)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_r = np.round(
np.clip((results['ori_img'] * 0.5 + img_gray * 0.5), 0,
255)).astype(results['ori_img'].dtype)
assert (results['img'] == img_r).all()
assert (results['img'] == results['img2']).all()
# test case when magnitude=0.3, random_negative_prob=1
results = construct_toy_data_photometric()
transform = dict(
type='ColorTransform', magnitude=.3, prob=1., random_negative_prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_r = np.round(
np.clip((results['ori_img'] * 0.7 + img_gray * 0.3), 0,
255)).astype(results['ori_img'].dtype)
assert (results['img'] == img_r).all()
assert (results['img'] == results['img2']).all()