[Feature] add ColorJitter
parent
11c3a6a3af
commit
dfbe3f6235
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSelfSupInputs
|
||||
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
|
||||
from .transforms import (BEiTMaskGenerator, ColorJitter, Lighting, RandomAug,
|
||||
RandomGaussianBlur, RandomPatchWithLabels,
|
||||
RandomResizedCropAndInterpolationWithTwoPic,
|
||||
RandomRotationWithLabels, RandomSolarize,
|
||||
|
@ -9,7 +9,7 @@ from .wrappers import MultiView
|
|||
|
||||
__all__ = [
|
||||
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
|
||||
'SimMIMMaskGenerator', 'BEiTMaskGenerator',
|
||||
'SimMIMMaskGenerator', 'BEiTMaskGenerator', 'ColorJitter',
|
||||
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
|
||||
'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
@ -8,7 +9,8 @@ import cv2
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from mmcv.image import adjust_lighting, solarize
|
||||
from mmcv.image import (adjust_brightness, adjust_color, adjust_contrast,
|
||||
adjust_hue, adjust_lighting, solarize)
|
||||
from mmcv.transforms import BaseTransform
|
||||
from PIL import Image, ImageFilter
|
||||
from timm.data import create_transform
|
||||
|
@ -645,3 +647,151 @@ class RandomPatchWithLabels(BaseTransform):
|
|||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ColorJitter(BaseTransform):
|
||||
"""Randomly change the brightness, contrast, saturation and hue of an
|
||||
image.
|
||||
|
||||
Modified from
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
brightness (float or tuple of float (min, max)): How much to jitter
|
||||
brightness. brightness_factor is chosen uniformly from
|
||||
[max(0, 1 - brightness), 1 + brightness] or the given [min, max].
|
||||
Should be non negative numbers.
|
||||
contrast (float or tuple of float (min, max)): How much to jitter
|
||||
contrast. contrast_factor is chosen uniformly from
|
||||
[max(0, 1 - contrast), 1 + contrast] or the given [min, max].
|
||||
Should be non negative numbers.
|
||||
saturation (float or tuple of float (min, max)): How much to jitter
|
||||
saturation. saturation_factor is chosen uniformly from
|
||||
[max(0, 1 - saturation), 1 + saturation] or the given [min, max].
|
||||
Should be non negative numbers.
|
||||
hue (float or tuple of float (min, max)): How much to jitter hue.
|
||||
hue_factor is chosen uniformly from [-hue, hue] or the given
|
||||
[min, max]. Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
||||
To jitter hue, the pixel values of the input image has to be
|
||||
non-negative for conversion to HSV space; thus it does not work
|
||||
if you normalize your image to an interval with negative values,
|
||||
or use an interpolation that generates negative values before using
|
||||
this function.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
brightness: Union[float, List[float]] = 0,
|
||||
contrast: Union[float, List[float]] = 0,
|
||||
saturation: Union[float, List[float]] = 0,
|
||||
hue: Union[float, List[float]] = 0) -> None:
|
||||
super().__init__()
|
||||
self.brightness = self._check_input(brightness, 'brightness')
|
||||
self.contrast = self._check_input(contrast, 'contrast')
|
||||
self.saturation = self._check_input(saturation, 'saturation')
|
||||
self.hue = self._check_input(
|
||||
hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
||||
|
||||
def _check_input(
|
||||
self,
|
||||
value: float,
|
||||
name: str,
|
||||
center: float = 1.,
|
||||
bound: Tuple = (0, float('inf')),
|
||||
clip_first_on_zero: bool = True) -> Union[List[float], None]:
|
||||
"""Check the input and convert it to the tuple format."""
|
||||
|
||||
if isinstance(value, numbers.Number):
|
||||
if value < 0:
|
||||
raise ValueError(
|
||||
f'If {name} is a single number, it must be non negative.')
|
||||
value = [center - value, center + value]
|
||||
if clip_first_on_zero:
|
||||
value[0] = max(value[0], 0.0)
|
||||
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
||||
raise ValueError(f'{name} values should be between {bound}')
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{name} should be a single number or a tuple with length 2.')
|
||||
|
||||
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
||||
# or (0., 0.) for hue, do nothing
|
||||
if value[0] == value[1] == center:
|
||||
value = None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def get_params(
|
||||
brightness: Optional[List[float]],
|
||||
contrast: Optional[List[float]],
|
||||
saturation: Optional[List[float]],
|
||||
hue: Optional[List[float]],
|
||||
) -> Tuple[np.ndarray, Optional[float], Optional[float], Optional[float],
|
||||
Optional[float]]:
|
||||
"""Get the parameters for the randomized transform to be applied on
|
||||
image.
|
||||
|
||||
Args:
|
||||
brightness (tuple of float (min, max), optional): The range from
|
||||
which the brightness_factor is chosen uniformly. Pass None to
|
||||
turn off the transformation.
|
||||
contrast (tuple of float (min, max), optional): The range from
|
||||
which the contrast_factor is chosen uniformly. Pass None to
|
||||
turn off the transformation.
|
||||
saturation (tuple of float (min, max), optional): The range from
|
||||
which the saturation_factor is chosen uniformly. Pass None to
|
||||
turn off the transformation.
|
||||
hue (tuple of float (min, max), optional): The range from which the
|
||||
hue_factor is chosen uniformly. Pass None to turn off the
|
||||
transformation.
|
||||
|
||||
Returns:
|
||||
tuple: The parameters used to apply the randomized transform
|
||||
along with their random order.
|
||||
"""
|
||||
order = np.random.permutation(4)
|
||||
|
||||
b = None if brightness is None else float(
|
||||
np.random.uniform(brightness[0], brightness[1]))
|
||||
c = None if contrast is None else float(
|
||||
np.random.uniform(contrast[0], contrast[1]))
|
||||
s = None if saturation is None else float(
|
||||
np.random.uniform(saturation[0], saturation[1]))
|
||||
h = None if hue is None else float(np.random.uniform(hue[0], hue[1]))
|
||||
|
||||
return b, c, s, h, order
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
brightness_factor, contrast_factor, saturation_factor, hue_factor, \
|
||||
order = self.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue)
|
||||
|
||||
img = results['img']
|
||||
for fn_id in order:
|
||||
if fn_id == 0 and brightness_factor is not None:
|
||||
img = adjust_brightness(img, brightness_factor)
|
||||
elif fn_id == 1 and contrast_factor is not None:
|
||||
img = adjust_contrast(img, contrast_factor)
|
||||
elif fn_id == 2 and saturation_factor is not None:
|
||||
img = adjust_color(img, saturation_factor)
|
||||
elif fn_id == 3 and hue_factor is not None:
|
||||
img = adjust_hue(img, hue_factor)
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(brightness={self.brightness}, '
|
||||
repr_str += f'contrast={self.contrast}, '
|
||||
repr_str += f'saturation={self.saturation},'
|
||||
repr_str += f'saturation={self.hue})'
|
||||
return repr_str
|
||||
|
|
|
@ -4,9 +4,9 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmselfsup.datasets.pipelines import (
|
||||
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
|
||||
RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
|
||||
RandomSolarize, SimMIMMaskGenerator)
|
||||
BEiTMaskGenerator, ColorJitter, Lighting, RandomGaussianBlur,
|
||||
RandomPatchWithLabels, RandomResizedCropAndInterpolationWithTwoPic,
|
||||
RandomRotationWithLabels, RandomSolarize, SimMIMMaskGenerator)
|
||||
|
||||
|
||||
def test_simmim_mask_gen():
|
||||
|
@ -139,3 +139,27 @@ def test_random_patch():
|
|||
# test transform
|
||||
assert list(results['img'].shape) == [8, 6, 53, 53]
|
||||
assert list(results['patch_label'].shape) == [8]
|
||||
|
||||
|
||||
def test_color_jitter():
|
||||
with pytest.raises(ValueError):
|
||||
transform = ColorJitter(-1, 0, 0, 0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
transform = ColorJitter(0, 0, 0, [0, 1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
transform = ColorJitter('test', 0, 0, 0)
|
||||
|
||||
original_img = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||||
results = {'img': original_img}
|
||||
|
||||
transform = ColorJitter(0, 0, 0, 0)
|
||||
results = transform(results)
|
||||
assert np.equal(results['img'], original_img).all()
|
||||
|
||||
transform = ColorJitter(0.4, 0.4, 0.2, 0.1)
|
||||
results = transform(results)
|
||||
assert results['img'].shape == original_img.shape
|
||||
|
||||
assert isinstance(str(transform), str)
|
||||
|
|
Loading…
Reference in New Issue