From c6c230df1b780976ee99f59e4644941967db39f9 Mon Sep 17 00:00:00 2001 From: yamengxi <49829199+yamengxi@users.noreply.github.com> Date: Wed, 11 Nov 2020 22:34:14 +0800 Subject: [PATCH] [Feature] Add CLAHE method (#647) * add CLAHE * add CLAHE * restore * Add docstring * modify docstring * modify CLAHE to clahe * fix syntax error * simplify assert * simplify assert * add assert test * fix unittest bug * fix syntax bug * fix assert bug --- mmcv/image/__init__.py | 7 +++-- mmcv/image/photometric.py | 27 ++++++++++++++++ tests/test_image/test_photometric.py | 47 ++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/mmcv/image/__init__.py b/mmcv/image/__init__.py index ba789bf62..06d0c9d88 100644 --- a/mmcv/image/__init__.py +++ b/mmcv/image/__init__.py @@ -8,8 +8,9 @@ from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple, from .io import imfrombytes, imread, imwrite, supported_backends, use_backend from .misc import tensor2imgs from .photometric import (adjust_brightness, adjust_color, adjust_contrast, - imdenormalize, imequalize, iminvert, imnormalize, - imnormalize_, lut_transform, posterize, solarize) + clahe, imdenormalize, imequalize, iminvert, + imnormalize, imnormalize_, lut_transform, posterize, + solarize) __all__ = [ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb', @@ -20,5 +21,5 @@ __all__ = [ 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize', - 'adjust_brightness', 'adjust_contrast', 'lut_transform' + 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe' ] diff --git a/mmcv/image/photometric.py b/mmcv/image/photometric.py index 93427e220..f0279274b 100644 --- a/mmcv/image/photometric.py +++ b/mmcv/image/photometric.py @@ -1,6 +1,7 @@ import cv2 import numpy as np +from ..utils import is_tuple_of from .colorspace import bgr2gray, gray2bgr @@ -248,3 +249,29 @@ def lut_transform(img, lut_table): assert lut_table.shape == (256, ) return cv2.LUT(np.array(img, dtype=np.uint8), lut_table) + + +def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + img (ndarray): Image to be processed. + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + + Returns: + ndarray: The processed image. + """ + assert isinstance(img, np.ndarray) + assert img.ndim == 2 + assert isinstance(clip_limit, (float, int)) + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + + clahe = cv2.createCLAHE(clip_limit, tile_grid_size) + return clahe.apply(np.array(img, dtype=np.uint8)) diff --git a/tests/test_image/test_photometric.py b/tests/test_image/test_photometric.py index e68adc34c..70185d94d 100644 --- a/tests/test_image/test_photometric.py +++ b/tests/test_image/test_photometric.py @@ -3,6 +3,7 @@ import os.path as osp import cv2 import numpy as np +import pytest from numpy.testing import assert_array_equal import mmcv @@ -204,6 +205,18 @@ class TestPhotometric: def test_lut_transform(self): lut_table = np.array(list(range(256))) + # test assertion image values should between 0 and 255. + with pytest.raises(AssertionError): + mmcv.lut_transform(np.array([256]), lut_table) + with pytest.raises(AssertionError): + mmcv.lut_transform(np.array([-1]), lut_table) + + # test assertion lut_table should be ndarray with shape (256, ) + with pytest.raises(AssertionError): + mmcv.lut_transform(np.array([0]), list(range(256))) + with pytest.raises(AssertionError): + mmcv.lut_transform(np.array([1]), np.array(list(range(257)))) + img = mmcv.lut_transform(self.img, lut_table) baseline = cv2.LUT(self.img, lut_table) assert np.allclose(img, baseline) @@ -219,3 +232,37 @@ class TestPhotometric: img = mmcv.lut_transform(input_img, lut_table) baseline = cv2.LUT(np.array(input_img, dtype=np.uint8), lut_table) assert np.allclose(img, baseline) + + def test_clahe(self): + + def _clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)): + clahe = cv2.createCLAHE(clip_limit, tile_grid_size) + return clahe.apply(np.array(img, dtype=np.uint8)) + + # test assertion image should have the right shape + with pytest.raises(AssertionError): + mmcv.clahe(self.img) + + # test assertion tile_grid_size should be a tuple with 2 integers + with pytest.raises(AssertionError): + mmcv.clahe(self.img[:, :, 0], tile_grid_size=(8.0, 8.0)) + with pytest.raises(AssertionError): + mmcv.clahe(self.img[:, :, 0], tile_grid_size=(8, 8, 8)) + with pytest.raises(AssertionError): + mmcv.clahe(self.img[:, :, 0], tile_grid_size=[8, 8]) + + # test with different channels + for i in range(self.img.shape[-1]): + img = mmcv.clahe(self.img[:, :, i]) + img_std = _clahe(self.img[:, :, i]) + assert np.allclose(img, img_std) + assert id(img) != id(self.img[:, :, i]) + assert id(img_std) != id(self.img[:, :, i]) + + # test case with clip_limit=1.2 + for i in range(self.img.shape[-1]): + img = mmcv.clahe(self.img[:, :, i], 1.2) + img_std = _clahe(self.img[:, :, i], 1.2) + assert np.allclose(img, img_std) + assert id(img) != id(self.img[:, :, i]) + assert id(img_std) != id(self.img[:, :, i])