mmdeploy/tests/test_apis/test_extract.py

42 lines
970 B
Python
Raw Normal View History

2021-11-30 15:00:37 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import onnx
import torch
[Refactor][API2.0] Api refactor2.0 (#529) * [refactor][API2.0] Add onnx export and jit trace (#419) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * remove comment * better pipeline manager * remove print * [Refactor][API2.0] Api partition calibration (#433) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * better pipeline manager * rename api, remove reduant variable, and misc * [Refactor][API2.0] Api ncnn openvino (#435) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add ncnn api * finish ncnn api * add openvino support * add kwargs * remove comment * better pipeline manager * merge fix * merge util and onnx2ncnn * fix docstring * [Refactor][API2.0] API for TensorRT (#519) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * Add tensorrt API * better pipeline manager * add tensorrt new api * remove print * rename api, remove reduant variable, and misc * add docstring * [Refactor][API2.0] Api ppl other (#528) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * Add new APIS for pplnn sdk and misc * remove comment * better pipeline manager * merge fix * update tools/onnx2pplnn.py * rename function
2022-05-31 09:18:18 +08:00
from mmdeploy.apis.onnx import extract_partition
from mmdeploy.core import mark
output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
def test_extract():
@mark('add', outputs='z')
def add(x, y):
return torch.add(x, y)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return add(x, y)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
torch.onnx.export(model, (x, y), output_file)
onnx_model = onnx.load(output_file)
[Refactor][API2.0] Api refactor2.0 (#529) * [refactor][API2.0] Add onnx export and jit trace (#419) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * remove comment * better pipeline manager * remove print * [Refactor][API2.0] Api partition calibration (#433) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * better pipeline manager * rename api, remove reduant variable, and misc * [Refactor][API2.0] Api ncnn openvino (#435) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add ncnn api * finish ncnn api * add openvino support * add kwargs * remove comment * better pipeline manager * merge fix * merge util and onnx2ncnn * fix docstring * [Refactor][API2.0] API for TensorRT (#519) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * Add tensorrt API * better pipeline manager * add tensorrt new api * remove print * rename api, remove reduant variable, and misc * add docstring * [Refactor][API2.0] Api ppl other (#528) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * Add new APIS for pplnn sdk and misc * remove comment * better pipeline manager * merge fix * update tools/onnx2pplnn.py * rename function
2022-05-31 09:18:18 +08:00
extracted = extract_partition(onnx_model, 'add:input', 'add:output')
assert extracted.graph.input[0].name == 'x'
assert extracted.graph.input[1].name == 'y'
assert extracted.graph.output[0].name == 'z'
assert extracted.graph.node[0].op_type == 'Add'