[Dostring]add example in apis docstring (#214)

* add example in apis docstring

* add backend example in docstring

* rm blank line
pull/178/head
VVsssssk 2022-03-10 11:36:44 +08:00 committed by GitHub
parent 0096aacd3e
commit 2ce83b6c77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 103 additions and 0 deletions

View File

@ -21,6 +21,20 @@ def create_calib_table(calib_file: str,
**kwargs) -> None:
"""Create calibration table.
Examples:
>>> from mmdeploy.apis import create_calib_table
>>> from mmdeploy.utils import get_calib_filename, load_config
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_tensorrt-int8_dynamic-320x320-1344x1344.py'
>>> deploy_cfg = load_config(deploy_cfg)[0]
>>> calib_file = get_calib_filename(deploy_cfg)
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> model_checkpoint = 'checkpoints/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
>>> create_calib_table(calib_file, deploy_cfg, \
model_cfg, model_checkpoint, device='cuda:0')
Args:
calib_file (str): Input calibration file.
deploy_cfg (str | mmcv.Config): Deployment config.

View File

@ -23,6 +23,30 @@ def extract_model(model: Union[str, onnx.ModelProto],
The sub-model is defined by the names of the input and output tensors
exactly.
Examples:
>>> from mmdeploy.apis import extract_model
>>> model = 'work_dir/fastrcnn.onnx'
>>> start = 'detector:input'
>>> end = ['extract_feat:output', 'multiclass_nms[0]:input']
>>> dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'scores': {
0: 'batch',
1: 'num_boxes',
},
'boxes': {
0: 'batch',
1: 'num_boxes',
}
}
>>> save_file = 'partition_model.onnx'
>>> extract_model(model, start, end, dynamic_axes=dynamic_axes, \
save_file=save_file)
Args:
model (str | onnx.ModelProto): Input ONNX model to be extracted.
start (str | Sequence[str]): Start marker(s) to extract.

View File

@ -14,6 +14,18 @@ def inference_model(model_cfg: Union[str, mmcv.Config],
device: str) -> Any:
"""Run inference with PyTorch or backend model and show results.
Examples:
>>> from mmdeploy.apis import inference_model
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_onnxruntime_dynamic.py'
>>> backend_files = ['work_dir/fcos.onnx']
>>> img = 'demo.jpg'
>>> device = 'cpu'
>>> model_output = inference_model(model_cfg, deploy_cfg, \
backend_files, img, device)
Args:
model_cfg (str | mmcv.Config): Model config file or Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or Config

View File

@ -62,6 +62,21 @@ def torch2onnx(img: Any,
device: str = 'cuda:0'):
"""Convert PyTorch model to ONNX model.
Examples:
>>> from mmdeploy.apis import torch2onnx
>>> img = 'demo.jpg'
>>> work_dir = 'work_dir'
>>> save_file = 'fcos.onnx'
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_onnxruntime_dynamic.py'
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> model_checkpoint = 'checkpoints/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
>>> device = 'cpu'
>>> torch2onnx(img, work_dir, save_file, deploy_cfg, \
model_cfg, model_checkpoint, device)
Args:
img (str | np.ndarray | torch.Tensor): Input image used to assist
converting model.

View File

@ -19,6 +19,18 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
show_result: bool = False):
"""Run inference with PyTorch or backend model and show results.
Examples:
>>> from mmdeploy.apis import visualize_model
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_onnxruntime_dynamic.py'
>>> model = 'work_dir/fcos.onnx'
>>> img = 'demo.jpg'
>>> device = 'cpu'
>>> visualize_model(model_cfg, deploy_cfg, model, \
img, device, show_result=True)
Args:
model_cfg (str | mmcv.Config): Model config file or Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or Config

View File

@ -33,6 +33,13 @@ def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str):
a executable program to convert the `.onnx` file to a `.param` file and
a `.bin` file. The output files will save to work_dir.
Example:
>>> from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn
>>> onnx_path = 'work_dir/end2end.onnx'
>>> save_param = 'work_dir/end2end.param'
>>> save_bin = 'work_dir/end2end.bin'
>>> onnx2ncnn(onnx_path, save_param, save_bin)
Args:
onnx_path (str): The path of the onnx model.
save_param (str): The path to save the output `.param` file.

View File

@ -58,6 +58,14 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
output_names: List[str], onnx_path: str, work_dir: str):
"""Convert ONNX to OpenVINO.
Examples:
>>> from mmdeploy.backend.openvino.onnx2openvino import onnx2openvino
>>> input_info = {'input': [1,3,800,1344]}
>>> output_names = ['dets', 'labels']
>>> onnx_path = 'work_dir/end2end.onnx'
>>> work_dir = 'work_dir'
>>> onnx2openvino(input_info, output_names, onnx_path, work_dir)
Args:
input_info (Dict[str, Union[List[int], torch.Size]]):
The shape of each input.

View File

@ -21,6 +21,17 @@ def onnx2tensorrt(work_dir: str,
**kwargs):
"""Convert ONNX to TensorRT.
Examples:
>>> from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt
>>> work_dir = 'work_dir'
>>> save_file = 'end2end.engine'
>>> model_id = 0
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_tensorrt_dynamic-320x320-1344x1344.py'
>>> onnx_model = 'work_dir/end2end.onnx'
>>> onnx2tensorrt(work_dir, save_file, model_id, deploy_cfg, \
onnx_model, 'cuda:0')
Args:
work_dir (str): A working directory.
save_file (str): The base name of the file to save TensorRT engine.