[Enhancement] Add type hints for mmcv/arraymisc and mmcv/video (#1950)

* Add type hints

* Add type hints

* Fix int float about scalar

* Add type hints for mmcv/tensorrt

* Update mmcv/tensorrt/tensorrt_utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/arraymisc/quantization.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Ignore type hint for dtype

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/1954/head
Song Lin 2022-05-11 21:55:41 +08:00 committed by GitHub
parent 882cab77bb
commit 94cc99d595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 37 deletions

View File

@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import numpy as np
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
def quantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.int64) -> tuple:
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
min_val (int or float): Minimum value to be clipped.
max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
@ -29,13 +35,17 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
def dequantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.float64) -> tuple:
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
min_val (int or float): Minimum value to be clipped.
max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.

View File

@ -5,7 +5,7 @@ import os
import warnings
def get_tensorrt_op_path():
def get_tensorrt_op_path() -> str:
"""Get TensorRT plugins library path."""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
@ -31,7 +31,7 @@ def get_tensorrt_op_path():
plugin_is_loaded = False
def is_tensorrt_plugin_loaded():
def is_tensorrt_plugin_loaded() -> bool:
"""Check if TensorRT plugins library is loaded or not.
Returns:
@ -54,7 +54,7 @@ def is_tensorrt_plugin_loaded():
return plugin_is_loaded
def load_tensorrt_plugin():
def load_tensorrt_plugin() -> None:
"""load TensorRT plugins library."""
# Following strings of text style are from colorama package

View File

@ -5,7 +5,7 @@ import numpy as np
import onnx
def preprocess_onnx(onnx_model):
def preprocess_onnx(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Union
import onnx
import tensorrt as trt
@ -8,12 +9,12 @@ import torch
from .preprocess import preprocess_onnx
def onnx2trt(onnx_model,
opt_shape_dict,
log_level=trt.Logger.ERROR,
fp16_mode=False,
max_workspace_size=0,
device_id=0):
def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
opt_shape_dict: dict,
log_level: trt.ILogger.Severity = trt.Logger.ERROR,
fp16_mode: bool = False,
max_workspace_size: int = 0,
device_id: int = 0) -> trt.ICudaEngine:
"""Convert onnx model to tensorrt engine.
Arguments:
@ -100,7 +101,7 @@ def onnx2trt(onnx_model,
return engine
def save_trt_engine(engine, path):
def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None:
"""Serialize TensorRT engine to disk.
Arguments:
@ -124,7 +125,7 @@ def save_trt_engine(engine, path):
f.write(bytearray(engine.serialize()))
def load_trt_engine(path):
def load_trt_engine(path: str) -> trt.ICudaEngine:
"""Deserialize TensorRT engine from disk.
Arguments:
@ -153,7 +154,7 @@ def load_trt_engine(path):
return engine
def torch_dtype_from_trt(dtype):
def torch_dtype_from_trt(dtype: trt.DataType) -> Union[torch.dtype, TypeError]:
"""Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool:
return torch.bool
@ -169,7 +170,8 @@ def torch_dtype_from_trt(dtype):
raise TypeError('%s is not supported by torch' % dtype)
def torch_device_from_trt(device):
def torch_device_from_trt(
device: trt.TensorLocation) -> Union[torch.device, TypeError]:
"""Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')

View File

@ -272,14 +272,14 @@ class VideoReader:
self._vcap.release()
def frames2video(frame_dir,
video_file,
fps=30,
fourcc='XVID',
filename_tmpl='{:06d}.jpg',
start=0,
end=0,
show_progress=True):
def frames2video(frame_dir: str,
video_file: str,
fps: float = 30,
fourcc: str = 'XVID',
filename_tmpl: str = '{:06d}.jpg',
start: int = 0,
end: int = 0,
show_progress: bool = True) -> None:
"""Read the frame images from a directory and join them as a video.
Args:

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Tuple, Union
import cv2
import numpy as np
@ -9,7 +10,11 @@ from mmcv.image import imread, imwrite
from mmcv.utils import is_str
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
def flowread(flow_or_path: Union[np.ndarray, str],
quantize: bool = False,
concat_axis: int = 0,
*args,
**kwargs) -> np.ndarray:
"""Read an optical flow map.
Args:
@ -58,7 +63,12 @@ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
return flow.astype(np.float32)
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
def flowwrite(flow: np.ndarray,
filename: str,
quantize: bool = False,
concat_axis: int = 0,
*args,
**kwargs) -> None:
"""Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly,
@ -88,7 +98,9 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
imwrite(dxdy, filename)
def quantize_flow(flow, max_val=0.02, norm=True):
def quantize_flow(flow: np.ndarray,
max_val: float = 0.02,
norm: bool = True) -> tuple:
"""Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be
@ -116,7 +128,10 @@ def quantize_flow(flow, max_val=0.02, norm=True):
return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
def dequantize_flow(dx: np.ndarray,
dy: np.ndarray,
max_val: float = 0.02,
denorm: bool = True) -> np.ndarray:
"""Recover from quantized flow.
Args:
@ -140,12 +155,15 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
return flow
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
def flow_warp(img: np.ndarray,
flow: np.ndarray,
filling_value: int = 0,
interpolate_mode: str = 'nearest') -> np.ndarray:
"""Use flow to warp img.
Args:
img (ndarray, float or uint8): Image to be warped.
flow (ndarray, float): Optical Flow.
img (ndarray): Image to be warped.
flow (ndarray): Optical Flow.
filling_value (int): The missing pixels will be set with filling_value.
interpolate_mode (str): bilinear -> Bilinear Interpolation;
nearest -> Nearest Neighbor.
@ -201,7 +219,7 @@ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
return output.astype(img.dtype)
def flow_from_bytes(content):
def flow_from_bytes(content: bytes) -> np.ndarray:
"""Read dense optical flow from bytes.
.. note::
@ -231,7 +249,7 @@ def flow_from_bytes(content):
return flow
def sparse_flow_from_bytes(content):
def sparse_flow_from_bytes(content: bytes) -> Tuple[np.ndarray, np.ndarray]:
"""Read the optical flow in KITTI datasets from bytes.
This function is modified from RAFT load the `KITTI datasets