[Refactor] Add backend option for ColorJitter (#353)

* add backend option for ColorJitter

* revise docstring
This commit is contained in:
Yixiao Fang 2022-07-19 19:00:29 +08:00 committed by GitHub
parent c7101c7648
commit d4ee6e5c2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -614,19 +614,23 @@ class ColorJitter(BaseTransform):
if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using
this function.
backend (str): The type of image processing backend. Options are
`cv2`, `pillow`. Defaults to `pillow`.
""" # noqa: E501
def __init__(self,
brightness: Union[float, List[float]] = 0,
contrast: Union[float, List[float]] = 0,
saturation: Union[float, List[float]] = 0,
hue: Union[float, List[float]] = 0) -> None:
hue: Union[float, List[float]] = 0,
backend: str = 'pillow') -> None:
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(
hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
self.backend = backend
def _check_input(
self,
@ -706,13 +710,13 @@ class ColorJitter(BaseTransform):
img = results['img'].astype('uint8')
for fn_id in order:
if fn_id == 0 and brightness_factor is not None:
img = adjust_brightness(img, brightness_factor)
img = adjust_brightness(img, brightness_factor, self.backend)
elif fn_id == 1 and contrast_factor is not None:
img = adjust_contrast(img, contrast_factor)
img = adjust_contrast(img, contrast_factor, self.backend)
elif fn_id == 2 and saturation_factor is not None:
img = adjust_color(img, saturation_factor)
img = adjust_color(img, saturation_factor, self.backend)
elif fn_id == 3 and hue_factor is not None:
img = adjust_hue(img, hue_factor)
img = adjust_hue(img, hue_factor, self.backend)
results['img'] = img
return results