diff --git a/mmcv/image/__init__.py b/mmcv/image/__init__.py index dda6acae4..dd985c54f 100644 --- a/mmcv/image/__init__.py +++ b/mmcv/image/__init__.py @@ -1,18 +1,19 @@ # Copyright (c) Open-MMLab. All rights reserved. from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr, - gray2rgb, hls2bgr, hsv2bgr, iminvert, posterize, - rgb2bgr, rgb2gray, solarize) -from .geometry import (imcrop, imflip, imflip_, impad, impad_to_multiple, - imrotate) + gray2rgb, hls2bgr, hsv2bgr, imconvert, rgb2bgr, + rgb2gray) +from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple, + imrescale, imresize, imresize_like, imrotate, + rescale_size) from .io import imfrombytes, imread, imwrite, supported_backends, use_backend -from .normalize import imdenormalize, imnormalize, imnormalize_ -from .resize import imrescale, imresize, imresize_like, rescale_size +from .photometric import (imdenormalize, iminvert, imnormalize, imnormalize_, + posterize, solarize) __all__ = [ - 'solarize', 'posterize', 'imread', 'imwrite', 'imfrombytes', 'bgr2gray', - 'rgb2gray', 'gray2bgr', 'gray2rgb', 'bgr2rgb', 'rgb2bgr', 'bgr2hsv', - 'hsv2bgr', 'bgr2hls', 'hls2bgr', 'iminvert', 'imflip', 'imflip_', - 'imrotate', 'imcrop', 'impad', 'impad_to_multiple', 'imnormalize', - 'imnormalize_', 'imdenormalize', 'imresize', 'imresize_like', 'imrescale', - 'use_backend', 'supported_backends', 'rescale_size' + 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb', + 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale', + 'imresize', 'imresize_like', 'rescale_size', 'imcrop', 'imflip', 'imflip_', + 'impad', 'impad_to_multiple', 'imrotate', 'imfrombytes', 'imread', + 'imwrite', 'supported_backends', 'use_backend', 'imdenormalize', + 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize' ] diff --git a/mmcv/image/colorspace.py b/mmcv/image/colorspace.py index f0d5a503c..ac001a509 100644 --- a/mmcv/image/colorspace.py +++ b/mmcv/image/colorspace.py @@ -1,46 +1,21 @@ # Copyright (c) Open-MMLab. All rights reserved. import cv2 -import numpy as np -def solarize(img, thr=128): - """Solarize an image (invert all pixel values above a threshold) +def imconvert(img, src, dst): + """Convert an image from the src colorspace to dst colorspace. Args: - img (ndarray): Image to be solarized. - thr (int): Threshold for solarizing (0 - 255). + img (ndarray): The input image. + src (str): The source colorspace, e.g., 'rgb', 'hsv'. + dst (str): The destination colorspace, e.g., 'rgb', 'hsv'. Returns: - ndarray: The solarized image. + ndarray: The converted image. """ - img = np.where(img < thr, img, 255 - img) - return img - - -def posterize(img, bits): - """Posterize an image (reduce the number of bits for each color channel) - - Args: - img (ndarray): Image to be posterized. - bits (int): Number of bits (1 to 8) to use for posterizing. - - Returns: - ndarray: The posterized image. - """ - shift = 8 - bits - img = np.left_shift(np.right_shift(img, shift), shift) - return img - - -def iminvert(img): - """Invert (negate) an image - Args: - img (ndarray): Image to be inverted. - - Returns: - ndarray: The inverted image. - """ - return np.full_like(img, 255) - img + code = getattr(cv2, 'COLOR_{}2{}'.format(src.upper(), dst.upper())) + out_img = cv2.cvtColor(img, code) + return out_img def bgr2gray(img, keepdim=False): diff --git a/mmcv/image/geometry.py b/mmcv/image/geometric.py similarity index 55% rename from mmcv/image/geometry.py rename to mmcv/image/geometric.py index e3947b428..325bd430e 100644 --- a/mmcv/image/geometry.py +++ b/mmcv/image/geometric.py @@ -1,10 +1,142 @@ # Copyright (c) Open-MMLab. All rights reserved. -from __future__ import division - import cv2 import numpy as np +def _scale_size(size, scale): + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + w, h = size + return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5) + + +interp_codes = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'bicubic': cv2.INTER_CUBIC, + 'area': cv2.INTER_AREA, + 'lanczos': cv2.INTER_LANCZOS4 +} + + +def imresize(img, + size, + return_scale=False, + interpolation='bilinear', + out=None): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos". + out (ndarray): The output destination. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + resized_img = cv2.resize( + img, size, dst=out, interpolation=interp_codes[interpolation]) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def imresize_like(img, dst_img, return_scale=False, interpolation='bilinear'): + """Resize image to the same size of a given image. + + Args: + img (ndarray): The input image. + dst_img (ndarray): The target image. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Same as :func:`resize`. + + Returns: + tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = dst_img.shape[:2] + return imresize(img, (w, h), return_scale, interpolation) + + +def rescale_size(old_size, scale, return_scale=False): + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError( + 'Invalid scale {}, must be positive.'.format(scale)) + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError( + 'Scale must be a number or tuple of int, but got {}'.format( + type(scale))) + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale(img, scale, return_scale=False, interpolation='bilinear'): + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize(img, new_size, interpolation=interpolation) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + def imflip(img, direction='horizontal'): """Flip an image horizontally or vertically. @@ -24,12 +156,13 @@ def imflip(img, direction='horizontal'): def imflip_(img, direction='horizontal'): """Inplace flip an image horizontally or vertically. + Args: img (ndarray): Image to be flipped. direction (str): The flip direction, either "horizontal" or "vertical". Returns: - ndarray: The flipped image(inplace). + ndarray: The flipped image (inplace). """ assert direction in ['horizontal', 'vertical'] if direction == 'horizontal': @@ -50,8 +183,9 @@ def imrotate(img, img (ndarray): Image to be rotated. angle (float): Rotation angle in degrees, positive values mean clockwise rotation. - center (tuple): Center of the rotation in the source image, by default - it is the center of the image. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. scale (float): Isotropic scale factor. border_value (int): Border value. auto_bound (bool): Whether to adjust the image size to cover the whole @@ -86,7 +220,7 @@ def bbox_clip(bboxes, img_shape): Args: bboxes (ndarray): Shape (..., 4*k) - img_shape (tuple): (height, width) of the image. + img_shape (tuple[int]): (height, width) of the image. Returns: ndarray: Clipped bboxes. @@ -105,7 +239,7 @@ def bbox_scaling(bboxes, scale, clip_shape=None): Args: bboxes (ndarray): Shape(..., 4). scale (float): Scaling factor. - clip_shape (tuple, optional): If specified, bboxes that exceed the + clip_shape (tuple[int], optional): If specified, bboxes that exceed the boundary will be clipped according to the given shape (h, w). Returns: @@ -135,11 +269,11 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None): bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. scale (float, optional): Scale ratio of bboxes, the default value 1.0 means no padding. - pad_fill (number or list): Value to be filled for padding, None for - no padding. + pad_fill (Number | list[Number]): Value to be filled for padding. + Default: None, which means no padding. Returns: - list or ndarray: The cropped image patches. + list[ndarray] | ndarray: The cropped image patches. """ chn = 1 if img.ndim == 2 else img.shape[2] if pad_fill is not None: @@ -184,8 +318,9 @@ def impad(img, shape, pad_val=0): Args: img (ndarray): Image to be padded. - shape (tuple): Expected padding shape. - pad_val (number or sequence): Values to be filled in padding areas. + shape (tuple[int]): Expected padding shape (h, w). + pad_val (Number | Sequence[Number]): Values to be filled in padding + areas. Default: 0. Returns: ndarray: The padded image. @@ -209,7 +344,7 @@ def impad_to_multiple(img, divisor, pad_val=0): Args: img (ndarray): Image to be padded. divisor (int): Padded image edges will be multiple to divisor. - pad_val (number or sequence): Same as :func:`impad`. + pad_val (Number | Sequence[Number]): Same as :func:`impad`. Returns: ndarray: The padded image. diff --git a/mmcv/image/normalize.py b/mmcv/image/photometric.py similarity index 63% rename from mmcv/image/normalize.py rename to mmcv/image/photometric.py index e735fb03b..a7b345a76 100644 --- a/mmcv/image/normalize.py +++ b/mmcv/image/photometric.py @@ -1,4 +1,3 @@ -# Copyright (c) Open-MMLab. All rights reserved. import cv2 import numpy as np @@ -51,3 +50,44 @@ def imdenormalize(img, mean, std, to_bgr=True): if to_bgr: cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace return img + + +def iminvert(img): + """Invert (negate) an image + + Args: + img (ndarray): Image to be inverted. + + Returns: + ndarray: The inverted image. + """ + return np.full_like(img, 255) - img + + +def solarize(img, thr=128): + """Solarize an image (invert all pixel values above a threshold) + + Args: + img (ndarray): Image to be solarized. + thr (int): Threshold for solarizing (0 - 255). + + Returns: + ndarray: The solarized image. + """ + img = np.where(img < thr, img, 255 - img) + return img + + +def posterize(img, bits): + """Posterize an image (reduce the number of bits for each color channel) + + Args: + img (ndarray): Image to be posterized. + bits (int): Number of bits (1 to 8) to use for posterizing. + + Returns: + ndarray: The posterized image. + """ + shift = 8 - bits + img = np.left_shift(np.right_shift(img, shift), shift) + return img diff --git a/mmcv/image/resize.py b/mmcv/image/resize.py deleted file mode 100644 index 55e92ff92..000000000 --- a/mmcv/image/resize.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright (c) Open-MMLab. All rights reserved. -from __future__ import division - -import cv2 - - -def _scale_size(size, scale): - """Rescale a size by a ratio. - - Args: - size (tuple): w, h. - scale (float): Scaling factor. - - Returns: - tuple[int]: scaled size. - """ - w, h = size - return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5) - - -interp_codes = { - 'nearest': cv2.INTER_NEAREST, - 'bilinear': cv2.INTER_LINEAR, - 'bicubic': cv2.INTER_CUBIC, - 'area': cv2.INTER_AREA, - 'lanczos': cv2.INTER_LANCZOS4 -} - - -def imresize(img, - size, - return_scale=False, - interpolation='bilinear', - out=None): - """Resize image to a given size. - - Args: - img (ndarray): The input image. - size (tuple): Target (w, h). - return_scale (bool): Whether to return `w_scale` and `h_scale`. - interpolation (str): Interpolation method, accepted values are - "nearest", "bilinear", "bicubic", "area", "lanczos". - out (ndarray): The output destination. - - Returns: - tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or - `resized_img`. - """ - h, w = img.shape[:2] - resized_img = cv2.resize( - img, size, dst=out, interpolation=interp_codes[interpolation]) - if not return_scale: - return resized_img - else: - w_scale = size[0] / w - h_scale = size[1] / h - return resized_img, w_scale, h_scale - - -def imresize_like(img, dst_img, return_scale=False, interpolation='bilinear'): - """Resize image to the same size of a given image. - - Args: - img (ndarray): The input image. - dst_img (ndarray): The target image. - return_scale (bool): Whether to return `w_scale` and `h_scale`. - interpolation (str): Same as :func:`resize`. - - Returns: - tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or - `resized_img`. - """ - h, w = dst_img.shape[:2] - return imresize(img, (w, h), return_scale, interpolation) - - -def rescale_size(old_size, scale, return_scale=False): - """Calculate the new size to be rescaled to. - - Args: - old_size (tuple[int]): The old size of image. - scale (float or tuple[int]): The scaling factor or maximum size. - If it is a float number, then the image will be rescaled by this - factor, else if it is a tuple of 2 integers, then the image will - be rescaled as large as possible within the scale. - return_scale (bool): Whether to return the scaling factor besides the - rescaled image size. - - Returns: - tuple[int]: The new rescaled image size. - """ - w, h = old_size - if isinstance(scale, (float, int)): - if scale <= 0: - raise ValueError( - 'Invalid scale {}, must be positive.'.format(scale)) - scale_factor = scale - elif isinstance(scale, tuple): - max_long_edge = max(scale) - max_short_edge = min(scale) - scale_factor = min(max_long_edge / max(h, w), - max_short_edge / min(h, w)) - else: - raise TypeError( - 'Scale must be a number or tuple of int, but got {}'.format( - type(scale))) - - new_size = _scale_size((w, h), scale_factor) - - if return_scale: - return new_size, scale_factor - else: - return new_size - - -def imrescale(img, scale, return_scale=False, interpolation='bilinear'): - """Resize image while keeping the aspect ratio. - - Args: - img (ndarray): The input image. - scale (float or tuple[int]): The scaling factor or maximum size. - If it is a float number, then the image will be rescaled by this - factor, else if it is a tuple of 2 integers, then the image will - be rescaled as large as possible within the scale. - return_scale (bool): Whether to return the scaling factor besides the - rescaled image. - interpolation (str): Same as :func:`resize`. - - Returns: - ndarray: The rescaled image. - """ - h, w = img.shape[:2] - new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) - rescaled_img = imresize(img, new_size, interpolation=interpolation) - if return_scale: - return rescaled_img, scale_factor - else: - return rescaled_img diff --git a/tests/test_image.py b/tests/test_image.py index 0ab43d31a..64810ec8a 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -12,7 +12,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal import mmcv -class TestImage(object): +class TestIO: @classmethod def setup_class(cls): @@ -23,8 +23,6 @@ class TestImage(object): osp.dirname(__file__), 'data/grayscale.jpg') cls.gray_img_path_obj = Path(cls.gray_img_path) cls.img = cv2.imread(cls.img_path) - cls.mean = np.float32(np.array([123.675, 116.28, 103.53])) - cls.std = np.float32(np.array([58.395, 57.12, 57.375])) def assert_img_equal(self, img, ref_img, ratio_thr=0.999): assert img.shape == ref_img.shape @@ -129,36 +127,8 @@ class TestImage(object): os.remove(out_file) self.assert_img_equal(img, rewrite_img) - def test_imnormalize(self): - rgbimg = self.img[:, :, ::-1] - baseline = (rgbimg - self.mean) / self.std - img = mmcv.imnormalize(self.img, self.mean, self.std) - assert np.allclose(img, baseline) - assert id(img) != id(self.img) - img = mmcv.imnormalize(rgbimg, self.mean, self.std, to_rgb=False) - assert np.allclose(img, baseline) - assert id(img) != id(rgbimg) - def test_imnormalize_(self): - img_for_normalize = np.float32(self.img.copy()) - rgbimg_for_normalize = np.float32(self.img[:, :, ::-1].copy()) - baseline = (rgbimg_for_normalize - self.mean) / self.std - img = mmcv.imnormalize_(img_for_normalize, self.mean, self.std) - assert np.allclose(img_for_normalize, baseline) - assert id(img) == id(img_for_normalize) - img = mmcv.imnormalize_( - rgbimg_for_normalize, self.mean, self.std, to_rgb=False) - assert np.allclose(img, baseline) - assert id(img) == id(rgbimg_for_normalize) - - def test_imdenormalize(self): - normimg = (self.img[:, :, ::-1] - self.mean) / self.std - rgbbaseline = (normimg * self.std + self.mean) - bgrbaseline = rgbbaseline[:, :, ::-1] - img = mmcv.imdenormalize(normimg, self.mean, self.std) - assert np.allclose(img, bgrbaseline) - img = mmcv.imdenormalize(normimg, self.mean, self.std, to_bgr=False) - assert np.allclose(img, rgbbaseline) +class TestColorSpace: def test_bgr2gray(self): in_img = np.random.rand(10, 10, 3).astype(np.float32) @@ -266,6 +236,29 @@ class TestImage(object): computed_hls[i, j, :] = [h, _l, s] assert_array_almost_equal(out_img, computed_hls, decimal=2) + @pytest.mark.parametrize('src,dst,ref', + [('bgr', 'gray', cv2.COLOR_BGR2GRAY), + ('rgb', 'gray', cv2.COLOR_RGB2GRAY), + ('bgr', 'rgb', cv2.COLOR_BGR2RGB), + ('rgb', 'bgr', cv2.COLOR_RGB2BGR), + ('bgr', 'hsv', cv2.COLOR_BGR2HSV), + ('hsv', 'bgr', cv2.COLOR_HSV2BGR), + ('bgr', 'hls', cv2.COLOR_BGR2HLS), + ('hls', 'bgr', cv2.COLOR_HLS2BGR)]) + def test_imconvert(self, src, dst, ref): + img = np.random.rand(10, 10, 3).astype(np.float32) + assert_array_equal( + mmcv.imconvert(img, src, dst), cv2.cvtColor(img, ref)) + + +class TestGeometric: + + @classmethod + def setup_class(cls): + # the test img resolution is 400x300 + cls.img_path = osp.join(osp.dirname(__file__), 'data/color.jpg') + cls.img = cv2.imread(cls.img_path) + def test_imresize(self): resized_img = mmcv.imresize(self.img, (1000, 600)) assert resized_img.shape == (600, 1000, 3) @@ -441,32 +434,32 @@ class TestImage(object): assert patch.shape == (100, 100, 3) patch_path = osp.join(osp.dirname(__file__), 'data/patches') ref_patch = np.load(patch_path + '/0.npy') - self.assert_img_equal(patch, ref_patch) + assert_array_equal(patch, ref_patch) assert isinstance(patches, list) and len(patches) == 1 - self.assert_img_equal(patches[0], ref_patch) + assert_array_equal(patches[0], ref_patch) # crop with no scaling and padding patches = mmcv.imcrop(self.img, bboxes) assert len(patches) == bboxes.shape[0] for i in range(len(patches)): ref_patch = np.load(patch_path + '/{}.npy'.format(i)) - self.assert_img_equal(patches[i], ref_patch) + assert_array_equal(patches[i], ref_patch) # crop with scaling and no padding patches = mmcv.imcrop(self.img, bboxes, 1.2) for i in range(len(patches)): ref_patch = np.load(patch_path + '/scale_{}.npy'.format(i)) - self.assert_img_equal(patches[i], ref_patch) + assert_array_equal(patches[i], ref_patch) # crop with scaling and padding patches = mmcv.imcrop(self.img, bboxes, 1.2, pad_fill=[255, 255, 0]) for i in range(len(patches)): ref_patch = np.load(patch_path + '/pad_{}.npy'.format(i)) - self.assert_img_equal(patches[i], ref_patch) + assert_array_equal(patches[i], ref_patch) patches = mmcv.imcrop(self.img, bboxes, 1.2, pad_fill=0) for i in range(len(patches)): ref_patch = np.load(patch_path + '/pad0_{}.npy'.format(i)) - self.assert_img_equal(patches[i], ref_patch) + assert_array_equal(patches[i], ref_patch) def test_impad(self): # grayscale image @@ -536,6 +529,48 @@ class TestImage(object): with pytest.raises(ValueError): mmcv.imrotate(img, 90, center=(0, 0), auto_bound=True) + +class TestPhotometric: + + @classmethod + def setup_class(cls): + # the test img resolution is 400x300 + cls.img_path = osp.join(osp.dirname(__file__), 'data/color.jpg') + cls.img = cv2.imread(cls.img_path) + cls.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) + cls.std = np.array([58.395, 57.12, 57.375], dtype=np.float32) + + def test_imnormalize(self): + rgb_img = self.img[:, :, ::-1] + baseline = (rgb_img - self.mean) / self.std + img = mmcv.imnormalize(self.img, self.mean, self.std) + assert np.allclose(img, baseline) + assert id(img) != id(self.img) + img = mmcv.imnormalize(rgb_img, self.mean, self.std, to_rgb=False) + assert np.allclose(img, baseline) + assert id(img) != id(rgb_img) + + def test_imnormalize_(self): + img_for_normalize = np.float32(self.img) + rgb_img_for_normalize = np.float32(self.img[:, :, ::-1]) + baseline = (rgb_img_for_normalize - self.mean) / self.std + img = mmcv.imnormalize_(img_for_normalize, self.mean, self.std) + assert np.allclose(img_for_normalize, baseline) + assert id(img) == id(img_for_normalize) + img = mmcv.imnormalize_( + rgb_img_for_normalize, self.mean, self.std, to_rgb=False) + assert np.allclose(img, baseline) + assert id(img) == id(rgb_img_for_normalize) + + def test_imdenormalize(self): + norm_img = (self.img[:, :, ::-1] - self.mean) / self.std + rgb_baseline = (norm_img * self.std + self.mean) + bgr_baseline = rgb_baseline[:, :, ::-1] + img = mmcv.imdenormalize(norm_img, self.mean, self.std) + assert np.allclose(img, bgr_baseline) + img = mmcv.imdenormalize(norm_img, self.mean, self.std, to_bgr=False) + assert np.allclose(img, rgb_baseline) + def test_iminvert(self): img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]], dtype=np.uint8)