mmdeploy/tests/test_backend/test_wrapper.py

220 lines
7.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import subprocess
import tempfile
import mmengine
import pytest
import torch
import torch.nn as nn
from mmdeploy.utils.constants import Backend
from mmdeploy.utils.test import check_backend
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
ts_file = tempfile.NamedTemporaryFile(suffix='.pt').name
test_img = torch.rand(1, 3, 8, 8)
output_names = ['output']
input_names = ['input']
target_platform = 'rk3588' # rknn pre-compiled model need device
@pytest.mark.skip(reason='This a not test class but a utility class.')
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + test_img
model = TestModel().eval()
@pytest.fixture(autouse=True, scope='module')
def generate_onnx_file():
with torch.no_grad():
torch.onnx.export(
model,
test_img,
onnx_file,
output_names=output_names,
input_names=input_names,
keep_initializers_as_inputs=True,
do_constant_folding=True,
verbose=False,
opset_version=11,
dynamic_axes=None)
@pytest.fixture(autouse=True, scope='module')
def generate_torchscript_file():
from mmengine import Config
backend = Backend.TORCHSCRIPT.value
deploy_cfg = Config({'backend_config': dict(type=backend)})
from mmdeploy.apis.torch_jit import trace
context_info = dict(deploy_cfg=deploy_cfg)
output_prefix = osp.splitext(ts_file)[0]
example_inputs = torch.rand(1, 3, 8, 8)
trace(
model,
example_inputs,
output_path_prefix=output_prefix,
backend=backend,
context_info=context_info)
def ir2backend(backend, onnx_file, ts_file):
if backend == Backend.TENSORRT:
from mmdeploy.backend.tensorrt import from_onnx
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
from_onnx(
onnx_file,
osp.splitext(backend_file)[0], {
'input': {
'min_shape': [1, 3, 8, 8],
'opt_shape': [1, 3, 8, 8],
'max_shape': [1, 3, 8, 8]
}
})
return backend_file
elif backend == Backend.ONNXRUNTIME:
return onnx_file
elif backend == Backend.PPLNN:
from mmdeploy.apis.pplnn import from_onnx
output_file_prefix = tempfile.NamedTemporaryFile().name
from_onnx(onnx_file, output_file_prefix=output_file_prefix)
algo_file = output_file_prefix + '.json'
output_file = output_file_prefix + '.onnx'
return output_file, algo_file
elif backend == Backend.NCNN:
from mmdeploy.backend.ncnn.init_plugins import get_onnx2ncnn_path
onnx2ncnn_path = get_onnx2ncnn_path()
param_file = tempfile.NamedTemporaryFile(suffix='.param').name
bin_file = tempfile.NamedTemporaryFile(suffix='.bin').name
subprocess.call([onnx2ncnn_path, onnx_file, param_file, bin_file])
return param_file, bin_file
elif backend == Backend.OPENVINO:
from mmdeploy.apis.openvino import from_onnx, get_output_model_file
backend_dir = tempfile.TemporaryDirectory().name
backend_file = get_output_model_file(onnx_file, backend_dir)
input_info = {'input': test_img.shape}
output_names = ['output']
work_dir = backend_dir
from_onnx(onnx_file, work_dir, input_info, output_names)
return backend_file
elif backend == Backend.RKNN:
from mmdeploy.apis.rknn import onnx2rknn
rknn_file = onnx_file.replace('.onnx', '.rknn')
deploy_cfg = mmengine.Config(
dict(
backend_config=dict(
type='rknn',
common_config=dict(target_platform=target_platform),
quantization_config=dict(
do_quantization=False, dataset=None),
input_size_list=[[3, 8, 8]])))
onnx2rknn(onnx_file, rknn_file, deploy_cfg)
return rknn_file
elif backend == Backend.ASCEND:
from mmdeploy.apis.ascend import from_onnx
backend_dir = tempfile.TemporaryDirectory().name
work_dir = backend_dir
file_name = osp.splitext(osp.split(onnx_file)[1])[0]
backend_file = osp.join(work_dir, file_name + '.om')
model_inputs = mmengine.Config(
dict(input_shapes=dict(input=test_img.shape)))
from_onnx(onnx_file, work_dir, model_inputs)
return backend_file
elif backend == Backend.TVM:
from mmdeploy.backend.tvm import from_onnx, get_library_ext
ext = get_library_ext()
lib_file = tempfile.NamedTemporaryFile(suffix=ext).name
shape = {'input': test_img.shape}
dtype = {'input': 'float32'}
target = 'llvm'
tuner_dict = dict(type='DefaultTuner', target=target)
from_onnx(
onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict)
assert osp.exists(lib_file)
return lib_file
elif backend == Backend.TORCHSCRIPT:
return ts_file
elif backend == Backend.COREML:
output_names = ['output']
from mmdeploy.backend.coreml.torchscript2coreml import (
from_torchscript, get_model_suffix)
backend_dir = tempfile.TemporaryDirectory().name
work_dir = backend_dir
torchscript_name = osp.splitext(osp.split(ts_file)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)
convert_to = 'mlprogram'
from_torchscript(
ts_file,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=dict(
input=dict(
min_shape=[1, 3, 8, 8],
default_shape=[1, 3, 8, 8],
max_shape=[1, 3, 8, 8])),
convert_to=convert_to)
suffix = get_model_suffix(convert_to)
return output_file_prefix + suffix
else:
raise NotImplementedError(
f'Convert for {backend.value} has not been implemented.')
def create_wrapper(backend, model_files):
from mmdeploy.backend.base import get_backend_manager
backend_mgr = get_backend_manager(backend.value)
deploy_cfg = None
if isinstance(model_files, str):
model_files = [model_files]
elif backend == Backend.RKNN:
deploy_cfg = dict(
backend_config=dict(
common_config=dict(target_platform=target_platform)))
return backend_mgr.build_wrapper(
model_files,
input_names=input_names,
output_names=output_names,
deploy_cfg=deploy_cfg)
def run_wrapper(backend, wrapper, input):
if backend == Backend.TENSORRT:
input = input.cuda()
results = wrapper({'input': input})
if backend != Backend.RKNN:
results = results['output']
results = results.detach().cpu()
return results
ALL_BACKEND = list(Backend)
ALL_BACKEND.remove(Backend.DEFAULT)
ALL_BACKEND.remove(Backend.PYTORCH)
ALL_BACKEND.remove(Backend.SDK)
@pytest.mark.parametrize('backend', ALL_BACKEND)
def test_wrapper(backend):
check_backend(backend, True)
model_files = ir2backend(backend, onnx_file, ts_file)
assert model_files is not None
wrapper = create_wrapper(backend, model_files)
assert wrapper is not None
results = run_wrapper(backend, wrapper, test_img)
assert results is not None