mmdeploy/tools/extract.py
q.yao b32fc41bed
[Refactor][API2.0] Api refactor2.0 (#529)
* [refactor][API2.0]  Add onnx export and jit trace (#419)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add kwargs

* remove comment

* better pipeline manager

* remove print

* [Refactor][API2.0] Api partition calibration (#433)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add partition

* move calibration

* Better create_calib_table

* better deploy

* add kwargs

* remove comment

* better pipeline manager

* rename api, remove reduant variable, and misc

* [Refactor][API2.0] Api ncnn openvino (#435)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add ncnn api

* finish ncnn api

* add openvino support

* add kwargs

* remove comment

* better pipeline manager

* merge fix

* merge util and onnx2ncnn

* fix docstring

* [Refactor][API2.0] API for TensorRT (#519)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add partition

* move calibration

* Better create_calib_table

* better deploy

* add kwargs

* remove comment

* Add tensorrt API

* better pipeline manager

* add tensorrt new api

* remove print

* rename api, remove reduant variable, and misc

* add docstring

* [Refactor][API2.0] Api ppl other (#528)

* first commit

* add async call

* add new api onnx export and jit trace

* add decorator

* fix ci

* fix torchscript ci

* fix loader

* better pipemanager

* remove comment, better import

* add kwargs

* Add new APIS for pplnn sdk and misc

* remove comment

* better pipeline manager

* merge fix

* update tools/onnx2pplnn.py

* rename function
2022-05-31 09:18:18 +08:00

65 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os.path as osp
import onnx
import onnx.helper
from mmdeploy.apis.onnx import extract_partition
from mmdeploy.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(
description='Extract model based on markers.')
parser.add_argument('input_model', help='Input ONNX model')
parser.add_argument('output_model', help='Output ONNX model')
parser.add_argument(
'--start',
help='Start markers, format: func:type, e.g. backbone:input')
parser.add_argument('--end', help='End markers')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
args.start = args.start.split(',') if args.start else []
args.end = args.end.split(',') if args.end else []
return args
def collect_avaiable_marks(model):
marks = []
for node in model.graph.node:
if node.op_type == 'Mark':
for attr in node.attribute:
if attr.name == 'func':
func = str(onnx.helper.get_attribute_value(attr), 'utf-8')
if func not in marks:
marks.append(func)
return marks
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
model = onnx.load(args.input_model)
marks = collect_avaiable_marks(model)
logger.info('Available marks:\n {}'.format('\n '.join(marks)))
extracted_model = extract_partition(model, args.start, args.end)
if osp.splitext(args.output_model)[-1] != '.onnx':
args.output_model += '.onnx'
onnx.save(extracted_model, args.output_model)
if __name__ == '__main__':
main()