[Enhancement] Add type hints for mmcv/arraymisc and mmcv/video (#1950)

* Add type hints

* Add type hints

* Fix int float about scalar

* Add type hints for mmcv/tensorrt

* Update mmcv/tensorrt/tensorrt_utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/arraymisc/quantization.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Ignore type hint for dtype

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Song Lin 2022-05-11 21:55:41 +08:00 committed by GitHub
parent 882cab77bb
commit 94cc99d595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 37 deletions

View File

@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import numpy as np import numpy as np
def quantize(arr, min_val, max_val, levels, dtype=np.int64): def quantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.int64) -> tuple:
"""Quantize an array of (-inf, inf) to [0, levels-1]. """Quantize an array of (-inf, inf) to [0, levels-1].
Args: Args:
arr (ndarray): Input array. arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped. min_val (int or float): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped. max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels. levels (int): Quantization levels.
dtype (np.type): The type of the quantized array. dtype (np.type): The type of the quantized array.
@ -29,13 +35,17 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
return quantized_arr return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64): def dequantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.float64) -> tuple:
"""Dequantize an array. """Dequantize an array.
Args: Args:
arr (ndarray): Input array. arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped. min_val (int or float): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped. max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels. levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array. dtype (np.type): The type of the dequantized array.

View File

@ -5,7 +5,7 @@ import os
import warnings import warnings
def get_tensorrt_op_path(): def get_tensorrt_op_path() -> str:
"""Get TensorRT plugins library path.""" """Get TensorRT plugins library path."""
# Following strings of text style are from colorama package # Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m' bright_style, reset_style = '\x1b[1m', '\x1b[0m'
@ -31,7 +31,7 @@ def get_tensorrt_op_path():
plugin_is_loaded = False plugin_is_loaded = False
def is_tensorrt_plugin_loaded(): def is_tensorrt_plugin_loaded() -> bool:
"""Check if TensorRT plugins library is loaded or not. """Check if TensorRT plugins library is loaded or not.
Returns: Returns:
@ -54,7 +54,7 @@ def is_tensorrt_plugin_loaded():
return plugin_is_loaded return plugin_is_loaded
def load_tensorrt_plugin(): def load_tensorrt_plugin() -> None:
"""load TensorRT plugins library.""" """load TensorRT plugins library."""
# Following strings of text style are from colorama package # Following strings of text style are from colorama package

View File

@ -5,7 +5,7 @@ import numpy as np
import onnx import onnx
def preprocess_onnx(onnx_model): def preprocess_onnx(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Modify onnx model to match with TensorRT plugins in mmcv. """Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit. There are some conflict between onnx node definition and TensorRT limit.

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Union
import onnx import onnx
import tensorrt as trt import tensorrt as trt
@ -8,12 +9,12 @@ import torch
from .preprocess import preprocess_onnx from .preprocess import preprocess_onnx
def onnx2trt(onnx_model, def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
opt_shape_dict, opt_shape_dict: dict,
log_level=trt.Logger.ERROR, log_level: trt.ILogger.Severity = trt.Logger.ERROR,
fp16_mode=False, fp16_mode: bool = False,
max_workspace_size=0, max_workspace_size: int = 0,
device_id=0): device_id: int = 0) -> trt.ICudaEngine:
"""Convert onnx model to tensorrt engine. """Convert onnx model to tensorrt engine.
Arguments: Arguments:
@ -100,7 +101,7 @@ def onnx2trt(onnx_model,
return engine return engine
def save_trt_engine(engine, path): def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None:
"""Serialize TensorRT engine to disk. """Serialize TensorRT engine to disk.
Arguments: Arguments:
@ -124,7 +125,7 @@ def save_trt_engine(engine, path):
f.write(bytearray(engine.serialize())) f.write(bytearray(engine.serialize()))
def load_trt_engine(path): def load_trt_engine(path: str) -> trt.ICudaEngine:
"""Deserialize TensorRT engine from disk. """Deserialize TensorRT engine from disk.
Arguments: Arguments:
@ -153,7 +154,7 @@ def load_trt_engine(path):
return engine return engine
def torch_dtype_from_trt(dtype): def torch_dtype_from_trt(dtype: trt.DataType) -> Union[torch.dtype, TypeError]:
"""Convert pytorch dtype to TensorRT dtype.""" """Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool: if dtype == trt.bool:
return torch.bool return torch.bool
@ -169,7 +170,8 @@ def torch_dtype_from_trt(dtype):
raise TypeError('%s is not supported by torch' % dtype) raise TypeError('%s is not supported by torch' % dtype)
def torch_device_from_trt(device): def torch_device_from_trt(
device: trt.TensorLocation) -> Union[torch.device, TypeError]:
"""Convert pytorch device to TensorRT device.""" """Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE: if device == trt.TensorLocation.DEVICE:
return torch.device('cuda') return torch.device('cuda')

View File

@ -272,14 +272,14 @@ class VideoReader:
self._vcap.release() self._vcap.release()
def frames2video(frame_dir, def frames2video(frame_dir: str,
video_file, video_file: str,
fps=30, fps: float = 30,
fourcc='XVID', fourcc: str = 'XVID',
filename_tmpl='{:06d}.jpg', filename_tmpl: str = '{:06d}.jpg',
start=0, start: int = 0,
end=0, end: int = 0,
show_progress=True): show_progress: bool = True) -> None:
"""Read the frame images from a directory and join them as a video. """Read the frame images from a directory and join them as a video.
Args: Args:

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -9,7 +10,11 @@ from mmcv.image import imread, imwrite
from mmcv.utils import is_str from mmcv.utils import is_str
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): def flowread(flow_or_path: Union[np.ndarray, str],
quantize: bool = False,
concat_axis: int = 0,
*args,
**kwargs) -> np.ndarray:
"""Read an optical flow map. """Read an optical flow map.
Args: Args:
@ -58,7 +63,12 @@ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
return flow.astype(np.float32) return flow.astype(np.float32)
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): def flowwrite(flow: np.ndarray,
filename: str,
quantize: bool = False,
concat_axis: int = 0,
*args,
**kwargs) -> None:
"""Write optical flow to file. """Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly, If the flow is not quantized, it will be saved as a .flo file losslessly,
@ -88,7 +98,9 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
imwrite(dxdy, filename) imwrite(dxdy, filename)
def quantize_flow(flow, max_val=0.02, norm=True): def quantize_flow(flow: np.ndarray,
max_val: float = 0.02,
norm: bool = True) -> tuple:
"""Quantize flow to [0, 255]. """Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be After this step, the size of flow will be much smaller, and can be
@ -116,7 +128,10 @@ def quantize_flow(flow, max_val=0.02, norm=True):
return tuple(flow_comps) return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True): def dequantize_flow(dx: np.ndarray,
dy: np.ndarray,
max_val: float = 0.02,
denorm: bool = True) -> np.ndarray:
"""Recover from quantized flow. """Recover from quantized flow.
Args: Args:
@ -140,12 +155,15 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
return flow return flow
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): def flow_warp(img: np.ndarray,
flow: np.ndarray,
filling_value: int = 0,
interpolate_mode: str = 'nearest') -> np.ndarray:
"""Use flow to warp img. """Use flow to warp img.
Args: Args:
img (ndarray, float or uint8): Image to be warped. img (ndarray): Image to be warped.
flow (ndarray, float): Optical Flow. flow (ndarray): Optical Flow.
filling_value (int): The missing pixels will be set with filling_value. filling_value (int): The missing pixels will be set with filling_value.
interpolate_mode (str): bilinear -> Bilinear Interpolation; interpolate_mode (str): bilinear -> Bilinear Interpolation;
nearest -> Nearest Neighbor. nearest -> Nearest Neighbor.
@ -201,7 +219,7 @@ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
return output.astype(img.dtype) return output.astype(img.dtype)
def flow_from_bytes(content): def flow_from_bytes(content: bytes) -> np.ndarray:
"""Read dense optical flow from bytes. """Read dense optical flow from bytes.
.. note:: .. note::
@ -231,7 +249,7 @@ def flow_from_bytes(content):
return flow return flow
def sparse_flow_from_bytes(content): def sparse_flow_from_bytes(content: bytes) -> Tuple[np.ndarray, np.ndarray]:
"""Read the optical flow in KITTI datasets from bytes. """Read the optical flow in KITTI datasets from bytes.
This function is modified from RAFT load the `KITTI datasets This function is modified from RAFT load the `KITTI datasets