mmdeploy/mmdeploy/apis/extract_model.py
q.yao b32fc41bed
[Refactor][API2.0] Api refactor2.0 (#529)
* [refactor][API2.0]  Add onnx export and jit trace (#419)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add kwargs

* remove comment

* better pipeline manager

* remove print

* [Refactor][API2.0] Api partition calibration (#433)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add partition

* move calibration

* Better create_calib_table

* better deploy

* add kwargs

* remove comment

* better pipeline manager

* rename api, remove reduant variable, and misc

* [Refactor][API2.0] Api ncnn openvino (#435)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add ncnn api

* finish ncnn api

* add openvino support

* add kwargs

* remove comment

* better pipeline manager

* merge fix

* merge util and onnx2ncnn

* fix docstring

* [Refactor][API2.0] API for TensorRT (#519)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add partition

* move calibration

* Better create_calib_table

* better deploy

* add kwargs

* remove comment

* Add tensorrt API

* better pipeline manager

* add tensorrt new api

* remove print

* rename api, remove reduant variable, and misc

* add docstring

* [Refactor][API2.0] Api ppl other (#528)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add kwargs

* Add new APIS for pplnn sdk and misc

* remove comment

* better pipeline manager

* merge fix

* update tools/onnx2pplnn.py

* rename function
2022-05-31 09:18:18 +08:00

68 lines
2.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Iterable, Optional, Union
import onnx
from .core import PIPELINE_MANAGER
from .onnx import extract_partition
@PIPELINE_MANAGER.register_pipeline()
def extract_model(model: Union[str, onnx.ModelProto],
start_marker: Union[str, Iterable[str]],
end_marker: Union[str, Iterable[str]],
start_name_map: Optional[Dict[str, str]] = None,
end_name_map: Optional[Dict[str, str]] = None,
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
save_file: Optional[str] = None) -> onnx.ModelProto:
"""Extract partition-model from an ONNX model.
The partition-model is defined by the names of the input and output tensors
exactly.
Examples:
>>> from mmdeploy.apis import extract_model
>>> model = 'work_dir/fastrcnn.onnx'
>>> start_marker = 'detector:input'
>>> end_marker = ['extract_feat:output', 'multiclass_nms[0]:input']
>>> dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'scores': {
0: 'batch',
1: 'num_boxes',
},
'boxes': {
0: 'batch',
1: 'num_boxes',
}
}
>>> save_file = 'partition_model.onnx'
>>> extract_model(model, start_marker, end_marker, \
dynamic_axes=dynamic_axes, \
save_file=save_file)
Args:
model (str | onnx.ModelProto): Input ONNX model to be extracted.
start_marker (str | Sequence[str]): Start marker(s) to extract.
end_marker (str | Sequence[str]): End marker(s) to extract.
start_name_map (Dict[str, str]): A mapping of start names, defaults to
`None`.
end_name_map (Dict[str, str]): A mapping of end names, defaults to
`None`.
dynamic_axes (Dict[str, Dict[int, str]]): A dictionary to specify
dynamic axes of input/output, defaults to `None`.
save_file (str): A file to save the extracted model, defaults to
`None`.
Returns:
onnx.ModelProto: The extracted model.
"""
return extract_partition(model, start_marker, end_marker, start_name_map,
end_name_map, dynamic_axes, save_file)