mirror of https://github.com/open-mmlab/mmcv.git
[Refactor] Add pillow backend for ColorJitter related functions (#2127)
* add 'backend' for ColorJitter related functions * add unittest * fix unittest * add bgr transpose and revise ut * update unittest * revise docstringpull/2188/head
parent
bad822d784
commit
b2ac245602
|
@ -1,9 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance
|
||||
|
||||
from ..utils import is_tuple_of
|
||||
from .colorspace import bgr2gray, gray2bgr
|
||||
from .io import imread_backend
|
||||
|
||||
|
||||
def imnormalize(img, mean, std, to_rgb=True):
|
||||
|
@ -97,7 +102,7 @@ def posterize(img, bits):
|
|||
return img
|
||||
|
||||
|
||||
def adjust_color(img, alpha=1, beta=None, gamma=0):
|
||||
def adjust_color(img, alpha=1, beta=None, gamma=0, backend=None):
|
||||
r"""It blends the source image and its gray image:
|
||||
|
||||
.. math::
|
||||
|
@ -110,22 +115,41 @@ def adjust_color(img, alpha=1, beta=None, gamma=0):
|
|||
If None, it's assigned the value (1 - `alpha`).
|
||||
gamma (int | float): Scalar added to each sum.
|
||||
Same as :func:`cv2.addWeighted`. Default 0.
|
||||
backend (str | None): The image processing backend type. Options are
|
||||
`cv2`, `pillow`, `None`. If backend is None, the global
|
||||
``imread_backend`` specified by ``mmcv.use_backend()`` will be
|
||||
used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ndarray: Colored image which has the same size and dtype as input.
|
||||
"""
|
||||
gray_img = bgr2gray(img)
|
||||
gray_img = np.tile(gray_img[..., None], [1, 1, 3])
|
||||
if beta is None:
|
||||
beta = 1 - alpha
|
||||
colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
|
||||
if not colored_img.dtype == np.uint8:
|
||||
# Note when the dtype of `img` is not the default `np.uint8`
|
||||
# (e.g. np.float32), the value in `colored_img` got from cv2
|
||||
# is not guaranteed to be in range [0, 255], so here clip
|
||||
# is needed.
|
||||
colored_img = np.clip(colored_img, 0, 255)
|
||||
return colored_img
|
||||
if backend is None:
|
||||
backend = imread_backend
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported.'
|
||||
f"Supported backends are 'cv2', 'pillow'")
|
||||
|
||||
if backend == 'pillow':
|
||||
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
||||
warnings.warn("Only use 'alpha' for pillow backend.")
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
enhancer = ImageEnhance.Color(pil_image)
|
||||
pil_image = enhancer.enhance(alpha)
|
||||
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
|
||||
else:
|
||||
gray_img = bgr2gray(img)
|
||||
gray_img = np.tile(gray_img[..., None], [1, 1, 3])
|
||||
if beta is None:
|
||||
beta = 1 - alpha
|
||||
colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
|
||||
if not colored_img.dtype == np.uint8:
|
||||
# Note when the dtype of `img` is not the default `np.uint8`
|
||||
# (e.g. np.float32), the value in `colored_img` got from cv2
|
||||
# is not guaranteed to be in range [0, 255], so here clip
|
||||
# is needed.
|
||||
colored_img = np.clip(colored_img, 0, 255)
|
||||
return colored_img.astype(img.dtype)
|
||||
|
||||
|
||||
def imequalize(img):
|
||||
|
@ -173,7 +197,7 @@ def imequalize(img):
|
|||
return equalized_img.astype(img.dtype)
|
||||
|
||||
|
||||
def adjust_brightness(img, factor=1.):
|
||||
def adjust_brightness(img, factor=1., backend=None):
|
||||
"""Adjust image brightness.
|
||||
|
||||
This function controls the brightness of an image. An
|
||||
|
@ -190,22 +214,40 @@ def adjust_brightness(img, factor=1.):
|
|||
Factor 1.0 returns the original image, lower
|
||||
factors mean less color (brightness, contrast,
|
||||
etc), and higher values more. Default 1.
|
||||
backend (str | None): The image processing backend type. Options are
|
||||
`cv2`, `pillow`, `None`. If backend is None, the global
|
||||
``imread_backend`` specified by ``mmcv.use_backend()`` will be
|
||||
used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ndarray: The brightened image.
|
||||
"""
|
||||
degenerated = np.zeros_like(img)
|
||||
# Note manually convert the dtype to np.float32, to
|
||||
# achieve as close results as PIL.ImageEnhance.Brightness.
|
||||
# Set beta=1-factor, and gamma=0
|
||||
brightened_img = cv2.addWeighted(
|
||||
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
||||
1 - factor, 0)
|
||||
brightened_img = np.clip(brightened_img, 0, 255)
|
||||
return brightened_img.astype(img.dtype)
|
||||
if backend is None:
|
||||
backend = imread_backend
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported.'
|
||||
f"Supported backends are 'cv2', 'pillow'")
|
||||
|
||||
if backend == 'pillow':
|
||||
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
enhancer = ImageEnhance.Brightness(pil_image)
|
||||
pil_image = enhancer.enhance(factor)
|
||||
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
|
||||
else:
|
||||
degenerated = np.zeros_like(img)
|
||||
# Note manually convert the dtype to np.float32, to
|
||||
# achieve as close results as PIL.ImageEnhance.Brightness.
|
||||
# Set beta=1-factor, and gamma=0
|
||||
brightened_img = cv2.addWeighted(
|
||||
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
||||
1 - factor, 0)
|
||||
brightened_img = np.clip(brightened_img, 0, 255)
|
||||
return brightened_img.astype(img.dtype)
|
||||
|
||||
|
||||
def adjust_contrast(img, factor=1.):
|
||||
def adjust_contrast(img, factor=1., backend=None):
|
||||
"""Adjust image contrast.
|
||||
|
||||
This function controls the contrast of an image. An
|
||||
|
@ -219,20 +261,38 @@ def adjust_contrast(img, factor=1.):
|
|||
Args:
|
||||
img (ndarray): Image to be contrasted. BGR order.
|
||||
factor (float): Same as :func:`mmcv.adjust_brightness`.
|
||||
backend (str | None): The image processing backend type. Options are
|
||||
`cv2`, `pillow`, `None`. If backend is None, the global
|
||||
``imread_backend`` specified by ``mmcv.use_backend()`` will be
|
||||
used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ndarray: The contrasted image.
|
||||
"""
|
||||
gray_img = bgr2gray(img)
|
||||
hist = np.histogram(gray_img, 256, (0, 255))[0]
|
||||
mean = round(np.sum(gray_img) / np.sum(hist))
|
||||
degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
|
||||
degenerated = gray2bgr(degenerated)
|
||||
contrasted_img = cv2.addWeighted(
|
||||
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
||||
1 - factor, 0)
|
||||
contrasted_img = np.clip(contrasted_img, 0, 255)
|
||||
return contrasted_img.astype(img.dtype)
|
||||
if backend is None:
|
||||
backend = imread_backend
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported.'
|
||||
f"Supported backends are 'cv2', 'pillow'")
|
||||
|
||||
if backend == 'pillow':
|
||||
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
enhancer = ImageEnhance.Contrast(pil_image)
|
||||
pil_image = enhancer.enhance(factor)
|
||||
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
|
||||
else:
|
||||
gray_img = bgr2gray(img)
|
||||
hist = np.histogram(gray_img, 256, (0, 255))[0]
|
||||
mean = round(np.sum(gray_img) / np.sum(hist))
|
||||
degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
|
||||
degenerated = gray2bgr(degenerated)
|
||||
contrasted_img = cv2.addWeighted(
|
||||
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
||||
1 - factor, 0)
|
||||
contrasted_img = np.clip(contrasted_img, 0, 255)
|
||||
return contrasted_img.astype(img.dtype)
|
||||
|
||||
|
||||
def auto_contrast(img, cutoff=0):
|
||||
|
@ -428,7 +488,9 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
|
|||
return clahe.apply(np.array(img, dtype=np.uint8))
|
||||
|
||||
|
||||
def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray:
|
||||
def adjust_hue(img: np.ndarray,
|
||||
hue_factor: float,
|
||||
backend: Optional[str] = None) -> np.ndarray:
|
||||
"""Adjust hue of an image.
|
||||
|
||||
The image hue is adjusted by converting the image to HSV and cyclically
|
||||
|
@ -449,23 +511,51 @@ def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray:
|
|||
HSV space in positive and negative direction respectively.
|
||||
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
||||
with complementary colors while 0 gives the original image.
|
||||
backend (str | None): The image processing backend type. Options are
|
||||
`cv2`, `pillow`, `None`. If backend is None, the global
|
||||
``imread_backend`` specified by ``mmcv.use_backend()`` will be
|
||||
used. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ndarray: Hue adjusted image.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = imread_backend
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported.'
|
||||
f"Supported backends are 'cv2', 'pillow'")
|
||||
|
||||
if not (-0.5 <= hue_factor <= 0.5):
|
||||
raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].')
|
||||
if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})):
|
||||
raise TypeError('img should be ndarray with dim=[2 or 3].')
|
||||
|
||||
dtype = img.dtype
|
||||
img = img.astype(np.uint8)
|
||||
hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
|
||||
h, s, v = cv2.split(hsv_img)
|
||||
h = h.astype(np.uint8)
|
||||
# uint8 addition take cares of rotation across boundaries
|
||||
with np.errstate(over='ignore'):
|
||||
h += np.uint8(hue_factor * 255)
|
||||
hsv_img = cv2.merge([h, s, v])
|
||||
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype)
|
||||
if backend == 'pillow':
|
||||
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
pil_image = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
input_mode = pil_image.mode
|
||||
if input_mode in {'L', '1', 'I', 'F'}:
|
||||
return pil_image
|
||||
|
||||
h, s, v = pil_image.convert('HSV').split()
|
||||
|
||||
np_h = np.array(h, dtype=np.uint8)
|
||||
# uint8 addition take cares of rotation across boundaries
|
||||
with np.errstate(over='ignore'):
|
||||
np_h += np.uint8(hue_factor * 255)
|
||||
h = Image.fromarray(np_h, 'L')
|
||||
|
||||
pil_image = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
||||
return np.array(pil_image, dtype=img.dtype)[..., ::-1]
|
||||
else:
|
||||
dtype = img.dtype
|
||||
img = img.astype(np.uint8)
|
||||
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV_FULL)
|
||||
h, s, v = cv2.split(hsv_img)
|
||||
h = h.astype(np.uint8)
|
||||
# uint8 addition take cares of rotation across boundaries
|
||||
with np.errstate(over='ignore'):
|
||||
h += np.uint8(hue_factor * 255)
|
||||
hsv_img = cv2.merge([h, s, v])
|
||||
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
|
||||
|
|
|
@ -77,7 +77,7 @@ class TestPhotometric:
|
|||
dtype=np.uint8)
|
||||
assert_array_equal(mmcv.posterize(img, 3), img_r)
|
||||
|
||||
def test_adjust_color(self):
|
||||
def test_adjust_color(self, nb_rand_test=100):
|
||||
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
|
||||
dtype=np.uint8)
|
||||
img = np.stack([img, img, img], axis=-1)
|
||||
|
@ -108,6 +108,23 @@ class TestPhotometric:
|
|||
np.round(mmcv.adjust_color(img, 0.8, -0.6, gamma=-0.6)),
|
||||
np.round(np.clip(img * 0.8 - 0.6 * img_r - 0.6, 0, 255)))
|
||||
|
||||
# test equalize with randomly sampled image.
|
||||
for _ in range(nb_rand_test):
|
||||
img = np.clip(np.random.normal(0, 1, (256, 256, 3)) * 260, 0,
|
||||
255).astype(np.uint8)
|
||||
factor = np.random.uniform()
|
||||
cv2_img = mmcv.adjust_color(img, alpha=factor)
|
||||
pil_img = mmcv.adjust_color(img, alpha=factor, backend='pillow')
|
||||
np.testing.assert_allclose(cv2_img, pil_img, rtol=0, atol=2)
|
||||
|
||||
# the input type must be uint8 for pillow backend
|
||||
with pytest.raises(AssertionError):
|
||||
mmcv.adjust_color(img.astype(np.float32), backend='pillow')
|
||||
|
||||
# backend must be 'cv2' or 'pillow'
|
||||
with pytest.raises(ValueError):
|
||||
mmcv.adjust_color(img.astype(np.uint8), backend='not support')
|
||||
|
||||
def test_imequalize(self, nb_rand_test=100):
|
||||
|
||||
def _imequalize(img):
|
||||
|
@ -138,15 +155,6 @@ class TestPhotometric:
|
|||
|
||||
def test_adjust_brightness(self, nb_rand_test=100):
|
||||
|
||||
def _adjust_brightness(img, factor):
|
||||
# adjust the brightness of image using
|
||||
# PIL.ImageEnhance.Brightness
|
||||
from PIL import Image
|
||||
from PIL.ImageEnhance import Brightness
|
||||
img = Image.fromarray(img)
|
||||
brightened_img = Brightness(img).enhance(factor)
|
||||
return np.asarray(brightened_img)
|
||||
|
||||
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
|
||||
dtype=np.uint8)
|
||||
img = np.stack([img, img, img], axis=-1)
|
||||
|
@ -162,23 +170,21 @@ class TestPhotometric:
|
|||
factor = np.random.uniform() + np.random.choice([0, 1])
|
||||
np.testing.assert_allclose(
|
||||
mmcv.adjust_brightness(img, factor).astype(np.int32),
|
||||
_adjust_brightness(img, factor).astype(np.int32),
|
||||
mmcv.adjust_brightness(img, factor,
|
||||
backend='pillow').astype(np.int32),
|
||||
rtol=0,
|
||||
atol=1)
|
||||
|
||||
# the input type must be uint8 for pillow backend
|
||||
with pytest.raises(AssertionError):
|
||||
mmcv.adjust_brightness(img.astype(np.float32), backend='pillow')
|
||||
|
||||
# backend must be 'cv2' or 'pillow'
|
||||
with pytest.raises(ValueError):
|
||||
mmcv.adjust_brightness(img.astype(np.uint8), backend='not support')
|
||||
|
||||
def test_adjust_contrast(self, nb_rand_test=100):
|
||||
|
||||
def _adjust_contrast(img, factor):
|
||||
from PIL import Image
|
||||
from PIL.ImageEnhance import Contrast
|
||||
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
# convert from BGR to RGB
|
||||
img = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
contrasted_img = Contrast(img).enhance(factor)
|
||||
# convert from RGB to BGR
|
||||
return np.asarray(contrasted_img)[..., ::-1]
|
||||
|
||||
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
|
||||
dtype=np.uint8)
|
||||
img = np.stack([img, img, img], axis=-1)
|
||||
|
@ -186,7 +192,8 @@ class TestPhotometric:
|
|||
assert_array_equal(mmcv.adjust_contrast(img, 1.), img)
|
||||
# test case with factor 0.0
|
||||
assert_array_equal(
|
||||
mmcv.adjust_contrast(img, 0.), _adjust_contrast(img, 0.))
|
||||
mmcv.adjust_contrast(img, 0.),
|
||||
mmcv.adjust_contrast(img, 0., backend='pillow'))
|
||||
# test adjust_contrast with randomly sampled images and factors.
|
||||
for _ in range(nb_rand_test):
|
||||
img = np.clip(
|
||||
|
@ -198,10 +205,19 @@ class TestPhotometric:
|
|||
# a color image to gray image using mmcv or PIL.
|
||||
np.testing.assert_allclose(
|
||||
mmcv.adjust_contrast(img, factor).astype(np.int32),
|
||||
_adjust_contrast(img, factor).astype(np.int32),
|
||||
mmcv.adjust_contrast(img, factor,
|
||||
backend='pillow').astype(np.int32),
|
||||
rtol=0,
|
||||
atol=1)
|
||||
|
||||
# the input type must be uint8 pillow backend
|
||||
with pytest.raises(AssertionError):
|
||||
mmcv.adjust_contrast(img.astype(np.float32), backend='pillow')
|
||||
|
||||
# backend must be 'cv2' or 'pillow'
|
||||
with pytest.raises(ValueError):
|
||||
mmcv.adjust_contrast(img.astype(np.uint8), backend='not support')
|
||||
|
||||
def test_auto_contrast(self, nb_rand_test=100):
|
||||
|
||||
def _auto_contrast(img, cutoff=0):
|
||||
|
@ -380,24 +396,10 @@ class TestPhotometric:
|
|||
assert id(img_std) != id(self.img[:, :, i])
|
||||
|
||||
def test_adjust_hue(self):
|
||||
# test case with img is not ndarray
|
||||
from PIL import Image
|
||||
|
||||
def _adjust_hue(img, hue_factor):
|
||||
input_mode = img.mode
|
||||
if input_mode in {'L', '1', 'I', 'F'}:
|
||||
return img
|
||||
h, s, v = img.convert('HSV').split()
|
||||
np_h = np.array(h, dtype=np.uint8)
|
||||
# uint8 addition take cares of rotation across boundaries
|
||||
with np.errstate(over='ignore'):
|
||||
np_h += np.uint8(hue_factor * 255)
|
||||
h = Image.fromarray(np_h, 'L')
|
||||
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
||||
return img
|
||||
|
||||
pil_img = Image.fromarray(self.img)
|
||||
|
||||
# test case with img is not ndarray
|
||||
with pytest.raises(TypeError):
|
||||
mmcv.adjust_hue(pil_img, hue_factor=0.0)
|
||||
|
||||
|
@ -408,7 +410,17 @@ class TestPhotometric:
|
|||
mmcv.adjust_hue(self.img, hue_factor=0.6)
|
||||
|
||||
for i in np.arange(-0.5, 0.5, 0.2):
|
||||
pil_res = _adjust_hue(pil_img, hue_factor=i)
|
||||
pil_res = mmcv.adjust_hue(self.img, hue_factor=i, backend='pillow')
|
||||
pil_res = np.array(pil_res)
|
||||
cv2_res = mmcv.adjust_hue(self.img, hue_factor=i)
|
||||
assert np.allclose(pil_res, cv2_res, atol=10.0)
|
||||
|
||||
# test pillow backend
|
||||
with pytest.raises(AssertionError):
|
||||
mmcv.adjust_hue(
|
||||
self.img.astype(np.float32), hue_factor=0, backend='pillow')
|
||||
|
||||
# backend must be 'cv2' or 'pillow'
|
||||
with pytest.raises(ValueError):
|
||||
mmcv.adjust_hue(
|
||||
self.img.astype(np.uint8), hue_factor=0, backend='not support')
|
||||
|
|
Loading…
Reference in New Issue