mmdeploy/tests/test_apis/test_wrapper.py
Yifan Zhou 4de5bbb461
[Unittest]: Add demos for core/apis/utils (#91)
* add unittests

* lint

* modify .gitignore, remove useless files

* remove emtpy.py and generate it when use

* Update according to comments

1. Use tempfile
2. Delete inference test (which will be tested in each codebase)
3. Refine calibration test

* update annotation

* Add export_info

* Reduce data scale, fix assert

* update json blank line

* add backend check
2021-09-27 16:10:47 +08:00

105 lines
2.7 KiB
Python

import tempfile
import pytest
import torch
import torch.nn as nn
from mmdeploy.apis.tensorrt import (TRTWrapper, create_trt_engine,
save_trt_engine)
from mmdeploy.utils.constants import Backend
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
test_img = torch.rand([1, 3, 64, 64])
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * 0.5
model = TestModel().eval().cuda()
@pytest.fixture(autouse=True, scope='module')
def generate_onnx_file():
with torch.no_grad():
dynamic_axes = {
'input': {
0: 'batch',
2: 'width',
3: 'height'
},
'output': {
0: 'batch'
}
}
torch.onnx.export(
model,
test_img,
onnx_file,
output_names=['output'],
input_names=['input'],
keep_initializers_as_inputs=True,
do_constant_folding=True,
verbose=False,
opset_version=11,
dynamic_axes=dynamic_axes)
def check_backend_avaiable(backend):
if backend == Backend.TENSORRT:
from mmdeploy.apis.tensorrt import is_available as trt_available
if not trt_available():
pytest.skip(
'TensorRT is not installed or custom ops are not compiled.')
if not torch.cuda.is_available():
pytest.skip('CUDA is not available.')
def onnx2backend(backend, onnx_file):
if backend == Backend.TENSORRT:
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
engine = create_trt_engine(
onnx_file, {
'input': {
'min_shape': [1, 3, 64, 64],
'opt_shape': [1, 3, 64, 64],
'max_shape': [1, 3, 64, 64]
}
})
save_trt_engine(engine, backend_file)
return backend_file
def create_wrapper(backend, engine_file):
if backend == Backend.TENSORRT:
trt_model = TRTWrapper(engine_file)
return trt_model
def run_wrapper(backend, wrapper, input):
if backend == Backend.TENSORRT:
input = input.cuda()
results = wrapper({'input': input})['output']
results = results.detach().cpu()
return results
ALL_BACKEND = [Backend.TENSORRT]
@pytest.mark.parametrize('backend', ALL_BACKEND)
def test_wrapper(backend):
check_backend_avaiable(backend)
model_files = onnx2backend(backend, onnx_file)
assert model_files is not None
wrapper = create_wrapper(backend, model_files)
assert wrapper is not None
results = run_wrapper(backend, wrapper, test_img)
assert results is not None