[Feature] Add CLAHE method (#647)

* add CLAHE

* add CLAHE

* restore

* Add docstring

* modify docstring

* modify CLAHE to clahe

* fix syntax error

* simplify assert

* simplify assert

* add assert test

* fix unittest bug

* fix syntax bug

* fix assert bug
pull/651/head
yamengxi 2020-11-11 22:34:14 +08:00 committed by GitHub
parent d9ef9dabe2
commit c6c230df1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 3 deletions

View File

@ -8,8 +8,9 @@ from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple,
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs
from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
imdenormalize, imequalize, iminvert, imnormalize,
imnormalize_, lut_transform, posterize, solarize)
clahe, imdenormalize, imequalize, iminvert,
imnormalize, imnormalize_, lut_transform, posterize,
solarize)
__all__ = [
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
@ -20,5 +21,5 @@ __all__ = [
'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize',
'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', 'tensor2imgs',
'imshear', 'imtranslate', 'adjust_color', 'imequalize',
'adjust_brightness', 'adjust_contrast', 'lut_transform'
'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe'
]

View File

@ -1,6 +1,7 @@
import cv2
import numpy as np
from ..utils import is_tuple_of
from .colorspace import bgr2gray, gray2bgr
@ -248,3 +249,29 @@ def lut_transform(img, lut_table):
assert lut_table.shape == (256, )
return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Args:
img (ndarray): Image to be processed.
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
Returns:
ndarray: The processed image.
"""
assert isinstance(img, np.ndarray)
assert img.ndim == 2
assert isinstance(clip_limit, (float, int))
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
return clahe.apply(np.array(img, dtype=np.uint8))

View File

@ -3,6 +3,7 @@ import os.path as osp
import cv2
import numpy as np
import pytest
from numpy.testing import assert_array_equal
import mmcv
@ -204,6 +205,18 @@ class TestPhotometric:
def test_lut_transform(self):
lut_table = np.array(list(range(256)))
# test assertion image values should between 0 and 255.
with pytest.raises(AssertionError):
mmcv.lut_transform(np.array([256]), lut_table)
with pytest.raises(AssertionError):
mmcv.lut_transform(np.array([-1]), lut_table)
# test assertion lut_table should be ndarray with shape (256, )
with pytest.raises(AssertionError):
mmcv.lut_transform(np.array([0]), list(range(256)))
with pytest.raises(AssertionError):
mmcv.lut_transform(np.array([1]), np.array(list(range(257))))
img = mmcv.lut_transform(self.img, lut_table)
baseline = cv2.LUT(self.img, lut_table)
assert np.allclose(img, baseline)
@ -219,3 +232,37 @@ class TestPhotometric:
img = mmcv.lut_transform(input_img, lut_table)
baseline = cv2.LUT(np.array(input_img, dtype=np.uint8), lut_table)
assert np.allclose(img, baseline)
def test_clahe(self):
def _clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
return clahe.apply(np.array(img, dtype=np.uint8))
# test assertion image should have the right shape
with pytest.raises(AssertionError):
mmcv.clahe(self.img)
# test assertion tile_grid_size should be a tuple with 2 integers
with pytest.raises(AssertionError):
mmcv.clahe(self.img[:, :, 0], tile_grid_size=(8.0, 8.0))
with pytest.raises(AssertionError):
mmcv.clahe(self.img[:, :, 0], tile_grid_size=(8, 8, 8))
with pytest.raises(AssertionError):
mmcv.clahe(self.img[:, :, 0], tile_grid_size=[8, 8])
# test with different channels
for i in range(self.img.shape[-1]):
img = mmcv.clahe(self.img[:, :, i])
img_std = _clahe(self.img[:, :, i])
assert np.allclose(img, img_std)
assert id(img) != id(self.img[:, :, i])
assert id(img_std) != id(self.img[:, :, i])
# test case with clip_limit=1.2
for i in range(self.img.shape[-1]):
img = mmcv.clahe(self.img[:, :, i], 1.2)
img_std = _clahe(self.img[:, :, i], 1.2)
assert np.allclose(img, img_std)
assert id(img) != id(self.img[:, :, i])
assert id(img_std) != id(self.img[:, :, i])