mirror of https://github.com/open-mmlab/mmcv.git
Refactoring mmcv.images (#239)
* refactoring mmcv.images * update docstring and minor fix * some renamespull/246/head
parent
010b1a0ffc
commit
a0618d1051
|
@ -1,18 +1,19 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, gray2bgr,
|
||||
gray2rgb, hls2bgr, hsv2bgr, iminvert, posterize,
|
||||
rgb2bgr, rgb2gray, solarize)
|
||||
from .geometry import (imcrop, imflip, imflip_, impad, impad_to_multiple,
|
||||
imrotate)
|
||||
gray2rgb, hls2bgr, hsv2bgr, imconvert, rgb2bgr,
|
||||
rgb2gray)
|
||||
from .geometric import (imcrop, imflip, imflip_, impad, impad_to_multiple,
|
||||
imrescale, imresize, imresize_like, imrotate,
|
||||
rescale_size)
|
||||
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
|
||||
from .normalize import imdenormalize, imnormalize, imnormalize_
|
||||
from .resize import imrescale, imresize, imresize_like, rescale_size
|
||||
from .photometric import (imdenormalize, iminvert, imnormalize, imnormalize_,
|
||||
posterize, solarize)
|
||||
|
||||
__all__ = [
|
||||
'solarize', 'posterize', 'imread', 'imwrite', 'imfrombytes', 'bgr2gray',
|
||||
'rgb2gray', 'gray2bgr', 'gray2rgb', 'bgr2rgb', 'rgb2bgr', 'bgr2hsv',
|
||||
'hsv2bgr', 'bgr2hls', 'hls2bgr', 'iminvert', 'imflip', 'imflip_',
|
||||
'imrotate', 'imcrop', 'impad', 'impad_to_multiple', 'imnormalize',
|
||||
'imnormalize_', 'imdenormalize', 'imresize', 'imresize_like', 'imrescale',
|
||||
'use_backend', 'supported_backends', 'rescale_size'
|
||||
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
|
||||
'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
|
||||
'imresize', 'imresize_like', 'rescale_size', 'imcrop', 'imflip', 'imflip_',
|
||||
'impad', 'impad_to_multiple', 'imrotate', 'imfrombytes', 'imread',
|
||||
'imwrite', 'supported_backends', 'use_backend', 'imdenormalize',
|
||||
'imnormalize', 'imnormalize_', 'iminvert', 'posterize', 'solarize'
|
||||
]
|
||||
|
|
|
@ -1,46 +1,21 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def solarize(img, thr=128):
|
||||
"""Solarize an image (invert all pixel values above a threshold)
|
||||
def imconvert(img, src, dst):
|
||||
"""Convert an image from the src colorspace to dst colorspace.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be solarized.
|
||||
thr (int): Threshold for solarizing (0 - 255).
|
||||
img (ndarray): The input image.
|
||||
src (str): The source colorspace, e.g., 'rgb', 'hsv'.
|
||||
dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
|
||||
|
||||
Returns:
|
||||
ndarray: The solarized image.
|
||||
ndarray: The converted image.
|
||||
"""
|
||||
img = np.where(img < thr, img, 255 - img)
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits):
|
||||
"""Posterize an image (reduce the number of bits for each color channel)
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be posterized.
|
||||
bits (int): Number of bits (1 to 8) to use for posterizing.
|
||||
|
||||
Returns:
|
||||
ndarray: The posterized image.
|
||||
"""
|
||||
shift = 8 - bits
|
||||
img = np.left_shift(np.right_shift(img, shift), shift)
|
||||
return img
|
||||
|
||||
|
||||
def iminvert(img):
|
||||
"""Invert (negate) an image
|
||||
Args:
|
||||
img (ndarray): Image to be inverted.
|
||||
|
||||
Returns:
|
||||
ndarray: The inverted image.
|
||||
"""
|
||||
return np.full_like(img, 255) - img
|
||||
code = getattr(cv2, 'COLOR_{}2{}'.format(src.upper(), dst.upper()))
|
||||
out_img = cv2.cvtColor(img, code)
|
||||
return out_img
|
||||
|
||||
|
||||
def bgr2gray(img, keepdim=False):
|
||||
|
|
|
@ -1,10 +1,142 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from __future__ import division
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _scale_size(size, scale):
|
||||
"""Rescale a size by a ratio.
|
||||
|
||||
Args:
|
||||
size (tuple[int]): (w, h).
|
||||
scale (float): Scaling factor.
|
||||
|
||||
Returns:
|
||||
tuple[int]: scaled size.
|
||||
"""
|
||||
w, h = size
|
||||
return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)
|
||||
|
||||
|
||||
interp_codes = {
|
||||
'nearest': cv2.INTER_NEAREST,
|
||||
'bilinear': cv2.INTER_LINEAR,
|
||||
'bicubic': cv2.INTER_CUBIC,
|
||||
'area': cv2.INTER_AREA,
|
||||
'lanczos': cv2.INTER_LANCZOS4
|
||||
}
|
||||
|
||||
|
||||
def imresize(img,
|
||||
size,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
out=None):
|
||||
"""Resize image to a given size.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
size (tuple[int]): Target size (w, h).
|
||||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
"nearest", "bilinear", "bicubic", "area", "lanczos".
|
||||
out (ndarray): The output destination.
|
||||
|
||||
Returns:
|
||||
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||||
`resized_img`.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
resized_img = cv2.resize(
|
||||
img, size, dst=out, interpolation=interp_codes[interpolation])
|
||||
if not return_scale:
|
||||
return resized_img
|
||||
else:
|
||||
w_scale = size[0] / w
|
||||
h_scale = size[1] / h
|
||||
return resized_img, w_scale, h_scale
|
||||
|
||||
|
||||
def imresize_like(img, dst_img, return_scale=False, interpolation='bilinear'):
|
||||
"""Resize image to the same size of a given image.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
dst_img (ndarray): The target image.
|
||||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||||
interpolation (str): Same as :func:`resize`.
|
||||
|
||||
Returns:
|
||||
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||||
`resized_img`.
|
||||
"""
|
||||
h, w = dst_img.shape[:2]
|
||||
return imresize(img, (w, h), return_scale, interpolation)
|
||||
|
||||
|
||||
def rescale_size(old_size, scale, return_scale=False):
|
||||
"""Calculate the new size to be rescaled to.
|
||||
|
||||
Args:
|
||||
old_size (tuple[int]): The old size (w, h) of image.
|
||||
scale (float | tuple[int]): The scaling factor or maximum size.
|
||||
If it is a float number, then the image will be rescaled by this
|
||||
factor, else if it is a tuple of 2 integers, then the image will
|
||||
be rescaled as large as possible within the scale.
|
||||
return_scale (bool): Whether to return the scaling factor besides the
|
||||
rescaled image size.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The new rescaled image size.
|
||||
"""
|
||||
w, h = old_size
|
||||
if isinstance(scale, (float, int)):
|
||||
if scale <= 0:
|
||||
raise ValueError(
|
||||
'Invalid scale {}, must be positive.'.format(scale))
|
||||
scale_factor = scale
|
||||
elif isinstance(scale, tuple):
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
else:
|
||||
raise TypeError(
|
||||
'Scale must be a number or tuple of int, but got {}'.format(
|
||||
type(scale)))
|
||||
|
||||
new_size = _scale_size((w, h), scale_factor)
|
||||
|
||||
if return_scale:
|
||||
return new_size, scale_factor
|
||||
else:
|
||||
return new_size
|
||||
|
||||
|
||||
def imrescale(img, scale, return_scale=False, interpolation='bilinear'):
|
||||
"""Resize image while keeping the aspect ratio.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
scale (float | tuple[int]): The scaling factor or maximum size.
|
||||
If it is a float number, then the image will be rescaled by this
|
||||
factor, else if it is a tuple of 2 integers, then the image will
|
||||
be rescaled as large as possible within the scale.
|
||||
return_scale (bool): Whether to return the scaling factor besides the
|
||||
rescaled image.
|
||||
interpolation (str): Same as :func:`resize`.
|
||||
|
||||
Returns:
|
||||
ndarray: The rescaled image.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
|
||||
rescaled_img = imresize(img, new_size, interpolation=interpolation)
|
||||
if return_scale:
|
||||
return rescaled_img, scale_factor
|
||||
else:
|
||||
return rescaled_img
|
||||
|
||||
|
||||
def imflip(img, direction='horizontal'):
|
||||
"""Flip an image horizontally or vertically.
|
||||
|
||||
|
@ -24,12 +156,13 @@ def imflip(img, direction='horizontal'):
|
|||
|
||||
def imflip_(img, direction='horizontal'):
|
||||
"""Inplace flip an image horizontally or vertically.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be flipped.
|
||||
direction (str): The flip direction, either "horizontal" or "vertical".
|
||||
|
||||
Returns:
|
||||
ndarray: The flipped image(inplace).
|
||||
ndarray: The flipped image (inplace).
|
||||
"""
|
||||
assert direction in ['horizontal', 'vertical']
|
||||
if direction == 'horizontal':
|
||||
|
@ -50,8 +183,9 @@ def imrotate(img,
|
|||
img (ndarray): Image to be rotated.
|
||||
angle (float): Rotation angle in degrees, positive values mean
|
||||
clockwise rotation.
|
||||
center (tuple): Center of the rotation in the source image, by default
|
||||
it is the center of the image.
|
||||
center (tuple[float], optional): Center point (w, h) of the rotation in
|
||||
the source image. If not specified, the center of the image will be
|
||||
used.
|
||||
scale (float): Isotropic scale factor.
|
||||
border_value (int): Border value.
|
||||
auto_bound (bool): Whether to adjust the image size to cover the whole
|
||||
|
@ -86,7 +220,7 @@ def bbox_clip(bboxes, img_shape):
|
|||
|
||||
Args:
|
||||
bboxes (ndarray): Shape (..., 4*k)
|
||||
img_shape (tuple): (height, width) of the image.
|
||||
img_shape (tuple[int]): (height, width) of the image.
|
||||
|
||||
Returns:
|
||||
ndarray: Clipped bboxes.
|
||||
|
@ -105,7 +239,7 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
|
|||
Args:
|
||||
bboxes (ndarray): Shape(..., 4).
|
||||
scale (float): Scaling factor.
|
||||
clip_shape (tuple, optional): If specified, bboxes that exceed the
|
||||
clip_shape (tuple[int], optional): If specified, bboxes that exceed the
|
||||
boundary will be clipped according to the given shape (h, w).
|
||||
|
||||
Returns:
|
||||
|
@ -135,11 +269,11 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None):
|
|||
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
|
||||
scale (float, optional): Scale ratio of bboxes, the default value
|
||||
1.0 means no padding.
|
||||
pad_fill (number or list): Value to be filled for padding, None for
|
||||
no padding.
|
||||
pad_fill (Number | list[Number]): Value to be filled for padding.
|
||||
Default: None, which means no padding.
|
||||
|
||||
Returns:
|
||||
list or ndarray: The cropped image patches.
|
||||
list[ndarray] | ndarray: The cropped image patches.
|
||||
"""
|
||||
chn = 1 if img.ndim == 2 else img.shape[2]
|
||||
if pad_fill is not None:
|
||||
|
@ -184,8 +318,9 @@ def impad(img, shape, pad_val=0):
|
|||
|
||||
Args:
|
||||
img (ndarray): Image to be padded.
|
||||
shape (tuple): Expected padding shape.
|
||||
pad_val (number or sequence): Values to be filled in padding areas.
|
||||
shape (tuple[int]): Expected padding shape (h, w).
|
||||
pad_val (Number | Sequence[Number]): Values to be filled in padding
|
||||
areas. Default: 0.
|
||||
|
||||
Returns:
|
||||
ndarray: The padded image.
|
||||
|
@ -209,7 +344,7 @@ def impad_to_multiple(img, divisor, pad_val=0):
|
|||
Args:
|
||||
img (ndarray): Image to be padded.
|
||||
divisor (int): Padded image edges will be multiple to divisor.
|
||||
pad_val (number or sequence): Same as :func:`impad`.
|
||||
pad_val (Number | Sequence[Number]): Same as :func:`impad`.
|
||||
|
||||
Returns:
|
||||
ndarray: The padded image.
|
|
@ -1,4 +1,3 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
@ -51,3 +50,44 @@ def imdenormalize(img, mean, std, to_bgr=True):
|
|||
if to_bgr:
|
||||
cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
|
||||
return img
|
||||
|
||||
|
||||
def iminvert(img):
|
||||
"""Invert (negate) an image
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be inverted.
|
||||
|
||||
Returns:
|
||||
ndarray: The inverted image.
|
||||
"""
|
||||
return np.full_like(img, 255) - img
|
||||
|
||||
|
||||
def solarize(img, thr=128):
|
||||
"""Solarize an image (invert all pixel values above a threshold)
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be solarized.
|
||||
thr (int): Threshold for solarizing (0 - 255).
|
||||
|
||||
Returns:
|
||||
ndarray: The solarized image.
|
||||
"""
|
||||
img = np.where(img < thr, img, 255 - img)
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits):
|
||||
"""Posterize an image (reduce the number of bits for each color channel)
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be posterized.
|
||||
bits (int): Number of bits (1 to 8) to use for posterizing.
|
||||
|
||||
Returns:
|
||||
ndarray: The posterized image.
|
||||
"""
|
||||
shift = 8 - bits
|
||||
img = np.left_shift(np.right_shift(img, shift), shift)
|
||||
return img
|
|
@ -1,138 +0,0 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from __future__ import division
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
def _scale_size(size, scale):
|
||||
"""Rescale a size by a ratio.
|
||||
|
||||
Args:
|
||||
size (tuple): w, h.
|
||||
scale (float): Scaling factor.
|
||||
|
||||
Returns:
|
||||
tuple[int]: scaled size.
|
||||
"""
|
||||
w, h = size
|
||||
return int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)
|
||||
|
||||
|
||||
interp_codes = {
|
||||
'nearest': cv2.INTER_NEAREST,
|
||||
'bilinear': cv2.INTER_LINEAR,
|
||||
'bicubic': cv2.INTER_CUBIC,
|
||||
'area': cv2.INTER_AREA,
|
||||
'lanczos': cv2.INTER_LANCZOS4
|
||||
}
|
||||
|
||||
|
||||
def imresize(img,
|
||||
size,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
out=None):
|
||||
"""Resize image to a given size.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
size (tuple): Target (w, h).
|
||||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
"nearest", "bilinear", "bicubic", "area", "lanczos".
|
||||
out (ndarray): The output destination.
|
||||
|
||||
Returns:
|
||||
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||||
`resized_img`.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
resized_img = cv2.resize(
|
||||
img, size, dst=out, interpolation=interp_codes[interpolation])
|
||||
if not return_scale:
|
||||
return resized_img
|
||||
else:
|
||||
w_scale = size[0] / w
|
||||
h_scale = size[1] / h
|
||||
return resized_img, w_scale, h_scale
|
||||
|
||||
|
||||
def imresize_like(img, dst_img, return_scale=False, interpolation='bilinear'):
|
||||
"""Resize image to the same size of a given image.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
dst_img (ndarray): The target image.
|
||||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||||
interpolation (str): Same as :func:`resize`.
|
||||
|
||||
Returns:
|
||||
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||||
`resized_img`.
|
||||
"""
|
||||
h, w = dst_img.shape[:2]
|
||||
return imresize(img, (w, h), return_scale, interpolation)
|
||||
|
||||
|
||||
def rescale_size(old_size, scale, return_scale=False):
|
||||
"""Calculate the new size to be rescaled to.
|
||||
|
||||
Args:
|
||||
old_size (tuple[int]): The old size of image.
|
||||
scale (float or tuple[int]): The scaling factor or maximum size.
|
||||
If it is a float number, then the image will be rescaled by this
|
||||
factor, else if it is a tuple of 2 integers, then the image will
|
||||
be rescaled as large as possible within the scale.
|
||||
return_scale (bool): Whether to return the scaling factor besides the
|
||||
rescaled image size.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The new rescaled image size.
|
||||
"""
|
||||
w, h = old_size
|
||||
if isinstance(scale, (float, int)):
|
||||
if scale <= 0:
|
||||
raise ValueError(
|
||||
'Invalid scale {}, must be positive.'.format(scale))
|
||||
scale_factor = scale
|
||||
elif isinstance(scale, tuple):
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
else:
|
||||
raise TypeError(
|
||||
'Scale must be a number or tuple of int, but got {}'.format(
|
||||
type(scale)))
|
||||
|
||||
new_size = _scale_size((w, h), scale_factor)
|
||||
|
||||
if return_scale:
|
||||
return new_size, scale_factor
|
||||
else:
|
||||
return new_size
|
||||
|
||||
|
||||
def imrescale(img, scale, return_scale=False, interpolation='bilinear'):
|
||||
"""Resize image while keeping the aspect ratio.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
scale (float or tuple[int]): The scaling factor or maximum size.
|
||||
If it is a float number, then the image will be rescaled by this
|
||||
factor, else if it is a tuple of 2 integers, then the image will
|
||||
be rescaled as large as possible within the scale.
|
||||
return_scale (bool): Whether to return the scaling factor besides the
|
||||
rescaled image.
|
||||
interpolation (str): Same as :func:`resize`.
|
||||
|
||||
Returns:
|
||||
ndarray: The rescaled image.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
|
||||
rescaled_img = imresize(img, new_size, interpolation=interpolation)
|
||||
if return_scale:
|
||||
return rescaled_img, scale_factor
|
||||
else:
|
||||
return rescaled_img
|
|
@ -12,7 +12,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal
|
|||
import mmcv
|
||||
|
||||
|
||||
class TestImage(object):
|
||||
class TestIO:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
|
@ -23,8 +23,6 @@ class TestImage(object):
|
|||
osp.dirname(__file__), 'data/grayscale.jpg')
|
||||
cls.gray_img_path_obj = Path(cls.gray_img_path)
|
||||
cls.img = cv2.imread(cls.img_path)
|
||||
cls.mean = np.float32(np.array([123.675, 116.28, 103.53]))
|
||||
cls.std = np.float32(np.array([58.395, 57.12, 57.375]))
|
||||
|
||||
def assert_img_equal(self, img, ref_img, ratio_thr=0.999):
|
||||
assert img.shape == ref_img.shape
|
||||
|
@ -129,36 +127,8 @@ class TestImage(object):
|
|||
os.remove(out_file)
|
||||
self.assert_img_equal(img, rewrite_img)
|
||||
|
||||
def test_imnormalize(self):
|
||||
rgbimg = self.img[:, :, ::-1]
|
||||
baseline = (rgbimg - self.mean) / self.std
|
||||
img = mmcv.imnormalize(self.img, self.mean, self.std)
|
||||
assert np.allclose(img, baseline)
|
||||
assert id(img) != id(self.img)
|
||||
img = mmcv.imnormalize(rgbimg, self.mean, self.std, to_rgb=False)
|
||||
assert np.allclose(img, baseline)
|
||||
assert id(img) != id(rgbimg)
|
||||
|
||||
def test_imnormalize_(self):
|
||||
img_for_normalize = np.float32(self.img.copy())
|
||||
rgbimg_for_normalize = np.float32(self.img[:, :, ::-1].copy())
|
||||
baseline = (rgbimg_for_normalize - self.mean) / self.std
|
||||
img = mmcv.imnormalize_(img_for_normalize, self.mean, self.std)
|
||||
assert np.allclose(img_for_normalize, baseline)
|
||||
assert id(img) == id(img_for_normalize)
|
||||
img = mmcv.imnormalize_(
|
||||
rgbimg_for_normalize, self.mean, self.std, to_rgb=False)
|
||||
assert np.allclose(img, baseline)
|
||||
assert id(img) == id(rgbimg_for_normalize)
|
||||
|
||||
def test_imdenormalize(self):
|
||||
normimg = (self.img[:, :, ::-1] - self.mean) / self.std
|
||||
rgbbaseline = (normimg * self.std + self.mean)
|
||||
bgrbaseline = rgbbaseline[:, :, ::-1]
|
||||
img = mmcv.imdenormalize(normimg, self.mean, self.std)
|
||||
assert np.allclose(img, bgrbaseline)
|
||||
img = mmcv.imdenormalize(normimg, self.mean, self.std, to_bgr=False)
|
||||
assert np.allclose(img, rgbbaseline)
|
||||
class TestColorSpace:
|
||||
|
||||
def test_bgr2gray(self):
|
||||
in_img = np.random.rand(10, 10, 3).astype(np.float32)
|
||||
|
@ -266,6 +236,29 @@ class TestImage(object):
|
|||
computed_hls[i, j, :] = [h, _l, s]
|
||||
assert_array_almost_equal(out_img, computed_hls, decimal=2)
|
||||
|
||||
@pytest.mark.parametrize('src,dst,ref',
|
||||
[('bgr', 'gray', cv2.COLOR_BGR2GRAY),
|
||||
('rgb', 'gray', cv2.COLOR_RGB2GRAY),
|
||||
('bgr', 'rgb', cv2.COLOR_BGR2RGB),
|
||||
('rgb', 'bgr', cv2.COLOR_RGB2BGR),
|
||||
('bgr', 'hsv', cv2.COLOR_BGR2HSV),
|
||||
('hsv', 'bgr', cv2.COLOR_HSV2BGR),
|
||||
('bgr', 'hls', cv2.COLOR_BGR2HLS),
|
||||
('hls', 'bgr', cv2.COLOR_HLS2BGR)])
|
||||
def test_imconvert(self, src, dst, ref):
|
||||
img = np.random.rand(10, 10, 3).astype(np.float32)
|
||||
assert_array_equal(
|
||||
mmcv.imconvert(img, src, dst), cv2.cvtColor(img, ref))
|
||||
|
||||
|
||||
class TestGeometric:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# the test img resolution is 400x300
|
||||
cls.img_path = osp.join(osp.dirname(__file__), 'data/color.jpg')
|
||||
cls.img = cv2.imread(cls.img_path)
|
||||
|
||||
def test_imresize(self):
|
||||
resized_img = mmcv.imresize(self.img, (1000, 600))
|
||||
assert resized_img.shape == (600, 1000, 3)
|
||||
|
@ -441,32 +434,32 @@ class TestImage(object):
|
|||
assert patch.shape == (100, 100, 3)
|
||||
patch_path = osp.join(osp.dirname(__file__), 'data/patches')
|
||||
ref_patch = np.load(patch_path + '/0.npy')
|
||||
self.assert_img_equal(patch, ref_patch)
|
||||
assert_array_equal(patch, ref_patch)
|
||||
assert isinstance(patches, list) and len(patches) == 1
|
||||
self.assert_img_equal(patches[0], ref_patch)
|
||||
assert_array_equal(patches[0], ref_patch)
|
||||
|
||||
# crop with no scaling and padding
|
||||
patches = mmcv.imcrop(self.img, bboxes)
|
||||
assert len(patches) == bboxes.shape[0]
|
||||
for i in range(len(patches)):
|
||||
ref_patch = np.load(patch_path + '/{}.npy'.format(i))
|
||||
self.assert_img_equal(patches[i], ref_patch)
|
||||
assert_array_equal(patches[i], ref_patch)
|
||||
|
||||
# crop with scaling and no padding
|
||||
patches = mmcv.imcrop(self.img, bboxes, 1.2)
|
||||
for i in range(len(patches)):
|
||||
ref_patch = np.load(patch_path + '/scale_{}.npy'.format(i))
|
||||
self.assert_img_equal(patches[i], ref_patch)
|
||||
assert_array_equal(patches[i], ref_patch)
|
||||
|
||||
# crop with scaling and padding
|
||||
patches = mmcv.imcrop(self.img, bboxes, 1.2, pad_fill=[255, 255, 0])
|
||||
for i in range(len(patches)):
|
||||
ref_patch = np.load(patch_path + '/pad_{}.npy'.format(i))
|
||||
self.assert_img_equal(patches[i], ref_patch)
|
||||
assert_array_equal(patches[i], ref_patch)
|
||||
patches = mmcv.imcrop(self.img, bboxes, 1.2, pad_fill=0)
|
||||
for i in range(len(patches)):
|
||||
ref_patch = np.load(patch_path + '/pad0_{}.npy'.format(i))
|
||||
self.assert_img_equal(patches[i], ref_patch)
|
||||
assert_array_equal(patches[i], ref_patch)
|
||||
|
||||
def test_impad(self):
|
||||
# grayscale image
|
||||
|
@ -536,6 +529,48 @@ class TestImage(object):
|
|||
with pytest.raises(ValueError):
|
||||
mmcv.imrotate(img, 90, center=(0, 0), auto_bound=True)
|
||||
|
||||
|
||||
class TestPhotometric:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# the test img resolution is 400x300
|
||||
cls.img_path = osp.join(osp.dirname(__file__), 'data/color.jpg')
|
||||
cls.img = cv2.imread(cls.img_path)
|
||||
cls.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
|
||||
cls.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
|
||||
|
||||
def test_imnormalize(self):
|
||||
rgb_img = self.img[:, :, ::-1]
|
||||
baseline = (rgb_img - self.mean) / self.std
|
||||
img = mmcv.imnormalize(self.img, self.mean, self.std)
|
||||
assert np.allclose(img, baseline)
|
||||
assert id(img) != id(self.img)
|
||||
img = mmcv.imnormalize(rgb_img, self.mean, self.std, to_rgb=False)
|
||||
assert np.allclose(img, baseline)
|
||||
assert id(img) != id(rgb_img)
|
||||
|
||||
def test_imnormalize_(self):
|
||||
img_for_normalize = np.float32(self.img)
|
||||
rgb_img_for_normalize = np.float32(self.img[:, :, ::-1])
|
||||
baseline = (rgb_img_for_normalize - self.mean) / self.std
|
||||
img = mmcv.imnormalize_(img_for_normalize, self.mean, self.std)
|
||||
assert np.allclose(img_for_normalize, baseline)
|
||||
assert id(img) == id(img_for_normalize)
|
||||
img = mmcv.imnormalize_(
|
||||
rgb_img_for_normalize, self.mean, self.std, to_rgb=False)
|
||||
assert np.allclose(img, baseline)
|
||||
assert id(img) == id(rgb_img_for_normalize)
|
||||
|
||||
def test_imdenormalize(self):
|
||||
norm_img = (self.img[:, :, ::-1] - self.mean) / self.std
|
||||
rgb_baseline = (norm_img * self.std + self.mean)
|
||||
bgr_baseline = rgb_baseline[:, :, ::-1]
|
||||
img = mmcv.imdenormalize(norm_img, self.mean, self.std)
|
||||
assert np.allclose(img, bgr_baseline)
|
||||
img = mmcv.imdenormalize(norm_img, self.mean, self.std, to_bgr=False)
|
||||
assert np.allclose(img, rgb_baseline)
|
||||
|
||||
def test_iminvert(self):
|
||||
img = np.array([[0, 128, 255], [1, 127, 254], [2, 129, 253]],
|
||||
dtype=np.uint8)
|
||||
|
|
Loading…
Reference in New Issue