From 0412097ec2cd042037c613a7c0f0ec3d325f05ae Mon Sep 17 00:00:00 2001 From: tripleMu Date: Mon, 12 Sep 2022 17:16:55 +0800 Subject: [PATCH] Add type hints for mmcv/image (#2089) * Fix typehint * minor fix * minor fix * minor fix Co-authored-by: zhouzaida --- mmcv/image/geometric.py | 143 ++++++++++++++++++++++++---------------- mmcv/image/io.py | 37 ++++++----- mmcv/image/misc.py | 7 +- 3 files changed, 114 insertions(+), 73 deletions(-) diff --git a/mmcv/image/geometric.py b/mmcv/image/geometric.py index 91f1979bf..a5d0f8fa3 100644 --- a/mmcv/image/geometric.py +++ b/mmcv/image/geometric.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import numbers import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union, no_type_check import cv2 import numpy as np @@ -15,7 +15,10 @@ except ImportError: Image = None -def _scale_size(size, scale): +def _scale_size( + size: Tuple[int, int], + scale: Union[float, int, tuple], +) -> Tuple[int, int]: """Rescale a size by a ratio. Args: @@ -72,12 +75,14 @@ if Image is not None: } -def imresize(img, - size, - return_scale=False, - interpolation='bilinear', - out=None, - backend=None): +def imresize( + img: np.ndarray, + size: Tuple[int, int], + return_scale: bool = False, + interpolation: str = 'bilinear', + out: Optional[np.ndarray] = None, + backend: Optional[str] = None +) -> Union[Tuple[np.ndarray, float, float], np.ndarray]: """Resize image to a given size. Args: @@ -119,15 +124,18 @@ def imresize(img, return resized_img, w_scale, h_scale -def imresize_to_multiple(img, - divisor, - size=None, - scale_factor=None, - keep_ratio=False, - return_scale=False, - interpolation='bilinear', - out=None, - backend=None): +@no_type_check +def imresize_to_multiple( + img: np.ndarray, + divisor: Union[int, Tuple[int, int]], + size: Union[int, Tuple[int, int], None] = None, + scale_factor: Union[float, Tuple[float, float], None] = None, + keep_ratio: bool = False, + return_scale: bool = False, + interpolation: str = 'bilinear', + out: Optional[np.ndarray] = None, + backend: Optional[str] = None +) -> Union[Tuple[np.ndarray, float, float], np.ndarray]: """Resize image according to a given size or scale factor and then rounds up the the resized or rescaled image size to the nearest value that can be divided by the divisor. @@ -183,11 +191,13 @@ def imresize_to_multiple(img, return resized_img -def imresize_like(img, - dst_img, - return_scale=False, - interpolation='bilinear', - backend=None): +def imresize_like( + img: np.ndarray, + dst_img: np.ndarray, + return_scale: bool = False, + interpolation: str = 'bilinear', + backend: Optional[str] = None +) -> Union[Tuple[np.ndarray, float, float], np.ndarray]: """Resize image to the same size of a given image. Args: @@ -205,7 +215,9 @@ def imresize_like(img, return imresize(img, (w, h), return_scale, interpolation, backend=backend) -def rescale_size(old_size, scale, return_scale=False): +def rescale_size(old_size: tuple, + scale: Union[float, int, tuple], + return_scale: bool = False) -> tuple: """Calculate the new size to be rescaled to. Args: @@ -242,11 +254,13 @@ def rescale_size(old_size, scale, return_scale=False): return new_size -def imrescale(img, - scale, - return_scale=False, - interpolation='bilinear', - backend=None): +def imrescale( + img: np.ndarray, + scale: Union[float, Tuple[int, int]], + return_scale: bool = False, + interpolation: str = 'bilinear', + backend: Optional[str] = None +) -> Union[np.ndarray, Tuple[np.ndarray, float]]: """Resize image while keeping the aspect ratio. Args: @@ -273,7 +287,7 @@ def imrescale(img, return rescaled_img -def imflip(img, direction='horizontal'): +def imflip(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray: """Flip an image horizontally or vertically. Args: @@ -293,7 +307,7 @@ def imflip(img, direction='horizontal'): return np.flip(img, axis=(0, 1)) -def imflip_(img, direction='horizontal'): +def imflip_(img: np.ndarray, direction: str = 'horizontal') -> np.ndarray: """Inplace flip an image horizontally or vertically. Args: @@ -373,7 +387,7 @@ def imrotate(img: np.ndarray, return rotated -def bbox_clip(bboxes, img_shape): +def bbox_clip(bboxes: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray: """Clip bboxes to fit the image shape. Args: @@ -391,7 +405,9 @@ def bbox_clip(bboxes, img_shape): return clipped_bboxes -def bbox_scaling(bboxes, scale, clip_shape=None): +def bbox_scaling(bboxes: np.ndarray, + scale: float, + clip_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: """Scaling bboxes w.r.t the box center. Args: @@ -417,7 +433,12 @@ def bbox_scaling(bboxes, scale, clip_shape=None): return scaled_bboxes -def imcrop(img, bboxes, scale=1.0, pad_fill=None): +def imcrop( + img: np.ndarray, + bboxes: np.ndarray, + scale: float = 1.0, + pad_fill: Union[float, list, None] = None +) -> Union[np.ndarray, List[np.ndarray]]: """Crop image patches. 3 steps: scale the bboxes -> clip bboxes -> crop and pad. @@ -450,10 +471,12 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None): patch = img[y1:y2 + 1, x1:x2 + 1, ...] else: _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :]) + patch_h = _y2 - _y1 + 1 + patch_w = _x2 - _x1 + 1 if chn == 1: - patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1) + patch_shape = (patch_h, patch_w) else: - patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn) + patch_shape = (patch_h, patch_w, chn) # type: ignore patch = np.array( pad_fill, dtype=img.dtype) * np.ones( patch_shape, dtype=img.dtype) @@ -471,12 +494,12 @@ def imcrop(img, bboxes, scale=1.0, pad_fill=None): return patches -def impad(img, +def impad(img: np.ndarray, *, - shape=None, - padding=None, - pad_val=0, - padding_mode='constant'): + shape: Optional[Tuple[int, int]] = None, + padding: Union[int, tuple, None] = None, + pad_val: Union[float, List] = 0, + padding_mode: str = 'constant') -> np.ndarray: """Pad the given image to a certain shape or pad on all sides with specified padding mode and padding value. @@ -555,7 +578,9 @@ def impad(img, return img -def impad_to_multiple(img, divisor, pad_val=0): +def impad_to_multiple(img: np.ndarray, + divisor: int, + pad_val: Union[float, List] = 0) -> np.ndarray: """Pad an image to ensure each edge to be multiple to some number. Args: @@ -571,7 +596,9 @@ def impad_to_multiple(img, divisor, pad_val=0): return impad(img, shape=(pad_h, pad_w), pad_val=pad_val) -def cutout(img, shape, pad_val=0): +def cutout(img: np.ndarray, + shape: Union[int, Tuple[int, int]], + pad_val: Union[int, float, tuple] = 0) -> np.ndarray: """Randomly cut out a rectangle from the original img. Args: @@ -615,7 +642,7 @@ def cutout(img, shape, pad_val=0): if img.ndim == 2: patch_shape = (y2 - y1, x2 - x1) else: - patch_shape = (y2 - y1, x2 - x1, channels) + patch_shape = (y2 - y1, x2 - x1, channels) # type: ignore img_cutout = img.copy() patch = np.array( @@ -626,7 +653,8 @@ def cutout(img, shape, pad_val=0): return img_cutout -def _get_shear_matrix(magnitude, direction='horizontal'): +def _get_shear_matrix(magnitude: Union[int, float], + direction: str = 'horizontal') -> np.ndarray: """Generate the shear matrix for transformation. Args: @@ -644,11 +672,11 @@ def _get_shear_matrix(magnitude, direction='horizontal'): return shear_matrix -def imshear(img, - magnitude, - direction='horizontal', - border_value=0, - interpolation='bilinear'): +def imshear(img: np.ndarray, + magnitude: Union[int, float], + direction: str = 'horizontal', + border_value: Union[int, Tuple[int, int]] = 0, + interpolation: str = 'bilinear') -> np.ndarray: """Shear an image. Args: @@ -672,7 +700,7 @@ def imshear(img, elif img.ndim == 3: channels = img.shape[-1] if isinstance(border_value, int): - border_value = tuple([border_value] * channels) + border_value = tuple([border_value] * channels) # type: ignore elif isinstance(border_value, tuple): assert len(border_value) == channels, \ 'Expected the num of elements in tuple equals the channels' \ @@ -690,12 +718,13 @@ def imshear(img, # greater than 3 (e.g. shearing masks whose channels large # than 3) will raise TypeError in `cv2.warpAffine`. # Here simply slice the first 3 values in `border_value`. - borderValue=border_value[:3], + borderValue=border_value[:3], # type: ignore flags=cv2_interp_codes[interpolation]) return sheared -def _get_translate_matrix(offset, direction='horizontal'): +def _get_translate_matrix(offset: Union[int, float], + direction: str = 'horizontal') -> np.ndarray: """Generate the translate matrix. Args: @@ -713,11 +742,11 @@ def _get_translate_matrix(offset, direction='horizontal'): return translate_matrix -def imtranslate(img, - offset, - direction='horizontal', - border_value=0, - interpolation='bilinear'): +def imtranslate(img: np.ndarray, + offset: Union[int, float], + direction: str = 'horizontal', + border_value: Union[int, tuple] = 0, + interpolation: str = 'bilinear') -> np.ndarray: """Translate an image. Args: diff --git a/mmcv/image/io.py b/mmcv/image/io.py index af13d38b6..8d2c86235 100644 --- a/mmcv/image/io.py +++ b/mmcv/image/io.py @@ -3,6 +3,7 @@ import io import os.path as osp import warnings from pathlib import Path +from typing import Optional, Union import cv2 import numpy as np @@ -41,7 +42,7 @@ imread_flags = { imread_backend = 'cv2' -def use_backend(backend): +def use_backend(backend: str) -> None: """Select a backend for image decoding. Args: @@ -67,7 +68,7 @@ def use_backend(backend): raise ImportError('`tifffile` is not installed') -def _jpegflag(flag='color', channel_order='bgr'): +def _jpegflag(flag: str = 'color', channel_order: str = 'bgr'): channel_order = channel_order.lower() if channel_order not in ['rgb', 'bgr']: raise ValueError('channel order must be either "rgb" or "bgr"') @@ -83,7 +84,9 @@ def _jpegflag(flag='color', channel_order='bgr'): raise ValueError('flag must be "color" or "grayscale"') -def _pillow2array(img, flag='color', channel_order='bgr'): +def _pillow2array(img, + flag: str = 'color', + channel_order: str = 'bgr') -> np.ndarray: """Convert a pillow image to numpy array. Args: @@ -138,11 +141,11 @@ def _pillow2array(img, flag='color', channel_order='bgr'): return array -def imread(img_or_path, - flag='color', - channel_order='bgr', - backend=None, - file_client_args=None): +def imread(img_or_path: Union[np.ndarray, str, Path], + flag: str = 'color', + channel_order: str = 'bgr', + backend: Optional[str] = None, + file_client_args: Optional[dict] = None) -> np.ndarray: """Read an image. Note: @@ -206,7 +209,10 @@ def imread(img_or_path, 'a pathlib.Path object') -def imfrombytes(content, flag='color', channel_order='bgr', backend=None): +def imfrombytes(content: bytes, + flag: str = 'color', + channel_order: str = 'bgr', + backend: Optional[str] = None) -> np.ndarray: """Read an image from bytes. Args: @@ -239,7 +245,8 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None): f'backend: {backend} is not supported. Supported ' "backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'") if backend == 'turbojpeg': - img = jpeg.decode(content, _jpegflag(flag, channel_order)) + img = jpeg.decode( # type: ignore + content, _jpegflag(flag, channel_order)) if img.shape[-1] == 1: img = img[:, :, 0] return img @@ -261,11 +268,11 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None): return img -def imwrite(img, - file_path, - params=None, - auto_mkdir=None, - file_client_args=None): +def imwrite(img: np.ndarray, + file_path: str, + params: Optional[list] = None, + auto_mkdir: Optional[bool] = None, + file_client_args: Optional[dict] = None) -> bool: """Write image to file. Note: diff --git a/mmcv/image/misc.py b/mmcv/image/misc.py index 43934a689..e923cad4e 100644 --- a/mmcv/image/misc.py +++ b/mmcv/image/misc.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import numpy as np import mmcv @@ -9,7 +11,10 @@ except ImportError: torch = None -def tensor2imgs(tensor, mean=None, std=None, to_rgb=True): +def tensor2imgs(tensor, + mean: Optional[tuple] = None, + std: Optional[tuple] = None, + to_rgb: bool = True) -> list: """Convert tensor to 3-channel images or 1-channel gray images. Args: