mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Add type hints for mmcv/arraymisc and mmcv/video (#1950)
* Add type hints * Add type hints * Fix int float about scalar * Add type hints for mmcv/tensorrt * Update mmcv/tensorrt/tensorrt_utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/arraymisc/quantization.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Ignore type hint for dtype Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1954/head
parent
882cab77bb
commit
94cc99d595
|
@ -1,14 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
||||
def quantize(arr: np.ndarray,
|
||||
min_val: Union[int, float],
|
||||
max_val: Union[int, float],
|
||||
levels: int,
|
||||
dtype=np.int64) -> tuple:
|
||||
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
min_val (int or float): Minimum value to be clipped.
|
||||
max_val (int or float): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the quantized array.
|
||||
|
||||
|
@ -29,13 +35,17 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
|||
return quantized_arr
|
||||
|
||||
|
||||
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
||||
def dequantize(arr: np.ndarray,
|
||||
min_val: Union[int, float],
|
||||
max_val: Union[int, float],
|
||||
levels: int,
|
||||
dtype=np.float64) -> tuple:
|
||||
"""Dequantize an array.
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
min_val (int or float): Minimum value to be clipped.
|
||||
max_val (int or float): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the dequantized array.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
import warnings
|
||||
|
||||
|
||||
def get_tensorrt_op_path():
|
||||
def get_tensorrt_op_path() -> str:
|
||||
"""Get TensorRT plugins library path."""
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
|
@ -31,7 +31,7 @@ def get_tensorrt_op_path():
|
|||
plugin_is_loaded = False
|
||||
|
||||
|
||||
def is_tensorrt_plugin_loaded():
|
||||
def is_tensorrt_plugin_loaded() -> bool:
|
||||
"""Check if TensorRT plugins library is loaded or not.
|
||||
|
||||
Returns:
|
||||
|
@ -54,7 +54,7 @@ def is_tensorrt_plugin_loaded():
|
|||
return plugin_is_loaded
|
||||
|
||||
|
||||
def load_tensorrt_plugin():
|
||||
def load_tensorrt_plugin() -> None:
|
||||
"""load TensorRT plugins library."""
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import onnx
|
||||
|
||||
|
||||
def preprocess_onnx(onnx_model):
|
||||
def preprocess_onnx(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
|
||||
"""Modify onnx model to match with TensorRT plugins in mmcv.
|
||||
|
||||
There are some conflict between onnx node definition and TensorRT limit.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
|
@ -8,12 +9,12 @@ import torch
|
|||
from .preprocess import preprocess_onnx
|
||||
|
||||
|
||||
def onnx2trt(onnx_model,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=0,
|
||||
device_id=0):
|
||||
def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
|
||||
opt_shape_dict: dict,
|
||||
log_level: trt.ILogger.Severity = trt.Logger.ERROR,
|
||||
fp16_mode: bool = False,
|
||||
max_workspace_size: int = 0,
|
||||
device_id: int = 0) -> trt.ICudaEngine:
|
||||
"""Convert onnx model to tensorrt engine.
|
||||
|
||||
Arguments:
|
||||
|
@ -100,7 +101,7 @@ def onnx2trt(onnx_model,
|
|||
return engine
|
||||
|
||||
|
||||
def save_trt_engine(engine, path):
|
||||
def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None:
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Arguments:
|
||||
|
@ -124,7 +125,7 @@ def save_trt_engine(engine, path):
|
|||
f.write(bytearray(engine.serialize()))
|
||||
|
||||
|
||||
def load_trt_engine(path):
|
||||
def load_trt_engine(path: str) -> trt.ICudaEngine:
|
||||
"""Deserialize TensorRT engine from disk.
|
||||
|
||||
Arguments:
|
||||
|
@ -153,7 +154,7 @@ def load_trt_engine(path):
|
|||
return engine
|
||||
|
||||
|
||||
def torch_dtype_from_trt(dtype):
|
||||
def torch_dtype_from_trt(dtype: trt.DataType) -> Union[torch.dtype, TypeError]:
|
||||
"""Convert pytorch dtype to TensorRT dtype."""
|
||||
if dtype == trt.bool:
|
||||
return torch.bool
|
||||
|
@ -169,7 +170,8 @@ def torch_dtype_from_trt(dtype):
|
|||
raise TypeError('%s is not supported by torch' % dtype)
|
||||
|
||||
|
||||
def torch_device_from_trt(device):
|
||||
def torch_device_from_trt(
|
||||
device: trt.TensorLocation) -> Union[torch.device, TypeError]:
|
||||
"""Convert pytorch device to TensorRT device."""
|
||||
if device == trt.TensorLocation.DEVICE:
|
||||
return torch.device('cuda')
|
||||
|
|
|
@ -272,14 +272,14 @@ class VideoReader:
|
|||
self._vcap.release()
|
||||
|
||||
|
||||
def frames2video(frame_dir,
|
||||
video_file,
|
||||
fps=30,
|
||||
fourcc='XVID',
|
||||
filename_tmpl='{:06d}.jpg',
|
||||
start=0,
|
||||
end=0,
|
||||
show_progress=True):
|
||||
def frames2video(frame_dir: str,
|
||||
video_file: str,
|
||||
fps: float = 30,
|
||||
fourcc: str = 'XVID',
|
||||
filename_tmpl: str = '{:06d}.jpg',
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
show_progress: bool = True) -> None:
|
||||
"""Read the frame images from a directory and join them as a video.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -9,7 +10,11 @@ from mmcv.image import imread, imwrite
|
|||
from mmcv.utils import is_str
|
||||
|
||||
|
||||
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
def flowread(flow_or_path: Union[np.ndarray, str],
|
||||
quantize: bool = False,
|
||||
concat_axis: int = 0,
|
||||
*args,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""Read an optical flow map.
|
||||
|
||||
Args:
|
||||
|
@ -58,7 +63,12 @@ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
|
|||
return flow.astype(np.float32)
|
||||
|
||||
|
||||
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
def flowwrite(flow: np.ndarray,
|
||||
filename: str,
|
||||
quantize: bool = False,
|
||||
concat_axis: int = 0,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
"""Write optical flow to file.
|
||||
|
||||
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
||||
|
@ -88,7 +98,9 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
|||
imwrite(dxdy, filename)
|
||||
|
||||
|
||||
def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
def quantize_flow(flow: np.ndarray,
|
||||
max_val: float = 0.02,
|
||||
norm: bool = True) -> tuple:
|
||||
"""Quantize flow to [0, 255].
|
||||
|
||||
After this step, the size of flow will be much smaller, and can be
|
||||
|
@ -116,7 +128,10 @@ def quantize_flow(flow, max_val=0.02, norm=True):
|
|||
return tuple(flow_comps)
|
||||
|
||||
|
||||
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
||||
def dequantize_flow(dx: np.ndarray,
|
||||
dy: np.ndarray,
|
||||
max_val: float = 0.02,
|
||||
denorm: bool = True) -> np.ndarray:
|
||||
"""Recover from quantized flow.
|
||||
|
||||
Args:
|
||||
|
@ -140,12 +155,15 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
|||
return flow
|
||||
|
||||
|
||||
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
|
||||
def flow_warp(img: np.ndarray,
|
||||
flow: np.ndarray,
|
||||
filling_value: int = 0,
|
||||
interpolate_mode: str = 'nearest') -> np.ndarray:
|
||||
"""Use flow to warp img.
|
||||
|
||||
Args:
|
||||
img (ndarray, float or uint8): Image to be warped.
|
||||
flow (ndarray, float): Optical Flow.
|
||||
img (ndarray): Image to be warped.
|
||||
flow (ndarray): Optical Flow.
|
||||
filling_value (int): The missing pixels will be set with filling_value.
|
||||
interpolate_mode (str): bilinear -> Bilinear Interpolation;
|
||||
nearest -> Nearest Neighbor.
|
||||
|
@ -201,7 +219,7 @@ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
|
|||
return output.astype(img.dtype)
|
||||
|
||||
|
||||
def flow_from_bytes(content):
|
||||
def flow_from_bytes(content: bytes) -> np.ndarray:
|
||||
"""Read dense optical flow from bytes.
|
||||
|
||||
.. note::
|
||||
|
@ -231,7 +249,7 @@ def flow_from_bytes(content):
|
|||
return flow
|
||||
|
||||
|
||||
def sparse_flow_from_bytes(content):
|
||||
def sparse_flow_from_bytes(content: bytes) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Read the optical flow in KITTI datasets from bytes.
|
||||
|
||||
This function is modified from RAFT load the `KITTI datasets
|
||||
|
|
Loading…
Reference in New Issue