Refactoring mmcv.images (#239)

* refactoring mmcv.images

* update docstring and minor fix

* some renames
pull/246/head
Kai Chen 2020-04-23 00:34:51 +08:00 committed by GitHub
parent 010b1a0ffc
commit a0618d1051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 284 additions and 236 deletions

View File

@ -1,18 +1,19 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr,
gray2rgb, hls2bgr, hsv2bgr, iminvert, posterize,
rgb2bgr, rgb2gray, solarize)
from .geometry import (imcrop, imflip, imflip_, impad, impad_to_multiple,
imrotate)
gray2rgb, hls2bgr, hsv2bgr, imconvert, rgb2bgr,
rgb2gray)
from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple,
imrescale, imresize, imresize_like, imrotate,
rescale_size)
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .normalize import imdenormalize, imnormalize, imnormalize_
from .resize import imrescale, imresize, imresize_like, rescale_size
from .photometric import (imdenormalize, iminvert, imnormalize, imnormalize_,
posterize, solarize)
__all__ = [
'solarize', 'posterize', 'imread', 'imwrite', 'imfrombytes', 'bgr2gray',
'rgb2gray', 'gray2bgr', 'gray2rgb', 'bgr2rgb', 'rgb2bgr', 'bgr2hsv',
'hsv2bgr', 'bgr2hls', 'hls2bgr', 'iminvert', 'imflip', 'imflip_',
'imrotate', 'imcrop', 'impad', 'impad_to_multiple', 'imnormalize',
'imnormalize_', 'imdenormalize', 'imresize', 'imresize_like', 'imrescale',
'use_backend', 'supported_backends', 'rescale_size'
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
'imresize', 'imresize_like', 'rescale_size', 'imcrop', 'imflip', 'imflip_',
'impad', 'impad_to_multiple', 'imrotate', 'imfrombytes', 'imread',
'imwrite', 'supported_backends', 'use_backend', 'imdenormalize',
'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize'
]

View File

@ -1,46 +1,21 @@
# Copyright (c) Open-MMLab. All rights reserved.
import cv2
import numpy as np
def solarize(img, thr=128):
"""Solarize an image (invert all pixel values above a threshold)
def imconvert(img, src, dst):
"""Convert an image from the src colorspace to dst colorspace.
Args:
img (ndarray): Image to be solarized.
thr (int): Threshold for solarizing (0 - 255).
img (ndarray): The input image.
src (str): The source colorspace, e.g., 'rgb', 'hsv'.
dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
Returns:
ndarray: The solarized image.
ndarray: The converted image.
"""
img = np.where(img < thr, img, 255 - img)
return img
def posterize(img, bits):
"""Posterize an image (reduce the number of bits for each color channel)
Args:
img (ndarray): Image to be posterized.
bits (int): Number of bits (1 to 8) to use for posterizing.
Returns:
ndarray: The posterized image.
"""
shift = 8 - bits
img = np.left_shift(np.right_shift(img, shift), shift)
return img
def iminvert(img):
"""Invert (negate) an image
Args:
img (ndarray): Image to be inverted.
Returns:
ndarray: The inverted image.
"""
return np.full_like(img, 255) - img
code = getattr(cv2, 'COLOR_{}2{}'.format(src.upper(), dst.upper()))
out_img = cv2.cvtColor(img, code)
return out_img
def bgr2gray(img, keepdim=False):

View File

@ -1,10 +1,142 @@
# Copyright (c) Open-MMLab. All rights reserved.
from __future__ import division
import cv2
import numpy as np
def _scale_size(size, scale):
"""Rescale a size by a ratio.
Args:
size (tuple[int]): (w, h).
scale (float): Scaling factor.
Returns:
tuple[int]: scaled size.
"""
w, h = size
return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)
interp_codes = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC,
'area': cv2.INTER_AREA,
'lanczos': cv2.INTER_LANCZOS4
}
def imresize(img,
size,
return_scale=False,
interpolation='bilinear',
out=None):
"""Resize image to a given size.
Args:
img (ndarray): The input image.
size (tuple[int]): Target size (w, h).
return_scale (bool): Whether to return `w_scale` and `h_scale`.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos".
out (ndarray): The output destination.
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
"""
h, w = img.shape[:2]
resized_img = cv2.resize(
img, size, dst=out, interpolation=interp_codes[interpolation])
if not return_scale:
return resized_img
else:
w_scale = size[0] / w
h_scale = size[1] / h
return resized_img, w_scale, h_scale
def imresize_like(img, dst_img, return_scale=False, interpolation='bilinear'):
"""Resize image to the same size of a given image.
Args:
img (ndarray): The input image.
dst_img (ndarray): The target image.
return_scale (bool): Whether to return `w_scale` and `h_scale`.
interpolation (str): Same as :func:`resize`.
Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
"""
h, w = dst_img.shape[:2]
return imresize(img, (w, h), return_scale, interpolation)
def rescale_size(old_size, scale, return_scale=False):
"""Calculate the new size to be rescaled to.
Args:
old_size (tuple[int]): The old size (w, h) of image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image size.
Returns:
tuple[int]: The new rescaled image size.
"""
w, h = old_size
if isinstance(scale, (float, int)):
if scale <= 0:
raise ValueError(
'Invalid scale {}, must be positive.'.format(scale))
scale_factor = scale
elif isinstance(scale, tuple):
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
else:
raise TypeError(
'Scale must be a number or tuple of int, but got {}'.format(
type(scale)))
new_size = _scale_size((w, h), scale_factor)
if return_scale:
return new_size, scale_factor
else:
return new_size
def imrescale(img, scale, return_scale=False, interpolation='bilinear'):
"""Resize image while keeping the aspect ratio.
Args:
img (ndarray): The input image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image.
interpolation (str): Same as :func:`resize`.
Returns:
ndarray: The rescaled image.
"""
h, w = img.shape[:2]
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
rescaled_img = imresize(img, new_size, interpolation=interpolation)
if return_scale:
return rescaled_img, scale_factor
else:
return rescaled_img
def imflip(img, direction='horizontal'):
"""Flip an image horizontally or vertically.
@ -24,12 +156,13 @@ def imflip(img, direction='horizontal'):
def imflip_(img, direction='horizontal'):
"""Inplace flip an image horizontally or vertically.
Args:
img (ndarray): Image to be flipped.
direction (str): The flip direction, either "horizontal" or "vertical".
Returns:
ndarray: The flipped image(inplace).
ndarray: The flipped image (inplace).
"""
assert direction in ['horizontal', 'vertical']
if direction == 'horizontal':
@ -50,8 +183,9 @@ def imrotate(img,
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees, positive values mean
clockwise rotation.
center (tuple): Center of the rotation in the source image, by default
it is the center of the image.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used.
scale (float): Isotropic scale factor.
border_value (int): Border value.
auto_bound (bool): Whether to adjust the image size to cover the whole
@ -86,7 +220,7 @@ def bbox_clip(bboxes, img_shape):
Args:
bboxes (ndarray): Shape (..., 4*k)
img_shape (tuple): (height, width) of the image.
img_shape (tuple[int]): (height, width) of the image.
Returns:
ndarray: Clipped bboxes.
@ -105,7 +239,7 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
Args:
bboxes (ndarray): Shape(..., 4).
scale (float): Scaling factor.
clip_shape (tuple, optional): If specified, bboxes that exceed the
clip_shape (tuple[int], optional): If specified, bboxes that exceed the
boundary will be clipped according to the given shape (h, w).
Returns:
@ -135,11 +269,11 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None):
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no padding.
pad_fill (number or list): Value to be filled for padding, None for
no padding.
pad_fill (Number | list[Number]): Value to be filled for padding.
Default: None, which means no padding.
Returns:
list or ndarray: The cropped image patches.
list[ndarray] | ndarray: The cropped image patches.
"""
chn = 1 if img.ndim == 2 else img.shape[2]
if pad_fill is not None:
@ -184,8 +318,9 @@ def impad(img, shape, pad_val=0):
Args:
img (ndarray): Image to be padded.
shape (tuple): Expected padding shape.
pad_val (number or sequence): Values to be filled in padding areas.
shape (tuple[int]): Expected padding shape (h, w).
pad_val (Number | Sequence[Number]): Values to be filled in padding
areas. Default: 0.
Returns:
ndarray: The padded image.
@ -209,7 +344,7 @@ def impad_to_multiple(img, divisor, pad_val=0):
Args:
img (ndarray): Image to be padded.
divisor (int): Padded image edges will be multiple to divisor.
pad_val (number or sequence): Same as :func:`impad`.
pad_val (Number | Sequence[Number]): Same as :func:`impad`.
Returns:
ndarray: The padded image.

View File

@ -1,4 +1,3 @@
# Copyright (c) Open-MMLab. All rights reserved.
import cv2
import numpy as np
@ -51,3 +50,44 @@ def imdenormalize(img, mean, std, to_bgr=True):
if to_bgr:
cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
return img
def iminvert(img):
"""Invert (negate) an image
Args:
img (ndarray): Image to be inverted.
Returns:
ndarray: The inverted image.
"""
return np.full_like(img, 255) - img
def solarize(img, thr=128):
"""Solarize an image (invert all pixel values above a threshold)
Args:
img (ndarray): Image to be solarized.
thr (int): Threshold for solarizing (0 - 255).
Returns:
ndarray: The solarized image.
"""
img = np.where(img < thr, img, 255 - img)
return img
def posterize(img, bits):
"""Posterize an image (reduce the number of bits for each color channel)
Args:
img (ndarray): Image to be posterized.
bits (int): Number of bits (1 to 8) to use for posterizing.
Returns:
ndarray: The posterized image.
"""
shift = 8 - bits
img = np.left_shift(np.right_shift(img, shift), shift)
return img

View File

@ -1,138 +0,0 @@
# Copyright (c) Open-MMLab. All rights reserved.
from __future__ import division
import cv2
def _scale_size(size, scale):
"""Rescale a size by a ratio.
Args:
size (tuple): w, h.
scale (float): Scaling factor.
Returns:
tuple[int]: scaled size.
"""
w, h = size
return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)
interp_codes = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC,
'area': cv2.INTER_AREA,
'lanczos': cv2.INTER_LANCZOS4
}
def imresize(img,
size,
return_scale=False,
interpolation='bilinear',
out=None):
"""Resize image to a given size.
Args:
img (ndarray): The input image.
size (tuple): Target (w, h).
return_scale (bool): Whether to return `w_scale` and `h_scale`.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos".
out (ndarray): The output destination.
Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
"""
h, w = img.shape[:2]
resized_img = cv2.resize(
img, size, dst=out, interpolation=interp_codes[interpolation])
if not return_scale:
return resized_img
else:
w_scale = size[0] / w
h_scale = size[1] / h
return resized_img, w_scale, h_scale
def imresize_like(img, dst_img, return_scale=False, interpolation='bilinear'):
"""Resize image to the same size of a given image.
Args:
img (ndarray): The input image.
dst_img (ndarray): The target image.
return_scale (bool): Whether to return `w_scale` and `h_scale`.
interpolation (str): Same as :func:`resize`.
Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
"""
h, w = dst_img.shape[:2]
return imresize(img, (w, h), return_scale, interpolation)
def rescale_size(old_size, scale, return_scale=False):
"""Calculate the new size to be rescaled to.
Args:
old_size (tuple[int]): The old size of image.
scale (float or tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image size.
Returns:
tuple[int]: The new rescaled image size.
"""
w, h = old_size
if isinstance(scale, (float, int)):
if scale <= 0:
raise ValueError(
'Invalid scale {}, must be positive.'.format(scale))
scale_factor = scale
elif isinstance(scale, tuple):
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
else:
raise TypeError(
'Scale must be a number or tuple of int, but got {}'.format(
type(scale)))
new_size = _scale_size((w, h), scale_factor)
if return_scale:
return new_size, scale_factor
else:
return new_size
def imrescale(img, scale, return_scale=False, interpolation='bilinear'):
"""Resize image while keeping the aspect ratio.
Args:
img (ndarray): The input image.
scale (float or tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image.
interpolation (str): Same as :func:`resize`.
Returns:
ndarray: The rescaled image.
"""
h, w = img.shape[:2]
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
rescaled_img = imresize(img, new_size, interpolation=interpolation)
if return_scale:
return rescaled_img, scale_factor
else:
return rescaled_img

View File

@ -12,7 +12,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal
import mmcv
class TestImage(object):
class TestIO:
@classmethod
def setup_class(cls):
@ -23,8 +23,6 @@ class TestImage(object):
osp.dirname(__file__), 'data/grayscale.jpg')
cls.gray_img_path_obj = Path(cls.gray_img_path)
cls.img = cv2.imread(cls.img_path)
cls.mean = np.float32(np.array([123.675, 116.28, 103.53]))
cls.std = np.float32(np.array([58.395, 57.12, 57.375]))
def assert_img_equal(self, img, ref_img, ratio_thr=0.999):
assert img.shape == ref_img.shape
@ -129,36 +127,8 @@ class TestImage(object):
os.remove(out_file)
self.assert_img_equal(img, rewrite_img)
def test_imnormalize(self):
rgbimg = self.img[:, :, ::-1]
baseline = (rgbimg - self.mean) / self.std
img = mmcv.imnormalize(self.img, self.mean, self.std)
assert np.allclose(img, baseline)
assert id(img) != id(self.img)
img = mmcv.imnormalize(rgbimg, self.mean, self.std, to_rgb=False)
assert np.allclose(img, baseline)
assert id(img) != id(rgbimg)
def test_imnormalize_(self):
img_for_normalize = np.float32(self.img.copy())
rgbimg_for_normalize = np.float32(self.img[:, :, ::-1].copy())
baseline = (rgbimg_for_normalize - self.mean) / self.std
img = mmcv.imnormalize_(img_for_normalize, self.mean, self.std)
assert np.allclose(img_for_normalize, baseline)
assert id(img) == id(img_for_normalize)
img = mmcv.imnormalize_(
rgbimg_for_normalize, self.mean, self.std, to_rgb=False)
assert np.allclose(img, baseline)
assert id(img) == id(rgbimg_for_normalize)
def test_imdenormalize(self):
normimg = (self.img[:, :, ::-1] - self.mean) / self.std
rgbbaseline = (normimg * self.std + self.mean)
bgrbaseline = rgbbaseline[:, :, ::-1]
img = mmcv.imdenormalize(normimg, self.mean, self.std)
assert np.allclose(img, bgrbaseline)
img = mmcv.imdenormalize(normimg, self.mean, self.std, to_bgr=False)
assert np.allclose(img, rgbbaseline)
class TestColorSpace:
def test_bgr2gray(self):
in_img = np.random.rand(10, 10, 3).astype(np.float32)
@ -266,6 +236,29 @@ class TestImage(object):
computed_hls[i, j, :] = [h, _l, s]
assert_array_almost_equal(out_img, computed_hls, decimal=2)
@pytest.mark.parametrize('src,dst,ref',
[('bgr', 'gray', cv2.COLOR_BGR2GRAY),
('rgb', 'gray', cv2.COLOR_RGB2GRAY),
('bgr', 'rgb', cv2.COLOR_BGR2RGB),
('rgb', 'bgr', cv2.COLOR_RGB2BGR),
('bgr', 'hsv', cv2.COLOR_BGR2HSV),
('hsv', 'bgr', cv2.COLOR_HSV2BGR),
('bgr', 'hls', cv2.COLOR_BGR2HLS),
('hls', 'bgr', cv2.COLOR_HLS2BGR)])
def test_imconvert(self, src, dst, ref):
img = np.random.rand(10, 10, 3).astype(np.float32)
assert_array_equal(
mmcv.imconvert(img, src, dst), cv2.cvtColor(img, ref))
class TestGeometric:
@classmethod
def setup_class(cls):
# the test img resolution is 400x300
cls.img_path = osp.join(osp.dirname(__file__), 'data/color.jpg')
cls.img = cv2.imread(cls.img_path)
def test_imresize(self):
resized_img = mmcv.imresize(self.img, (1000, 600))
assert resized_img.shape == (600, 1000, 3)
@ -441,32 +434,32 @@ class TestImage(object):
assert patch.shape == (100, 100, 3)
patch_path = osp.join(osp.dirname(__file__), 'data/patches')
ref_patch = np.load(patch_path + '/0.npy')
self.assert_img_equal(patch, ref_patch)
assert_array_equal(patch, ref_patch)
assert isinstance(patches, list) and len(patches) == 1
self.assert_img_equal(patches[0], ref_patch)
assert_array_equal(patches[0], ref_patch)
# crop with no scaling and padding
patches = mmcv.imcrop(self.img, bboxes)
assert len(patches) == bboxes.shape[0]
for i in range(len(patches)):
ref_patch = np.load(patch_path + '/{}.npy'.format(i))
self.assert_img_equal(patches[i], ref_patch)
assert_array_equal(patches[i], ref_patch)
# crop with scaling and no padding
patches = mmcv.imcrop(self.img, bboxes, 1.2)
for i in range(len(patches)):
ref_patch = np.load(patch_path + '/scale_{}.npy'.format(i))
self.assert_img_equal(patches[i], ref_patch)
assert_array_equal(patches[i], ref_patch)
# crop with scaling and padding
patches = mmcv.imcrop(self.img, bboxes, 1.2, pad_fill=[255, 255, 0])
for i in range(len(patches)):
ref_patch = np.load(patch_path + '/pad_{}.npy'.format(i))
self.assert_img_equal(patches[i], ref_patch)
assert_array_equal(patches[i], ref_patch)
patches = mmcv.imcrop(self.img, bboxes, 1.2, pad_fill=0)
for i in range(len(patches)):
ref_patch = np.load(patch_path + '/pad0_{}.npy'.format(i))
self.assert_img_equal(patches[i], ref_patch)
assert_array_equal(patches[i], ref_patch)
def test_impad(self):
# grayscale image
@ -536,6 +529,48 @@ class TestImage(object):
with pytest.raises(ValueError):
mmcv.imrotate(img, 90, center=(0, 0), auto_bound=True)
class TestPhotometric:
@classmethod
def setup_class(cls):
# the test img resolution is 400x300
cls.img_path = osp.join(osp.dirname(__file__), 'data/color.jpg')
cls.img = cv2.imread(cls.img_path)
cls.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
cls.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
def test_imnormalize(self):
rgb_img = self.img[:, :, ::-1]
baseline = (rgb_img - self.mean) / self.std
img = mmcv.imnormalize(self.img, self.mean, self.std)
assert np.allclose(img, baseline)
assert id(img) != id(self.img)
img = mmcv.imnormalize(rgb_img, self.mean, self.std, to_rgb=False)
assert np.allclose(img, baseline)
assert id(img) != id(rgb_img)
def test_imnormalize_(self):
img_for_normalize = np.float32(self.img)
rgb_img_for_normalize = np.float32(self.img[:, :, ::-1])
baseline = (rgb_img_for_normalize - self.mean) / self.std
img = mmcv.imnormalize_(img_for_normalize, self.mean, self.std)
assert np.allclose(img_for_normalize, baseline)
assert id(img) == id(img_for_normalize)
img = mmcv.imnormalize_(
rgb_img_for_normalize, self.mean, self.std, to_rgb=False)
assert np.allclose(img, baseline)
assert id(img) == id(rgb_img_for_normalize)
def test_imdenormalize(self):
norm_img = (self.img[:, :, ::-1] - self.mean) / self.std
rgb_baseline = (norm_img * self.std + self.mean)
bgr_baseline = rgb_baseline[:, :, ::-1]
img = mmcv.imdenormalize(norm_img, self.mean, self.std)
assert np.allclose(img, bgr_baseline)
img = mmcv.imdenormalize(norm_img, self.mean, self.std, to_bgr=False)
assert np.allclose(img, rgb_baseline)
def test_iminvert(self):
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
dtype=np.uint8)