mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
add invert pipeline (#168)
This commit is contained in:
parent
b1fa298a66
commit
c8033ece8e
@ -1,4 +1,4 @@
|
|||||||
from .auto_augment import Rotate, Shear, Translate
|
from .auto_augment import Invert, Rotate, Shear, Translate
|
||||||
from .compose import Compose
|
from .compose import Compose
|
||||||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||||
Transpose, to_tensor)
|
Transpose, to_tensor)
|
||||||
@ -10,5 +10,5 @@ __all__ = [
|
|||||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
||||||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate'
|
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert'
|
||||||
]
|
]
|
||||||
|
@ -89,6 +89,7 @@ class Shear(object):
|
|||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class Translate(object):
|
class Translate(object):
|
||||||
"""Translate images.
|
"""Translate images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
magnitude (int | float): The magnitude used for translate. Note that
|
magnitude (int | float): The magnitude used for translate. Note that
|
||||||
the offset is calculated by magnitude * size in the corresponding
|
the offset is calculated by magnitude * size in the corresponding
|
||||||
@ -258,3 +259,32 @@ class Rotate(object):
|
|||||||
repr_str += f'random_negative_prob={self.random_negative_prob}, '
|
repr_str += f'random_negative_prob={self.random_negative_prob}, '
|
||||||
repr_str += f'interpolation={self.interpolation})'
|
repr_str += f'interpolation={self.interpolation})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class Invert(object):
|
||||||
|
"""Invert images.
|
||||||
|
Args:
|
||||||
|
prob (float): The probability for performing invert therefore should
|
||||||
|
be in range [0, 1]. Defaults to 0.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prob=0.5):
|
||||||
|
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||||
|
f'got {prob} instead.'
|
||||||
|
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
if np.random.rand() > self.prob:
|
||||||
|
return results
|
||||||
|
for key in results.get('img_fields', ['img']):
|
||||||
|
img = results[key]
|
||||||
|
img_inverted = mmcv.iminvert(img)
|
||||||
|
results[key] = img_inverted.astype(img.dtype)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(prob={self.prob})'
|
||||||
|
return repr_str
|
||||||
|
@ -308,3 +308,30 @@ def test_rotate():
|
|||||||
rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1)
|
rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1)
|
||||||
assert (results['img'] == rotated_img).all()
|
assert (results['img'] == rotated_img).all()
|
||||||
assert (results['img'] == results['img2']).all()
|
assert (results['img'] == results['img2']).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_invert():
|
||||||
|
# test assertion for invalid value of prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='Invert', prob=100)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test case when prob=0, therefore no invert
|
||||||
|
results = construct_toy_data()
|
||||||
|
transform = dict(type='Invert', prob=0.)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
assert (results['img'] == results['ori_img']).all()
|
||||||
|
|
||||||
|
# test case when prob=1
|
||||||
|
results = construct_toy_data()
|
||||||
|
transform = dict(type='Invert', prob=1.)
|
||||||
|
pipeline = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = pipeline(results)
|
||||||
|
inverted_img = np.array(
|
||||||
|
[[254, 253, 252, 251], [250, 249, 248, 247], [246, 245, 244, 243]],
|
||||||
|
dtype=np.uint8)
|
||||||
|
inverted_img = np.stack([inverted_img, inverted_img, inverted_img],
|
||||||
|
axis=-1)
|
||||||
|
assert (results['img'] == inverted_img).all()
|
||||||
|
assert (results['img'] == results['img2']).all()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user