mmdeploy/tests/test_apis/test_onnx2tensorrt.py

110 lines
2.6 KiB
Python
Raw Normal View History

import os
import os.path as osp
import shutil
import mmcv
import pytest
import torch
import torch.multiprocessing as mp
from torch import nn
import mmdeploy.apis.tensorrt as trt_apis
# skip if tensorrt apis can not loaded
if not trt_apis.is_available():
pytest.skip('TensorRT apis is not prepared.')
trt = pytest.importorskip('tensorrt', reason='Import tensorrt failed.')
if not torch.cuda.is_available():
pytest.skip('CUDA is not available.')
# load apis from trt_apis
TRTWrapper = trt_apis.TRTWrapper
onnx2tensorrt = trt_apis.onnx2tensorrt
ret_value = mp.Value('d', 0, lock=False)
work_dir = './tmp/'
onnx_file = 'tmp.onnx'
save_file = 'tmp.engine'
@pytest.fixture(autouse=True)
def clear_workdir_after_test():
# clear work_dir before test
if osp.exists(work_dir):
shutil.rmtree(work_dir)
os.mkdir(work_dir)
yield
# clear work_dir after test
if osp.exists(work_dir):
shutil.rmtree(work_dir)
def test_onnx2tensorrt():
# dummy model
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + 1
model = TestModel().eval().cuda()
x = torch.rand(1, 3, 64, 64).cuda()
onnx_path = osp.join(work_dir, onnx_file)
# export to onnx
torch.onnx.export(
model,
x,
onnx_path,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {
0: 'batch',
2: 'height',
3: 'width'
}})
assert osp.exists(onnx_path)
# deploy config
deploy_cfg = mmcv.Config(
dict(
backend='tensorrt',
tensorrt_params=dict(
shared_params=dict(
log_level=trt.Logger.WARNING, fp16_mode=False),
model_params=[
dict(
opt_shape_dict=dict(
input=[[1, 3, 32, 32], [1, 3, 64, 64],
[1, 3, 128, 128]]),
max_workspace_size=1 << 30)
])))
# convert to engine
onnx2tensorrt(
work_dir,
save_file,
0,
deploy_cfg=deploy_cfg,
onnx_model=onnx_path,
ret_value=ret_value)
assert ret_value.value == 0
assert osp.exists(work_dir)
assert osp.exists(osp.join(work_dir, save_file))
# test
trt_model = TRTWrapper(osp.join(work_dir, save_file))
x = x.cuda()
with torch.no_grad():
trt_output = trt_model({'input': x})['output']
torch.testing.assert_allclose(trt_output, x + 1)