mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
add invert pipeline (#168)
This commit is contained in:
parent
b1fa298a66
commit
c8033ece8e
@ -1,4 +1,4 @@
|
||||
from .auto_augment import Rotate, Shear, Translate
|
||||
from .auto_augment import Invert, Rotate, Shear, Translate
|
||||
from .compose import Compose
|
||||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
@ -10,5 +10,5 @@ __all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
||||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate'
|
||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert'
|
||||
]
|
||||
|
@ -89,6 +89,7 @@ class Shear(object):
|
||||
@PIPELINES.register_module()
|
||||
class Translate(object):
|
||||
"""Translate images.
|
||||
|
||||
Args:
|
||||
magnitude (int | float): The magnitude used for translate. Note that
|
||||
the offset is calculated by magnitude * size in the corresponding
|
||||
@ -258,3 +259,32 @@ class Rotate(object):
|
||||
repr_str += f'random_negative_prob={self.random_negative_prob}, '
|
||||
repr_str += f'interpolation={self.interpolation})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Invert(object):
|
||||
"""Invert images.
|
||||
Args:
|
||||
prob (float): The probability for performing invert therefore should
|
||||
be in range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, prob=0.5):
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_inverted = mmcv.iminvert(img)
|
||||
results[key] = img_inverted.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob})'
|
||||
return repr_str
|
||||
|
@ -308,3 +308,30 @@ def test_rotate():
|
||||
rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1)
|
||||
assert (results['img'] == rotated_img).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_invert():
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Invert', prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when prob=0, therefore no invert
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Invert', prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when prob=1
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Invert', prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
inverted_img = np.array(
|
||||
[[254, 253, 252, 251], [250, 249, 248, 247], [246, 245, 244, 243]],
|
||||
dtype=np.uint8)
|
||||
inverted_img = np.stack([inverted_img, inverted_img, inverted_img],
|
||||
axis=-1)
|
||||
assert (results['img'] == inverted_img).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
Loading…
x
Reference in New Issue
Block a user