diff --git a/mmcls/datasets/pipelines/__init__.py b/mmcls/datasets/pipelines/__init__.py index 21eb12d51..dcf535c60 100644 --- a/mmcls/datasets/pipelines/__init__.py +++ b/mmcls/datasets/pipelines/__init__.py @@ -1,7 +1,7 @@ from .auto_augment import (AutoAugment, AutoContrast, Brightness, - ColorTransform, Contrast, Equalize, Invert, - Posterize, Rotate, Sharpness, Shear, Solarize, - Translate) + ColorTransform, Contrast, Cutout, Equalize, Invert, + Posterize, RandAugment, Rotate, Sharpness, Shear, + Solarize, SolarizeAdd, Translate) from .compose import Compose from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, Transpose, to_tensor) @@ -16,6 +16,6 @@ __all__ = [ 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert', 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', - 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'Lighting', - 'ColorJitter' + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', + 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter' ] diff --git a/mmcls/datasets/pipelines/auto_augment.py b/mmcls/datasets/pipelines/auto_augment.py index 1dd66e2ac..e4b48c892 100644 --- a/mmcls/datasets/pipelines/auto_augment.py +++ b/mmcls/datasets/pipelines/auto_augment.py @@ -1,4 +1,5 @@ import copy +import random import mmcv import numpy as np @@ -41,7 +42,7 @@ class AutoAugment(object): self.sub_policy = [Compose(policy) for policy in self.policies] def __call__(self, results): - sub_policy = np.random.choice(self.sub_policy) + sub_policy = random.choice(self.sub_policy) return sub_policy(results) def __repr__(self): @@ -50,6 +51,87 @@ class AutoAugment(object): return repr_str +@PIPELINES.register_module() +class RandAugment(object): + """Random augmentation. + This data augmentation is proposed in `RandAugment: Practical automated + data augmentation with a reduced search space + `_. + + Args: + policies (list[dict]): The policies of random augmentation. Each + policy in ``policies`` is one specific augmentation policy (dict). + The policy shall at least have key `type`, indicating the type of + augmentation. For those which have magnitude, (given to the fact + they are named differently in different augmentation, ) + `magnitude_key` and `magnitude_range` shall be the magnitude + argument (str) and the range of magnitude (tuple in the format or + (minval, maxval)), respectively. + num_policies (int): Number of policies to select from policies each + time. + magnitude_level (int | float): Magnitude level for all the augmentation + selected. + total_level (int | float): Total level for the magnitude. Defaults to + 30. + """ + + def __init__(self, + policies, + num_policies, + magnitude_level, + total_level=30): + assert isinstance(num_policies, int), 'Number of policies must be ' \ + f'of int type, got {type(num_policies)} instead.' + assert isinstance(magnitude_level, (int, float)), \ + 'Magnitude level must be of int or float type, ' \ + f'got {type(magnitude_level)} instead.' + assert isinstance(total_level, (int, float)), 'Total level must be ' \ + f'of int or float type, got {type(total_level)} instead.' + assert isinstance(policies, list) and len(policies) > 0, \ + 'Policies must be a non-empty list.' + for policy in policies: + assert isinstance(policy, dict) and 'type' in policy, \ + 'Each policy must be a dict with key "type".' + + assert num_policies > 0, 'num_policies must be greater than 0.' + assert magnitude_level >= 0, 'magnitude_level must be no less than 0.' + assert total_level > 0, 'total_level must be greater than 0.' + + self.num_policies = num_policies + self.magnitude_level = magnitude_level + self.total_level = total_level + self.policies = self._process_policies(policies) + + def _process_policies(self, policies): + processed_policies = [] + for policy in policies: + processed_policy = copy.deepcopy(policy) + magnitude_key = processed_policy.pop('magnitude_key', None) + if magnitude_key is not None: + minval, maxval = processed_policy.pop('magnitude_range') + magnitude_value = (float(self.magnitude_level) / + self.total_level) * float(maxval - + minval) + minval + processed_policy.update({magnitude_key: magnitude_value}) + processed_policies.append(processed_policy) + return processed_policies + + def __call__(self, results): + if self.num_policies == 0: + return results + sub_policy = random.choices(self.policies, k=self.num_policies) + sub_policy = Compose(sub_policy) + return sub_policy(results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(policies={self.policies}, ' + repr_str += f'num_policies={self.num_policies}, ' + repr_str += f'magnitude_level={self.magnitude_level}, ' + repr_str += f'total_level={self.total_level})' + return repr_str + + @PIPELINES.register_module() class Shear(object): """Shear images. @@ -428,25 +510,66 @@ class Solarize(object): return repr_str +@PIPELINES.register_module() +class SolarizeAdd(object): + """SolarizeAdd images (add a certain value to pixels below a threshold). + + Args: + magnitude (int | float): The value to be added to pixels below the thr. + thr (int | float): The threshold below which the pixels value will be + adjusted. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + """ + + def __init__(self, magnitude, thr=128, prob=0.5): + assert isinstance(magnitude, (int, float)), 'The thr magnitude must '\ + f'be int or float, but got {type(magnitude)} instead.' + assert isinstance(thr, (int, float)), 'The thr type must '\ + f'be int or float, but got {type(thr)} instead.' + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + + self.magnitude = magnitude + self.thr = thr + self.prob = prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + for key in results.get('img_fields', ['img']): + img = results[key] + img_solarized = np.where(img < self.thr, + np.minimum(img + self.magnitude, 255), + img) + results[key] = img_solarized.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'thr={self.thr}, ' + repr_str += f'prob={self.prob})' + return repr_str + + @PIPELINES.register_module() class Posterize(object): """Posterize images (reduce the number of bits for each color channel). Args: - bits (int): Number of bits for each pixel in the output img, which - should be less or equal to 8. + bits (int | float): Number of bits for each pixel in the output img, + which should be less or equal to 8. prob (float): The probability for posterizing therefore should be in range [0, 1]. Defaults to 0.5. """ def __init__(self, bits, prob=0.5): - assert isinstance(bits, int), 'The bits type must be int, '\ - f'but got {type(bits)} instead.' assert bits <= 8, f'The bits must be less than 8, got {bits} instead.' assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ f'got {prob} instead.' - self.bits = bits + self.bits = int(bits) self.prob = prob def __call__(self, results): @@ -642,3 +765,51 @@ class Sharpness(object): repr_str += f'prob={self.prob}, ' repr_str += f'random_negative_prob={self.random_negative_prob})' return repr_str + + +@PIPELINES.register_module() +class Cutout(object): + """Cutout images. + + Args: + shape (int | float | tuple(int | float)): Expected cutout shape (h, w). + If given as a single value, the value will be used for + both h and w. + pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If + it is a tuple, it must have the same length with the image + channels. Defaults to 128. + prob (float): The probability for performing cutout therefore should + be in range [0, 1]. Defaults to 0.5. + """ + + def __init__(self, shape, pad_val=128, prob=0.5): + if isinstance(shape, float): + shape = int(shape) + elif isinstance(shape, tuple): + shape = tuple(int(i) for i in shape) + elif not isinstance(shape, int): + raise TypeError( + 'shape must be of ' + f'type int, float or tuple, got {type(shape)} instead') + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + + self.shape = shape + self.pad_val = pad_val + self.prob = prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + for key in results.get('img_fields', ['img']): + img = results[key] + img_cutout = mmcv.cutout(img, self.shape, pad_val=self.pad_val) + results[key] = img_cutout.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(shape={self.shape}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob})' + return repr_str diff --git a/tests/test_pipelines/test_auto_augment.py b/tests/test_pipelines/test_auto_augment.py index 9c750af0b..bd23aca96 100644 --- a/tests/test_pipelines/test_auto_augment.py +++ b/tests/test_pipelines/test_auto_augment.py @@ -1,4 +1,5 @@ import copy +import random import mmcv import numpy as np @@ -38,6 +39,189 @@ def construct_toy_data_photometric(): return results +def test_rand_augment(): + policies = [ + dict( + type='Translate', + magnitude_key='magnitude', + magnitude_range=(0, 1), + pad_val=128, + prob=1., + direction='horizontal'), + dict(type='Invert', prob=1.), + dict( + type='Rotate', + magnitude_key='angle', + magnitude_range=(0, 30), + prob=0.) + ] + # test assertion for num_policies + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1.5, + magnitude_level=12) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=policies, + num_policies=-1, + magnitude_level=12) + build_from_cfg(transform, PIPELINES) + # test assertion for magnitude_level + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1, + magnitude_level=None) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1, + magnitude_level=-1) + build_from_cfg(transform, PIPELINES) + # test assertion for total_level + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1, + magnitude_level=12, + total_level=None) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1, + magnitude_level=12, + total_level=-30) + build_from_cfg(transform, PIPELINES) + # test assertion for policies + with pytest.raises(AssertionError): + transform = dict( + type='RandAugment', + policies=[], + num_policies=2, + magnitude_level=12) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + invalid_policies = copy.deepcopy(policies) + invalid_policies.append(('Wrong_policy')) + transform = dict( + type='RandAugment', + policies=invalid_policies, + num_policies=2, + magnitude_level=12) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + invalid_policies = copy.deepcopy(policies) + invalid_policies[2].pop('type') + transform = dict( + type='RandAugment', + policies=invalid_policies, + num_policies=2, + magnitude_level=12) + build_from_cfg(transform, PIPELINES) + with pytest.raises(KeyError): + invalid_policies = copy.deepcopy(policies) + invalid_policies[2].pop('magnitude_range') + transform = dict( + type='RandAugment', + policies=invalid_policies, + num_policies=2, + magnitude_level=12) + build_from_cfg(transform, PIPELINES) + + # test case where num_policies = 1 + random.seed(1) + np.random.seed(0) + results = construct_toy_data() + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1, + magnitude_level=12) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + # apply translate + img_augmented = np.array( + [[128, 128, 1, 2], [128, 128, 5, 6], [128, 128, 9, 10]], + dtype=np.uint8) + img_augmented = np.stack([img_augmented, img_augmented, img_augmented], + axis=-1) + assert (results['img'] == img_augmented).all() + + results = construct_toy_data() + transform = dict( + type='RandAugment', + policies=policies, + num_policies=1, + magnitude_level=12) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + # apply rotation with prob=0. + assert (results['img'] == results['ori_img']).all() + + # test case where num_policies = 2 + random.seed(0) + np.random.seed(0) + results = construct_toy_data() + transform = dict( + type='RandAugment', + policies=policies, + num_policies=2, + magnitude_level=12) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + # apply rotate and rotate with prob=0 + assert (results['img'] == results['ori_img']).all() + + results = construct_toy_data() + transform = dict( + type='RandAugment', + policies=policies, + num_policies=2, + magnitude_level=12) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + # apply invert and translate + img_augmented = np.array( + [[252, 251, 128, 128], [248, 247, 128, 128], [244, 243, 128, 128]], + dtype=np.uint8) + img_augmented = np.stack([img_augmented, img_augmented, img_augmented], + axis=-1) + assert (results['img'] == img_augmented).all() + + results = construct_toy_data() + transform = dict( + type='RandAugment', + policies=policies, + num_policies=2, + magnitude_level=0) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + # apply invert and invert + assert (results['img'] == results['ori_img']).all() + + # test case where magnitude_level = 0 + results = construct_toy_data() + transform = dict( + type='RandAugment', + policies=policies, + num_policies=2, + magnitude_level=0) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + # apply rotate and translate + assert (results['img'] == results['ori_img']).all() + + def test_shear(): # test assertion for invalid type of magnitude with pytest.raises(AssertionError): @@ -467,12 +651,57 @@ def test_solarize(): assert (results['img'] == results['img2']).all() -def test_posterize(): - # test assertion for invalid type of bits +def test_solarize_add(): + # test assertion for invalid type of magnitude with pytest.raises(AssertionError): - transform = dict(type='Posterize', bits=4.5) + transform = dict(type='SolarizeAdd', magnitude=(1, 2)) build_from_cfg(transform, PIPELINES) + # test assertion for invalid type of thr + with pytest.raises(AssertionError): + transform = dict(type='SolarizeAdd', magnitude=100, thr=(1, 2)) + build_from_cfg(transform, PIPELINES) + + # test case when prob=0, therefore no solarize + results = construct_toy_data_photometric() + transform = dict(type='SolarizeAdd', magnitude=100, thr=128, prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when thr=0, therefore no solarize + results = construct_toy_data_photometric() + transform = dict(type='SolarizeAdd', magnitude=100, thr=0, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when thr=128, magnitude=100 + results = construct_toy_data_photometric() + transform = dict(type='SolarizeAdd', magnitude=100, thr=128, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_solarized = np.array( + [[100, 128, 255], [101, 227, 254], [102, 129, 253]], dtype=np.uint8) + img_solarized = np.stack([img_solarized, img_solarized, img_solarized], + axis=-1) + assert (results['img'] == img_solarized).all() + assert (results['img'] == results['img2']).all() + + # test case when thr=100, magnitude=50 + results = construct_toy_data_photometric() + transform = dict(type='SolarizeAdd', magnitude=50, thr=100, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_solarized = np.array([[50, 128, 255], [51, 127, 254], [52, 129, 253]], + dtype=np.uint8) + img_solarized = np.stack([img_solarized, img_solarized, img_solarized], + axis=-1) + assert (results['img'] == img_solarized).all() + assert (results['img'] == results['img2']).all() + + +def test_posterize(): # test assertion for invalid value of bits with pytest.raises(AssertionError): transform = dict(type='Posterize', bits=10) @@ -771,3 +1000,59 @@ def test_sharpness(nb_rand_test=100): _adjust_sharpness(img, 1 + magnitude)[1:-1, 1:-1], rtol=0, atol=1) + + +def test_cutout(): + + # test assertion for invalid type of shape + with pytest.raises(TypeError): + transform = dict(type='Cutout', shape=None) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='Cutout', shape=1, prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when prob=0, therefore no cutout + results = construct_toy_data() + transform = dict(type='Cutout', shape=2, prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when shape=0, therefore no cutout + results = construct_toy_data() + transform = dict(type='Cutout', shape=0, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when shape=6, therefore the whole img has been cut + results = construct_toy_data() + transform = dict(type='Cutout', shape=6, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == np.ones_like(results['ori_img']) * 128).all() + + # test case when shape is int + np.random.seed(0) + results = construct_toy_data() + transform = dict(type='Cutout', shape=1, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_cutout = np.array([[1, 2, 3, 4], [5, 128, 7, 8], [9, 10, 11, 12]], + dtype=np.uint8) + img_cutout = np.stack([img_cutout, img_cutout, img_cutout], axis=-1) + assert (results['img'] == img_cutout).all() + + # test case when shape is tuple + np.random.seed(0) + results = construct_toy_data() + transform = dict(type='Cutout', shape=(1, 2), pad_val=0, prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_cutout = np.array([[1, 2, 3, 4], [5, 0, 0, 8], [9, 10, 11, 12]], + dtype=np.uint8) + img_cutout = np.stack([img_cutout, img_cutout, img_cutout], axis=-1) + assert (results['img'] == img_cutout).all()