From dfbe3f62351e3785828d9322e244e3b0036d9704 Mon Sep 17 00:00:00 2001 From: "fangyixiao.vendor" Date: Mon, 16 May 2022 05:24:13 +0000 Subject: [PATCH] [Feature] add ColorJitter --- mmselfsup/datasets/pipelines/__init__.py | 4 +- mmselfsup/datasets/pipelines/transforms.py | 152 +++++++++++++++++- .../test_pipelines/test_transforms.py | 30 +++- 3 files changed, 180 insertions(+), 6 deletions(-) diff --git a/mmselfsup/datasets/pipelines/__init__.py b/mmselfsup/datasets/pipelines/__init__.py index 69aa6907..e09201fe 100644 --- a/mmselfsup/datasets/pipelines/__init__.py +++ b/mmselfsup/datasets/pipelines/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .formatting import PackSelfSupInputs -from .transforms import (BEiTMaskGenerator, Lighting, RandomAug, +from .transforms import (BEiTMaskGenerator, ColorJitter, Lighting, RandomAug, RandomGaussianBlur, RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels, RandomSolarize, @@ -9,7 +9,7 @@ from .wrappers import MultiView __all__ = [ 'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug', - 'SimMIMMaskGenerator', 'BEiTMaskGenerator', + 'SimMIMMaskGenerator', 'BEiTMaskGenerator', 'ColorJitter', 'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs', 'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels' ] diff --git a/mmselfsup/datasets/pipelines/transforms.py b/mmselfsup/datasets/pipelines/transforms.py index ac5b28d1..2ba11294 100644 --- a/mmselfsup/datasets/pipelines/transforms.py +++ b/mmselfsup/datasets/pipelines/transforms.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +import numbers import random import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -8,7 +9,8 @@ import cv2 import numpy as np import torch import torchvision.transforms.functional as F -from mmcv.image import adjust_lighting, solarize +from mmcv.image import (adjust_brightness, adjust_color, adjust_contrast, + adjust_hue, adjust_lighting, solarize) from mmcv.transforms import BaseTransform from PIL import Image, ImageFilter from timm.data import create_transform @@ -645,3 +647,151 @@ class RandomPatchWithLabels(BaseTransform): def __repr__(self) -> str: repr_str = self.__class__.__name__ return repr_str + + +@TRANSFORMS.register_module() +class ColorJitter(BaseTransform): + """Randomly change the brightness, contrast, saturation and hue of an + image. + + Modified from + https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + brightness (float or tuple of float (min, max)): How much to jitter + brightness. brightness_factor is chosen uniformly from + [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. + Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter + contrast. contrast_factor is chosen uniformly from + [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. + Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter + saturation. saturation_factor is chosen uniformly from + [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. + Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given + [min, max]. Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + To jitter hue, the pixel values of the input image has to be + non-negative for conversion to HSV space; thus it does not work + if you normalize your image to an interval with negative values, + or use an interpolation that generates negative values before using + this function. + """ # noqa: E501 + + def __init__(self, + brightness: Union[float, List[float]] = 0, + contrast: Union[float, List[float]] = 0, + saturation: Union[float, List[float]] = 0, + hue: Union[float, List[float]] = 0) -> None: + super().__init__() + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input( + hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + def _check_input( + self, + value: float, + name: str, + center: float = 1., + bound: Tuple = (0, float('inf')), + clip_first_on_zero: bool = True) -> Union[List[float], None]: + """Check the input and convert it to the tuple format.""" + + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f'If {name} is a single number, it must be non negative.') + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f'{name} values should be between {bound}') + else: + raise TypeError( + f'{name} should be a single number or a tuple with length 2.') + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params( + brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]], + ) -> Tuple[np.ndarray, Optional[float], Optional[float], Optional[float], + Optional[float]]: + """Get the parameters for the randomized transform to be applied on + image. + + Args: + brightness (tuple of float (min, max), optional): The range from + which the brightness_factor is chosen uniformly. Pass None to + turn off the transformation. + contrast (tuple of float (min, max), optional): The range from + which the contrast_factor is chosen uniformly. Pass None to + turn off the transformation. + saturation (tuple of float (min, max), optional): The range from + which the saturation_factor is chosen uniformly. Pass None to + turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the + hue_factor is chosen uniformly. Pass None to turn off the + transformation. + + Returns: + tuple: The parameters used to apply the randomized transform + along with their random order. + """ + order = np.random.permutation(4) + + b = None if brightness is None else float( + np.random.uniform(brightness[0], brightness[1])) + c = None if contrast is None else float( + np.random.uniform(contrast[0], contrast[1])) + s = None if saturation is None else float( + np.random.uniform(saturation[0], saturation[1])) + h = None if hue is None else float(np.random.uniform(hue[0], hue[1])) + + return b, c, s, h, order + + def transform(self, results: Dict) -> Dict: + brightness_factor, contrast_factor, saturation_factor, hue_factor, \ + order = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + img = results['img'] + for fn_id in order: + if fn_id == 0 and brightness_factor is not None: + img = adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = adjust_color(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = adjust_hue(img, hue_factor) + results['img'] = img + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(brightness={self.brightness}, ' + repr_str += f'contrast={self.contrast}, ' + repr_str += f'saturation={self.saturation},' + repr_str += f'saturation={self.hue})' + return repr_str diff --git a/tests/test_datasets/test_pipelines/test_transforms.py b/tests/test_datasets/test_pipelines/test_transforms.py index a969500b..f838183e 100644 --- a/tests/test_datasets/test_pipelines/test_transforms.py +++ b/tests/test_datasets/test_pipelines/test_transforms.py @@ -4,9 +4,9 @@ import pytest import torch from mmselfsup.datasets.pipelines import ( - BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels, - RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels, - RandomSolarize, SimMIMMaskGenerator) + BEiTMaskGenerator, ColorJitter, Lighting, RandomGaussianBlur, + RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic, + RandomRotationWithLabels, RandomSolarize, SimMIMMaskGenerator) def test_simmim_mask_gen(): @@ -139,3 +139,27 @@ def test_random_patch(): # test transform assert list(results['img'].shape) == [8, 6, 53, 53] assert list(results['patch_label'].shape) == [8] + + +def test_color_jitter(): + with pytest.raises(ValueError): + transform = ColorJitter(-1, 0, 0, 0) + + with pytest.raises(ValueError): + transform = ColorJitter(0, 0, 0, [0, 1]) + + with pytest.raises(TypeError): + transform = ColorJitter('test', 0, 0, 0) + + original_img = torch.rand((224, 224, 3)).numpy().astype(np.uint8) + results = {'img': original_img} + + transform = ColorJitter(0, 0, 0, 0) + results = transform(results) + assert np.equal(results['img'], original_img).all() + + transform = ColorJitter(0.4, 0.4, 0.2, 0.1) + results = transform(results) + assert results['img'].shape == original_img.shape + + assert isinstance(str(transform), str)