[Dostring]add example in apis docstring (#214)
* add example in apis docstring * add backend example in docstring * rm blank linepull/178/head
parent
0096aacd3e
commit
2ce83b6c77
|
@ -21,6 +21,20 @@ def create_calib_table(calib_file: str,
|
|||
**kwargs) -> None:
|
||||
"""Create calibration table.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import create_calib_table
|
||||
>>> from mmdeploy.utils import get_calib_filename, load_config
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/' \
|
||||
'detection_tensorrt-int8_dynamic-320x320-1344x1344.py'
|
||||
>>> deploy_cfg = load_config(deploy_cfg)[0]
|
||||
>>> calib_file = get_calib_filename(deploy_cfg)
|
||||
>>> model_cfg = 'mmdetection/configs/fcos/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
|
||||
>>> model_checkpoint = 'checkpoints/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
|
||||
>>> create_calib_table(calib_file, deploy_cfg, \
|
||||
model_cfg, model_checkpoint, device='cuda:0')
|
||||
|
||||
Args:
|
||||
calib_file (str): Input calibration file.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config.
|
||||
|
|
|
@ -23,6 +23,30 @@ def extract_model(model: Union[str, onnx.ModelProto],
|
|||
The sub-model is defined by the names of the input and output tensors
|
||||
exactly.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import extract_model
|
||||
>>> model = 'work_dir/fastrcnn.onnx'
|
||||
>>> start = 'detector:input'
|
||||
>>> end = ['extract_feat:output', 'multiclass_nms[0]:input']
|
||||
>>> dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'scores': {
|
||||
0: 'batch',
|
||||
1: 'num_boxes',
|
||||
},
|
||||
'boxes': {
|
||||
0: 'batch',
|
||||
1: 'num_boxes',
|
||||
}
|
||||
}
|
||||
>>> save_file = 'partition_model.onnx'
|
||||
>>> extract_model(model, start, end, dynamic_axes=dynamic_axes, \
|
||||
save_file=save_file)
|
||||
|
||||
Args:
|
||||
model (str | onnx.ModelProto): Input ONNX model to be extracted.
|
||||
start (str | Sequence[str]): Start marker(s) to extract.
|
||||
|
|
|
@ -14,6 +14,18 @@ def inference_model(model_cfg: Union[str, mmcv.Config],
|
|||
device: str) -> Any:
|
||||
"""Run inference with PyTorch or backend model and show results.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import inference_model
|
||||
>>> model_cfg = 'mmdetection/configs/fcos/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/' \
|
||||
'detection_onnxruntime_dynamic.py'
|
||||
>>> backend_files = ['work_dir/fcos.onnx']
|
||||
>>> img = 'demo.jpg'
|
||||
>>> device = 'cpu'
|
||||
>>> model_output = inference_model(model_cfg, deploy_cfg, \
|
||||
backend_files, img, device)
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmcv.Config): Model config file or Config object.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or Config
|
||||
|
|
|
@ -62,6 +62,21 @@ def torch2onnx(img: Any,
|
|||
device: str = 'cuda:0'):
|
||||
"""Convert PyTorch model to ONNX model.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import torch2onnx
|
||||
>>> img = 'demo.jpg'
|
||||
>>> work_dir = 'work_dir'
|
||||
>>> save_file = 'fcos.onnx'
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/' \
|
||||
'detection_onnxruntime_dynamic.py'
|
||||
>>> model_cfg = 'mmdetection/configs/fcos/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
|
||||
>>> model_checkpoint = 'checkpoints/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
|
||||
>>> device = 'cpu'
|
||||
>>> torch2onnx(img, work_dir, save_file, deploy_cfg, \
|
||||
model_cfg, model_checkpoint, device)
|
||||
|
||||
Args:
|
||||
img (str | np.ndarray | torch.Tensor): Input image used to assist
|
||||
converting model.
|
||||
|
|
|
@ -19,6 +19,18 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
|
|||
show_result: bool = False):
|
||||
"""Run inference with PyTorch or backend model and show results.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import visualize_model
|
||||
>>> model_cfg = 'mmdetection/configs/fcos/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/' \
|
||||
'detection_onnxruntime_dynamic.py'
|
||||
>>> model = 'work_dir/fcos.onnx'
|
||||
>>> img = 'demo.jpg'
|
||||
>>> device = 'cpu'
|
||||
>>> visualize_model(model_cfg, deploy_cfg, model, \
|
||||
img, device, show_result=True)
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmcv.Config): Model config file or Config object.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or Config
|
||||
|
|
|
@ -33,6 +33,13 @@ def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str):
|
|||
a executable program to convert the `.onnx` file to a `.param` file and
|
||||
a `.bin` file. The output files will save to work_dir.
|
||||
|
||||
Example:
|
||||
>>> from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn
|
||||
>>> onnx_path = 'work_dir/end2end.onnx'
|
||||
>>> save_param = 'work_dir/end2end.param'
|
||||
>>> save_bin = 'work_dir/end2end.bin'
|
||||
>>> onnx2ncnn(onnx_path, save_param, save_bin)
|
||||
|
||||
Args:
|
||||
onnx_path (str): The path of the onnx model.
|
||||
save_param (str): The path to save the output `.param` file.
|
||||
|
|
|
@ -58,6 +58,14 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
|
|||
output_names: List[str], onnx_path: str, work_dir: str):
|
||||
"""Convert ONNX to OpenVINO.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.backend.openvino.onnx2openvino import onnx2openvino
|
||||
>>> input_info = {'input': [1,3,800,1344]}
|
||||
>>> output_names = ['dets', 'labels']
|
||||
>>> onnx_path = 'work_dir/end2end.onnx'
|
||||
>>> work_dir = 'work_dir'
|
||||
>>> onnx2openvino(input_info, output_names, onnx_path, work_dir)
|
||||
|
||||
Args:
|
||||
input_info (Dict[str, Union[List[int], torch.Size]]):
|
||||
The shape of each input.
|
||||
|
|
|
@ -21,6 +21,17 @@ def onnx2tensorrt(work_dir: str,
|
|||
**kwargs):
|
||||
"""Convert ONNX to TensorRT.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt
|
||||
>>> work_dir = 'work_dir'
|
||||
>>> save_file = 'end2end.engine'
|
||||
>>> model_id = 0
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/' \
|
||||
'detection_tensorrt_dynamic-320x320-1344x1344.py'
|
||||
>>> onnx_model = 'work_dir/end2end.onnx'
|
||||
>>> onnx2tensorrt(work_dir, save_file, model_id, deploy_cfg, \
|
||||
onnx_model, 'cuda:0')
|
||||
|
||||
Args:
|
||||
work_dir (str): A working directory.
|
||||
save_file (str): The base name of the file to save TensorRT engine.
|
||||
|
|
Loading…
Reference in New Issue