diff --git a/mmcv/arraymisc/quantization.py b/mmcv/arraymisc/quantization.py index 8e47a3545..6182710d5 100644 --- a/mmcv/arraymisc/quantization.py +++ b/mmcv/arraymisc/quantization.py @@ -1,14 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + import numpy as np -def quantize(arr, min_val, max_val, levels, dtype=np.int64): +def quantize(arr: np.ndarray, + min_val: Union[int, float], + max_val: Union[int, float], + levels: int, + dtype=np.int64) -> tuple: """Quantize an array of (-inf, inf) to [0, levels-1]. Args: arr (ndarray): Input array. - min_val (scalar): Minimum value to be clipped. - max_val (scalar): Maximum value to be clipped. + min_val (int or float): Minimum value to be clipped. + max_val (int or float): Maximum value to be clipped. levels (int): Quantization levels. dtype (np.type): The type of the quantized array. @@ -29,13 +35,17 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64): return quantized_arr -def dequantize(arr, min_val, max_val, levels, dtype=np.float64): +def dequantize(arr: np.ndarray, + min_val: Union[int, float], + max_val: Union[int, float], + levels: int, + dtype=np.float64) -> tuple: """Dequantize an array. Args: arr (ndarray): Input array. - min_val (scalar): Minimum value to be clipped. - max_val (scalar): Maximum value to be clipped. + min_val (int or float): Minimum value to be clipped. + max_val (int or float): Maximum value to be clipped. levels (int): Quantization levels. dtype (np.type): The type of the dequantized array. diff --git a/mmcv/tensorrt/init_plugins.py b/mmcv/tensorrt/init_plugins.py index 7c0eff0e4..909b9ae28 100644 --- a/mmcv/tensorrt/init_plugins.py +++ b/mmcv/tensorrt/init_plugins.py @@ -5,7 +5,7 @@ import os import warnings -def get_tensorrt_op_path(): +def get_tensorrt_op_path() -> str: """Get TensorRT plugins library path.""" # Following strings of text style are from colorama package bright_style, reset_style = '\x1b[1m', '\x1b[0m' @@ -31,7 +31,7 @@ def get_tensorrt_op_path(): plugin_is_loaded = False -def is_tensorrt_plugin_loaded(): +def is_tensorrt_plugin_loaded() -> bool: """Check if TensorRT plugins library is loaded or not. Returns: @@ -54,7 +54,7 @@ def is_tensorrt_plugin_loaded(): return plugin_is_loaded -def load_tensorrt_plugin(): +def load_tensorrt_plugin() -> None: """load TensorRT plugins library.""" # Following strings of text style are from colorama package diff --git a/mmcv/tensorrt/preprocess.py b/mmcv/tensorrt/preprocess.py index ca6b3674f..a0ad25428 100644 --- a/mmcv/tensorrt/preprocess.py +++ b/mmcv/tensorrt/preprocess.py @@ -5,7 +5,7 @@ import numpy as np import onnx -def preprocess_onnx(onnx_model): +def preprocess_onnx(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Modify onnx model to match with TensorRT plugins in mmcv. There are some conflict between onnx node definition and TensorRT limit. diff --git a/mmcv/tensorrt/tensorrt_utils.py b/mmcv/tensorrt/tensorrt_utils.py index 2ddff2cd5..ed99893d7 100644 --- a/mmcv/tensorrt/tensorrt_utils.py +++ b/mmcv/tensorrt/tensorrt_utils.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from typing import Union import onnx import tensorrt as trt @@ -8,12 +9,12 @@ import torch from .preprocess import preprocess_onnx -def onnx2trt(onnx_model, - opt_shape_dict, - log_level=trt.Logger.ERROR, - fp16_mode=False, - max_workspace_size=0, - device_id=0): +def onnx2trt(onnx_model: Union[str, onnx.ModelProto], + opt_shape_dict: dict, + log_level: trt.ILogger.Severity = trt.Logger.ERROR, + fp16_mode: bool = False, + max_workspace_size: int = 0, + device_id: int = 0) -> trt.ICudaEngine: """Convert onnx model to tensorrt engine. Arguments: @@ -100,7 +101,7 @@ def onnx2trt(onnx_model, return engine -def save_trt_engine(engine, path): +def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None: """Serialize TensorRT engine to disk. Arguments: @@ -124,7 +125,7 @@ def save_trt_engine(engine, path): f.write(bytearray(engine.serialize())) -def load_trt_engine(path): +def load_trt_engine(path: str) -> trt.ICudaEngine: """Deserialize TensorRT engine from disk. Arguments: @@ -153,7 +154,7 @@ def load_trt_engine(path): return engine -def torch_dtype_from_trt(dtype): +def torch_dtype_from_trt(dtype: trt.DataType) -> Union[torch.dtype, TypeError]: """Convert pytorch dtype to TensorRT dtype.""" if dtype == trt.bool: return torch.bool @@ -169,7 +170,8 @@ def torch_dtype_from_trt(dtype): raise TypeError('%s is not supported by torch' % dtype) -def torch_device_from_trt(device): +def torch_device_from_trt( + device: trt.TensorLocation) -> Union[torch.device, TypeError]: """Convert pytorch device to TensorRT device.""" if device == trt.TensorLocation.DEVICE: return torch.device('cuda') diff --git a/mmcv/video/io.py b/mmcv/video/io.py index 8d9bf9d3a..09fa770db 100644 --- a/mmcv/video/io.py +++ b/mmcv/video/io.py @@ -272,14 +272,14 @@ class VideoReader: self._vcap.release() -def frames2video(frame_dir, - video_file, - fps=30, - fourcc='XVID', - filename_tmpl='{:06d}.jpg', - start=0, - end=0, - show_progress=True): +def frames2video(frame_dir: str, + video_file: str, + fps: float = 30, + fourcc: str = 'XVID', + filename_tmpl: str = '{:06d}.jpg', + start: int = 0, + end: int = 0, + show_progress: bool = True) -> None: """Read the frame images from a directory and join them as a video. Args: diff --git a/mmcv/video/optflow.py b/mmcv/video/optflow.py index 2e6518900..7cea95349 100644 --- a/mmcv/video/optflow.py +++ b/mmcv/video/optflow.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from typing import Tuple, Union import cv2 import numpy as np @@ -9,7 +10,11 @@ from mmcv.image import imread, imwrite from mmcv.utils import is_str -def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): +def flowread(flow_or_path: Union[np.ndarray, str], + quantize: bool = False, + concat_axis: int = 0, + *args, + **kwargs) -> np.ndarray: """Read an optical flow map. Args: @@ -58,7 +63,12 @@ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): return flow.astype(np.float32) -def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): +def flowwrite(flow: np.ndarray, + filename: str, + quantize: bool = False, + concat_axis: int = 0, + *args, + **kwargs) -> None: """Write optical flow to file. If the flow is not quantized, it will be saved as a .flo file losslessly, @@ -88,7 +98,9 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): imwrite(dxdy, filename) -def quantize_flow(flow, max_val=0.02, norm=True): +def quantize_flow(flow: np.ndarray, + max_val: float = 0.02, + norm: bool = True) -> tuple: """Quantize flow to [0, 255]. After this step, the size of flow will be much smaller, and can be @@ -116,7 +128,10 @@ def quantize_flow(flow, max_val=0.02, norm=True): return tuple(flow_comps) -def dequantize_flow(dx, dy, max_val=0.02, denorm=True): +def dequantize_flow(dx: np.ndarray, + dy: np.ndarray, + max_val: float = 0.02, + denorm: bool = True) -> np.ndarray: """Recover from quantized flow. Args: @@ -140,12 +155,15 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True): return flow -def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): +def flow_warp(img: np.ndarray, + flow: np.ndarray, + filling_value: int = 0, + interpolate_mode: str = 'nearest') -> np.ndarray: """Use flow to warp img. Args: - img (ndarray, float or uint8): Image to be warped. - flow (ndarray, float): Optical Flow. + img (ndarray): Image to be warped. + flow (ndarray): Optical Flow. filling_value (int): The missing pixels will be set with filling_value. interpolate_mode (str): bilinear -> Bilinear Interpolation; nearest -> Nearest Neighbor. @@ -201,7 +219,7 @@ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): return output.astype(img.dtype) -def flow_from_bytes(content): +def flow_from_bytes(content: bytes) -> np.ndarray: """Read dense optical flow from bytes. .. note:: @@ -231,7 +249,7 @@ def flow_from_bytes(content): return flow -def sparse_flow_from_bytes(content): +def sparse_flow_from_bytes(content: bytes) -> Tuple[np.ndarray, np.ndarray]: """Read the optical flow in KITTI datasets from bytes. This function is modified from RAFT load the `KITTI datasets