1
0
mirror of https://github.com/open-mmlab/mmdeploy.git synced 2025-01-14 08:09:43 +08:00
mmdeploy/tests/test_apis/test_onnx2tensorrt.py

90 lines
2.3 KiB
Python
Raw Normal View History

2021-11-30 15:00:37 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import mmcv
import pytest
import torch
import torch.nn as nn
from mmdeploy.utils import Backend
from mmdeploy.utils.test import backend_checker
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
engine_file = tempfile.NamedTemporaryFile(suffix='.engine').name
test_img = torch.rand([1, 3, 8, 8])
@pytest.mark.skip(reason='This a not test class but a utility class.')
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * 0.5
test_model = TestModel().eval().cuda()
def get_deploy_cfg():
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(
type='tensorrt',
common_config=dict(
2021-12-13 19:57:15 +08:00
fp16_mode=False, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 8, 8],
opt_shape=[1, 3, 8, 8],
max_shape=[1, 3, 8, 8])))
])))
return deploy_cfg
def generate_onnx_file(model):
with torch.no_grad():
dynamic_axes = {
'input': {
0: 'batch',
2: 'width',
3: 'height'
},
'output': {
0: 'batch'
}
}
torch.onnx.export(
model,
test_img,
onnx_file,
output_names=['output'],
input_names=['input'],
keep_initializers_as_inputs=True,
do_constant_folding=True,
verbose=False,
opset_version=11,
dynamic_axes=dynamic_axes)
assert osp.exists(onnx_file)
@backend_checker(Backend.TENSORRT)
def test_onnx2tensorrt():
from mmdeploy.apis.tensorrt import onnx2tensorrt
[Refactor][API2.0] Api refactor2.0 (#529) * [refactor][API2.0] Add onnx export and jit trace (#419) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * remove comment * better pipeline manager * remove print * [Refactor][API2.0] Api partition calibration (#433) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * better pipeline manager * rename api, remove reduant variable, and misc * [Refactor][API2.0] Api ncnn openvino (#435) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add ncnn api * finish ncnn api * add openvino support * add kwargs * remove comment * better pipeline manager * merge fix * merge util and onnx2ncnn * fix docstring * [Refactor][API2.0] API for TensorRT (#519) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * Add tensorrt API * better pipeline manager * add tensorrt new api * remove print * rename api, remove reduant variable, and misc * add docstring * [Refactor][API2.0] Api ppl other (#528) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * Add new APIS for pplnn sdk and misc * remove comment * better pipeline manager * merge fix * update tools/onnx2pplnn.py * rename function
2022-05-31 09:18:18 +08:00
from mmdeploy.backend.tensorrt import load
model = test_model
generate_onnx_file(model)
deploy_cfg = get_deploy_cfg()
work_dir, save_file = osp.split(engine_file)
onnx2tensorrt(work_dir, save_file, 0, deploy_cfg, onnx_file)
assert osp.exists(work_dir)
assert osp.exists(engine_file)
[Refactor][API2.0] Api refactor2.0 (#529) * [refactor][API2.0] Add onnx export and jit trace (#419) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * remove comment * better pipeline manager * remove print * [Refactor][API2.0] Api partition calibration (#433) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * better pipeline manager * rename api, remove reduant variable, and misc * [Refactor][API2.0] Api ncnn openvino (#435) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add ncnn api * finish ncnn api * add openvino support * add kwargs * remove comment * better pipeline manager * merge fix * merge util and onnx2ncnn * fix docstring * [Refactor][API2.0] API for TensorRT (#519) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * Add tensorrt API * better pipeline manager * add tensorrt new api * remove print * rename api, remove reduant variable, and misc * add docstring * [Refactor][API2.0] Api ppl other (#528) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * Add new APIS for pplnn sdk and misc * remove comment * better pipeline manager * merge fix * update tools/onnx2pplnn.py * rename function
2022-05-31 09:18:18 +08:00
engine = load(engine_file)
assert engine is not None