mirror of https://github.com/open-mmlab/mmcv.git
Add type hints for mmcv/image (#2089)
* Fix typehint * minor fix * minor fix * minor fix Co-authored-by: zhouzaida <zhouzaida@163.com>pull/2256/head
parent
3a311e85ae
commit
2fb2b91aed
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numbers
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union, no_type_check
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -15,7 +15,10 @@ except ImportError:
|
|||
Image = None
|
||||
|
||||
|
||||
def _scale_size(size, scale):
|
||||
def _scale_size(
|
||||
size: Tuple[int, int],
|
||||
scale: Union[float, int, tuple],
|
||||
) -> Tuple[int, int]:
|
||||
"""Rescale a size by a ratio.
|
||||
|
||||
Args:
|
||||
|
@ -72,12 +75,14 @@ if Image is not None:
|
|||
}
|
||||
|
||||
|
||||
def imresize(img,
|
||||
size,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
out=None,
|
||||
backend=None):
|
||||
def imresize(
|
||||
img: np.ndarray,
|
||||
size: Tuple[int, int],
|
||||
return_scale: bool = False,
|
||||
interpolation: str = 'bilinear',
|
||||
out: Optional[np.ndarray] = None,
|
||||
backend: Optional[str] = None
|
||||
) -> Union[Tuple[np.ndarray, float, float], np.ndarray]:
|
||||
"""Resize image to a given size.
|
||||
|
||||
Args:
|
||||
|
@ -119,15 +124,18 @@ def imresize(img,
|
|||
return resized_img, w_scale, h_scale
|
||||
|
||||
|
||||
def imresize_to_multiple(img,
|
||||
divisor,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
keep_ratio=False,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
out=None,
|
||||
backend=None):
|
||||
@no_type_check
|
||||
def imresize_to_multiple(
|
||||
img: np.ndarray,
|
||||
divisor: Union[int, Tuple[int, int]],
|
||||
size: Union[int, Tuple[int, int], None] = None,
|
||||
scale_factor: Union[float, Tuple[float, float], None] = None,
|
||||
keep_ratio: bool = False,
|
||||
return_scale: bool = False,
|
||||
interpolation: str = 'bilinear',
|
||||
out: Optional[np.ndarray] = None,
|
||||
backend: Optional[str] = None
|
||||
) -> Union[Tuple[np.ndarray, float, float], np.ndarray]:
|
||||
"""Resize image according to a given size or scale factor and then rounds
|
||||
up the the resized or rescaled image size to the nearest value that can be
|
||||
divided by the divisor.
|
||||
|
@ -183,11 +191,13 @@ def imresize_to_multiple(img,
|
|||
return resized_img
|
||||
|
||||
|
||||
def imresize_like(img,
|
||||
dst_img,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
backend=None):
|
||||
def imresize_like(
|
||||
img: np.ndarray,
|
||||
dst_img: np.ndarray,
|
||||
return_scale: bool = False,
|
||||
interpolation: str = 'bilinear',
|
||||
backend: Optional[str] = None
|
||||
) -> Union[Tuple[np.ndarray, float, float], np.ndarray]:
|
||||
"""Resize image to the same size of a given image.
|
||||
|
||||
Args:
|
||||
|
@ -205,7 +215,9 @@ def imresize_like(img,
|
|||
return imresize(img, (w, h), return_scale, interpolation, backend=backend)
|
||||
|
||||
|
||||
def rescale_size(old_size, scale, return_scale=False):
|
||||
def rescale_size(old_size: tuple,
|
||||
scale: Union[float, int, tuple],
|
||||
return_scale: bool = False) -> tuple:
|
||||
"""Calculate the new size to be rescaled to.
|
||||
|
||||
Args:
|
||||
|
@ -242,11 +254,13 @@ def rescale_size(old_size, scale, return_scale=False):
|
|||
return new_size
|
||||
|
||||
|
||||
def imrescale(img,
|
||||
scale,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
backend=None):
|
||||
def imrescale(
|
||||
img: np.ndarray,
|
||||
scale: Union[float, Tuple[int, int]],
|
||||
return_scale: bool = False,
|
||||
interpolation: str = 'bilinear',
|
||||
backend: Optional[str] = None
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
|
||||
"""Resize image while keeping the aspect ratio.
|
||||
|
||||
Args:
|
||||
|
@ -273,7 +287,7 @@ def imrescale(img,
|
|||
return rescaled_img
|
||||
|
||||
|
||||
def imflip(img, direction='horizontal'):
|
||||
def imflip(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray:
|
||||
"""Flip an image horizontally or vertically.
|
||||
|
||||
Args:
|
||||
|
@ -293,7 +307,7 @@ def imflip(img, direction='horizontal'):
|
|||
return np.flip(img, axis=(0, 1))
|
||||
|
||||
|
||||
def imflip_(img, direction='horizontal'):
|
||||
def imflip_(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray:
|
||||
"""Inplace flip an image horizontally or vertically.
|
||||
|
||||
Args:
|
||||
|
@ -373,7 +387,7 @@ def imrotate(img: np.ndarray,
|
|||
return rotated
|
||||
|
||||
|
||||
def bbox_clip(bboxes, img_shape):
|
||||
def bbox_clip(bboxes: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray:
|
||||
"""Clip bboxes to fit the image shape.
|
||||
|
||||
Args:
|
||||
|
@ -391,7 +405,9 @@ def bbox_clip(bboxes, img_shape):
|
|||
return clipped_bboxes
|
||||
|
||||
|
||||
def bbox_scaling(bboxes, scale, clip_shape=None):
|
||||
def bbox_scaling(bboxes: np.ndarray,
|
||||
scale: float,
|
||||
clip_shape: Optional[Tuple[int, int]] = None) -> np.ndarray:
|
||||
"""Scaling bboxes w.r.t the box center.
|
||||
|
||||
Args:
|
||||
|
@ -417,7 +433,12 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
|
|||
return scaled_bboxes
|
||||
|
||||
|
||||
def imcrop(img, bboxes, scale=1.0, pad_fill=None):
|
||||
def imcrop(
|
||||
img: np.ndarray,
|
||||
bboxes: np.ndarray,
|
||||
scale: float = 1.0,
|
||||
pad_fill: Union[float, list, None] = None
|
||||
) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
"""Crop image patches.
|
||||
|
||||
3 steps: scale the bboxes -> clip bboxes -> crop and pad.
|
||||
|
@ -450,10 +471,12 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None):
|
|||
patch = img[y1:y2 + 1, x1:x2 + 1, ...]
|
||||
else:
|
||||
_x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
|
||||
patch_h = _y2 - _y1 + 1
|
||||
patch_w = _x2 - _x1 + 1
|
||||
if chn == 1:
|
||||
patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
|
||||
patch_shape = (patch_h, patch_w)
|
||||
else:
|
||||
patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
|
||||
patch_shape = (patch_h, patch_w, chn) # type: ignore
|
||||
patch = np.array(
|
||||
pad_fill, dtype=img.dtype) * np.ones(
|
||||
patch_shape, dtype=img.dtype)
|
||||
|
@ -471,12 +494,12 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None):
|
|||
return patches
|
||||
|
||||
|
||||
def impad(img,
|
||||
def impad(img: np.ndarray,
|
||||
*,
|
||||
shape=None,
|
||||
padding=None,
|
||||
pad_val=0,
|
||||
padding_mode='constant'):
|
||||
shape: Optional[Tuple[int, int]] = None,
|
||||
padding: Union[int, tuple, None] = None,
|
||||
pad_val: Union[float, List] = 0,
|
||||
padding_mode: str = 'constant') -> np.ndarray:
|
||||
"""Pad the given image to a certain shape or pad on all sides with
|
||||
specified padding mode and padding value.
|
||||
|
||||
|
@ -554,7 +577,9 @@ def impad(img,
|
|||
return img
|
||||
|
||||
|
||||
def impad_to_multiple(img, divisor, pad_val=0):
|
||||
def impad_to_multiple(img: np.ndarray,
|
||||
divisor: int,
|
||||
pad_val: Union[float, List] = 0) -> np.ndarray:
|
||||
"""Pad an image to ensure each edge to be multiple to some number.
|
||||
|
||||
Args:
|
||||
|
@ -570,7 +595,9 @@ def impad_to_multiple(img, divisor, pad_val=0):
|
|||
return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
|
||||
|
||||
|
||||
def cutout(img, shape, pad_val=0):
|
||||
def cutout(img: np.ndarray,
|
||||
shape: Union[int, Tuple[int, int]],
|
||||
pad_val: Union[int, float, tuple] = 0) -> np.ndarray:
|
||||
"""Randomly cut out a rectangle from the original img.
|
||||
|
||||
Args:
|
||||
|
@ -614,7 +641,7 @@ def cutout(img, shape, pad_val=0):
|
|||
if img.ndim == 2:
|
||||
patch_shape = (y2 - y1, x2 - x1)
|
||||
else:
|
||||
patch_shape = (y2 - y1, x2 - x1, channels)
|
||||
patch_shape = (y2 - y1, x2 - x1, channels) # type: ignore
|
||||
|
||||
img_cutout = img.copy()
|
||||
patch = np.array(
|
||||
|
@ -625,7 +652,8 @@ def cutout(img, shape, pad_val=0):
|
|||
return img_cutout
|
||||
|
||||
|
||||
def _get_shear_matrix(magnitude, direction='horizontal'):
|
||||
def _get_shear_matrix(magnitude: Union[int, float],
|
||||
direction: str = 'horizontal') -> np.ndarray:
|
||||
"""Generate the shear matrix for transformation.
|
||||
|
||||
Args:
|
||||
|
@ -643,11 +671,11 @@ def _get_shear_matrix(magnitude, direction='horizontal'):
|
|||
return shear_matrix
|
||||
|
||||
|
||||
def imshear(img,
|
||||
magnitude,
|
||||
direction='horizontal',
|
||||
border_value=0,
|
||||
interpolation='bilinear'):
|
||||
def imshear(img: np.ndarray,
|
||||
magnitude: Union[int, float],
|
||||
direction: str = 'horizontal',
|
||||
border_value: Union[int, Tuple[int, int]] = 0,
|
||||
interpolation: str = 'bilinear') -> np.ndarray:
|
||||
"""Shear an image.
|
||||
|
||||
Args:
|
||||
|
@ -671,7 +699,7 @@ def imshear(img,
|
|||
elif img.ndim == 3:
|
||||
channels = img.shape[-1]
|
||||
if isinstance(border_value, int):
|
||||
border_value = tuple([border_value] * channels)
|
||||
border_value = tuple([border_value] * channels) # type: ignore
|
||||
elif isinstance(border_value, tuple):
|
||||
assert len(border_value) == channels, \
|
||||
'Expected the num of elements in tuple equals the channels' \
|
||||
|
@ -689,12 +717,13 @@ def imshear(img,
|
|||
# greater than 3 (e.g. shearing masks whose channels large
|
||||
# than 3) will raise TypeError in `cv2.warpAffine`.
|
||||
# Here simply slice the first 3 values in `border_value`.
|
||||
borderValue=border_value[:3],
|
||||
borderValue=border_value[:3], # type: ignore
|
||||
flags=cv2_interp_codes[interpolation])
|
||||
return sheared
|
||||
|
||||
|
||||
def _get_translate_matrix(offset, direction='horizontal'):
|
||||
def _get_translate_matrix(offset: Union[int, float],
|
||||
direction: str = 'horizontal') -> np.ndarray:
|
||||
"""Generate the translate matrix.
|
||||
|
||||
Args:
|
||||
|
@ -712,11 +741,11 @@ def _get_translate_matrix(offset, direction='horizontal'):
|
|||
return translate_matrix
|
||||
|
||||
|
||||
def imtranslate(img,
|
||||
offset,
|
||||
direction='horizontal',
|
||||
border_value=0,
|
||||
interpolation='bilinear'):
|
||||
def imtranslate(img: np.ndarray,
|
||||
offset: Union[int, float],
|
||||
direction: str = 'horizontal',
|
||||
border_value: Union[int, tuple] = 0,
|
||||
interpolation: str = 'bilinear') -> np.ndarray:
|
||||
"""Translate an image.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -3,6 +3,7 @@ import io
|
|||
import os.path as osp
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -42,7 +43,7 @@ imread_flags = {
|
|||
imread_backend = 'cv2'
|
||||
|
||||
|
||||
def use_backend(backend):
|
||||
def use_backend(backend: str) -> None:
|
||||
"""Select a backend for image decoding.
|
||||
|
||||
Args:
|
||||
|
@ -68,7 +69,7 @@ def use_backend(backend):
|
|||
raise ImportError('`tifffile` is not installed')
|
||||
|
||||
|
||||
def _jpegflag(flag='color', channel_order='bgr'):
|
||||
def _jpegflag(flag: str = 'color', channel_order: str = 'bgr'):
|
||||
channel_order = channel_order.lower()
|
||||
if channel_order not in ['rgb', 'bgr']:
|
||||
raise ValueError('channel order must be either "rgb" or "bgr"')
|
||||
|
@ -84,7 +85,9 @@ def _jpegflag(flag='color', channel_order='bgr'):
|
|||
raise ValueError('flag must be "color" or "grayscale"')
|
||||
|
||||
|
||||
def _pillow2array(img, flag='color', channel_order='bgr'):
|
||||
def _pillow2array(img,
|
||||
flag: str = 'color',
|
||||
channel_order: str = 'bgr') -> np.ndarray:
|
||||
"""Convert a pillow image to numpy array.
|
||||
|
||||
Args:
|
||||
|
@ -139,11 +142,11 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
|
|||
return array
|
||||
|
||||
|
||||
def imread(img_or_path,
|
||||
flag='color',
|
||||
channel_order='bgr',
|
||||
backend=None,
|
||||
file_client_args=None):
|
||||
def imread(img_or_path: Union[np.ndarray, str, Path],
|
||||
flag: str = 'color',
|
||||
channel_order: str = 'bgr',
|
||||
backend: Optional[str] = None,
|
||||
file_client_args: Optional[dict] = None) -> np.ndarray:
|
||||
"""Read an image.
|
||||
|
||||
Note:
|
||||
|
@ -207,7 +210,10 @@ def imread(img_or_path,
|
|||
'a pathlib.Path object')
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
|
||||
def imfrombytes(content: bytes,
|
||||
flag: str = 'color',
|
||||
channel_order: str = 'bgr',
|
||||
backend: Optional[str] = None) -> np.ndarray:
|
||||
"""Read an image from bytes.
|
||||
|
||||
Args:
|
||||
|
@ -240,7 +246,8 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
|
|||
f'backend: {backend} is not supported. Supported '
|
||||
"backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'")
|
||||
if backend == 'turbojpeg':
|
||||
img = jpeg.decode(content, _jpegflag(flag, channel_order))
|
||||
img = jpeg.decode( # type: ignore
|
||||
content, _jpegflag(flag, channel_order))
|
||||
if img.shape[-1] == 1:
|
||||
img = img[:, :, 0]
|
||||
return img
|
||||
|
@ -262,11 +269,11 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
|
|||
return img
|
||||
|
||||
|
||||
def imwrite(img,
|
||||
file_path,
|
||||
params=None,
|
||||
auto_mkdir=None,
|
||||
file_client_args=None):
|
||||
def imwrite(img: np.ndarray,
|
||||
file_path: str,
|
||||
params: Optional[list] = None,
|
||||
auto_mkdir: Optional[bool] = None,
|
||||
file_client_args: Optional[dict] = None) -> bool:
|
||||
"""Write image to file.
|
||||
|
||||
Note:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mmcv
|
||||
|
@ -9,7 +11,10 @@ except ImportError:
|
|||
torch = None
|
||||
|
||||
|
||||
def tensor2imgs(tensor, mean=None, std=None, to_rgb=True):
|
||||
def tensor2imgs(tensor,
|
||||
mean: Optional[tuple] = None,
|
||||
std: Optional[tuple] = None,
|
||||
to_rgb: bool = True) -> list:
|
||||
"""Convert tensor to 3-channel images or 1-channel gray images.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue