[Feature] add ColorJitter

pull/352/head
fangyixiao.vendor 2022-05-16 05:24:13 +00:00 committed by fangyixiao18
parent 11c3a6a3af
commit dfbe3f6235
3 changed files with 180 additions and 6 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackSelfSupInputs
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
from .transforms import (BEiTMaskGenerator, ColorJitter, Lighting, RandomAug,
RandomGaussianBlur, RandomPatchWithLabels,
RandomResizedCropAndInterpolationWithTwoPic,
RandomRotationWithLabels, RandomSolarize,
@ -9,7 +9,7 @@ from .wrappers import MultiView
__all__ = [
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
'SimMIMMaskGenerator', 'BEiTMaskGenerator',
'SimMIMMaskGenerator', 'BEiTMaskGenerator', 'ColorJitter',
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels'
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import numbers
import random
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
@ -8,7 +9,8 @@ import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from mmcv.image import adjust_lighting, solarize
from mmcv.image import (adjust_brightness, adjust_color, adjust_contrast,
adjust_hue, adjust_lighting, solarize)
from mmcv.transforms import BaseTransform
from PIL import Image, ImageFilter
from timm.data import create_transform
@ -645,3 +647,151 @@ class RandomPatchWithLabels(BaseTransform):
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class ColorJitter(BaseTransform):
"""Randomly change the brightness, contrast, saturation and hue of an
image.
Modified from
https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
Required Keys:
- img
Modified Keys:
- img
Args:
brightness (float or tuple of float (min, max)): How much to jitter
brightness. brightness_factor is chosen uniformly from
[max(0, 1 - brightness), 1 + brightness] or the given [min, max].
Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter
contrast. contrast_factor is chosen uniformly from
[max(0, 1 - contrast), 1 + contrast] or the given [min, max].
Should be non negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter
saturation. saturation_factor is chosen uniformly from
[max(0, 1 - saturation), 1 + saturation] or the given [min, max].
Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given
[min, max]. Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be
non-negative for conversion to HSV space; thus it does not work
if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using
this function.
""" # noqa: E501
def __init__(self,
brightness: Union[float, List[float]] = 0,
contrast: Union[float, List[float]] = 0,
saturation: Union[float, List[float]] = 0,
hue: Union[float, List[float]] = 0) -> None:
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(
hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
def _check_input(
self,
value: float,
name: str,
center: float = 1.,
bound: Tuple = (0, float('inf')),
clip_first_on_zero: bool = True) -> Union[List[float], None]:
"""Check the input and convert it to the tuple format."""
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError(
f'If {name} is a single number, it must be non negative.')
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f'{name} values should be between {bound}')
else:
raise TypeError(
f'{name} should be a single number or a tuple with length 2.')
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
@staticmethod
def get_params(
brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]],
) -> Tuple[np.ndarray, Optional[float], Optional[float], Optional[float],
Optional[float]]:
"""Get the parameters for the randomized transform to be applied on
image.
Args:
brightness (tuple of float (min, max), optional): The range from
which the brightness_factor is chosen uniformly. Pass None to
turn off the transformation.
contrast (tuple of float (min, max), optional): The range from
which the contrast_factor is chosen uniformly. Pass None to
turn off the transformation.
saturation (tuple of float (min, max), optional): The range from
which the saturation_factor is chosen uniformly. Pass None to
turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the
hue_factor is chosen uniformly. Pass None to turn off the
transformation.
Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
order = np.random.permutation(4)
b = None if brightness is None else float(
np.random.uniform(brightness[0], brightness[1]))
c = None if contrast is None else float(
np.random.uniform(contrast[0], contrast[1]))
s = None if saturation is None else float(
np.random.uniform(saturation[0], saturation[1]))
h = None if hue is None else float(np.random.uniform(hue[0], hue[1]))
return b, c, s, h, order
def transform(self, results: Dict) -> Dict:
brightness_factor, contrast_factor, saturation_factor, hue_factor, \
order = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
img = results['img']
for fn_id in order:
if fn_id == 0 and brightness_factor is not None:
img = adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = adjust_color(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = adjust_hue(img, hue_factor)
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(brightness={self.brightness}, '
repr_str += f'contrast={self.contrast}, '
repr_str += f'saturation={self.saturation},'
repr_str += f'saturation={self.hue})'
return repr_str

View File

@ -4,9 +4,9 @@ import pytest
import torch
from mmselfsup.datasets.pipelines import (
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
RandomSolarize, SimMIMMaskGenerator)
BEiTMaskGenerator, ColorJitter, Lighting, RandomGaussianBlur,
RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic,
RandomRotationWithLabels, RandomSolarize, SimMIMMaskGenerator)
def test_simmim_mask_gen():
@ -139,3 +139,27 @@ def test_random_patch():
# test transform
assert list(results['img'].shape) == [8, 6, 53, 53]
assert list(results['patch_label'].shape) == [8]
def test_color_jitter():
with pytest.raises(ValueError):
transform = ColorJitter(-1, 0, 0, 0)
with pytest.raises(ValueError):
transform = ColorJitter(0, 0, 0, [0, 1])
with pytest.raises(TypeError):
transform = ColorJitter('test', 0, 0, 0)
original_img = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
results = {'img': original_img}
transform = ColorJitter(0, 0, 0, 0)
results = transform(results)
assert np.equal(results['img'], original_img).all()
transform = ColorJitter(0.4, 0.4, 0.2, 0.1)
results = transform(results)
assert results['img'].shape == original_img.shape
assert isinstance(str(transform), str)