# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess import tempfile import mmengine import pytest import torch import torch.nn as nn from mmdeploy.utils.constants import Backend from mmdeploy.utils.test import check_backend onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name ts_file = tempfile.NamedTemporaryFile(suffix='.pt').name test_img = torch.rand(1, 3, 8, 8) output_names = ['output'] input_names = ['input'] target_platform = 'rk3588' # rknn pre-compiled model need device @pytest.mark.skip(reason='This a not test class but a utility class.') class TestModel(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x + test_img model = TestModel().eval() @pytest.fixture(autouse=True, scope='module') def generate_onnx_file(): with torch.no_grad(): torch.onnx.export( model, test_img, onnx_file, output_names=output_names, input_names=input_names, keep_initializers_as_inputs=True, do_constant_folding=True, verbose=False, opset_version=11, dynamic_axes=None) @pytest.fixture(autouse=True, scope='module') def generate_torchscript_file(): from mmengine import Config backend = Backend.TORCHSCRIPT.value deploy_cfg = Config({'backend_config': dict(type=backend)}) from mmdeploy.apis.torch_jit import trace context_info = dict(deploy_cfg=deploy_cfg) output_prefix = osp.splitext(ts_file)[0] example_inputs = torch.rand(1, 3, 8, 8) trace( model, example_inputs, output_path_prefix=output_prefix, backend=backend, context_info=context_info) def ir2backend(backend, onnx_file, ts_file): if backend == Backend.TENSORRT: from mmdeploy.backend.tensorrt import from_onnx backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name from_onnx( onnx_file, osp.splitext(backend_file)[0], { 'input': { 'min_shape': [1, 3, 8, 8], 'opt_shape': [1, 3, 8, 8], 'max_shape': [1, 3, 8, 8] } }) return backend_file elif backend == Backend.ONNXRUNTIME: return onnx_file elif backend == Backend.PPLNN: from mmdeploy.apis.pplnn import from_onnx output_file_prefix = tempfile.NamedTemporaryFile().name from_onnx(onnx_file, output_file_prefix=output_file_prefix) algo_file = output_file_prefix + '.json' output_file = output_file_prefix + '.onnx' return output_file, algo_file elif backend == Backend.NCNN: from mmdeploy.backend.ncnn.init_plugins import get_onnx2ncnn_path onnx2ncnn_path = get_onnx2ncnn_path() param_file = tempfile.NamedTemporaryFile(suffix='.param').name bin_file = tempfile.NamedTemporaryFile(suffix='.bin').name subprocess.call([onnx2ncnn_path, onnx_file, param_file, bin_file]) return param_file, bin_file elif backend == Backend.OPENVINO: from mmdeploy.apis.openvino import from_onnx, get_output_model_file backend_dir = tempfile.TemporaryDirectory().name backend_file = get_output_model_file(onnx_file, backend_dir) input_info = {'input': test_img.shape} output_names = ['output'] work_dir = backend_dir from_onnx(onnx_file, work_dir, input_info, output_names) return backend_file elif backend == Backend.RKNN: from mmdeploy.apis.rknn import onnx2rknn rknn_file = onnx_file.replace('.onnx', '.rknn') deploy_cfg = mmengine.Config( dict( backend_config=dict( type='rknn', common_config=dict(target_platform=target_platform), quantization_config=dict( do_quantization=False, dataset=None), input_size_list=[[3, 8, 8]]))) onnx2rknn(onnx_file, rknn_file, deploy_cfg) return rknn_file elif backend == Backend.ASCEND: from mmdeploy.apis.ascend import from_onnx backend_dir = tempfile.TemporaryDirectory().name work_dir = backend_dir file_name = osp.splitext(osp.split(onnx_file)[1])[0] backend_file = osp.join(work_dir, file_name + '.om') model_inputs = mmengine.Config( dict(input_shapes=dict(input=test_img.shape))) from_onnx(onnx_file, work_dir, model_inputs) return backend_file elif backend == Backend.TVM: from mmdeploy.backend.tvm import from_onnx, get_library_ext ext = get_library_ext() lib_file = tempfile.NamedTemporaryFile(suffix=ext).name shape = {'input': test_img.shape} dtype = {'input': 'float32'} target = 'llvm' tuner_dict = dict(type='DefaultTuner', target=target) from_onnx( onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict) assert osp.exists(lib_file) return lib_file elif backend == Backend.TORCHSCRIPT: return ts_file elif backend == Backend.COREML: output_names = ['output'] from mmdeploy.backend.coreml.torchscript2coreml import ( from_torchscript, get_model_suffix) backend_dir = tempfile.TemporaryDirectory().name work_dir = backend_dir torchscript_name = osp.splitext(osp.split(ts_file)[1])[0] output_file_prefix = osp.join(work_dir, torchscript_name) convert_to = 'mlprogram' from_torchscript( ts_file, output_file_prefix, input_names=input_names, output_names=output_names, input_shapes=dict( input=dict( min_shape=[1, 3, 8, 8], default_shape=[1, 3, 8, 8], max_shape=[1, 3, 8, 8])), convert_to=convert_to) suffix = get_model_suffix(convert_to) return output_file_prefix + suffix else: raise NotImplementedError( f'Convert for {backend.value} has not been implemented.') def create_wrapper(backend, model_files): from mmdeploy.backend.base import get_backend_manager backend_mgr = get_backend_manager(backend.value) deploy_cfg = None if isinstance(model_files, str): model_files = [model_files] elif backend == Backend.RKNN: deploy_cfg = dict( backend_config=dict( common_config=dict(target_platform=target_platform))) return backend_mgr.build_wrapper( model_files, input_names=input_names, output_names=output_names, deploy_cfg=deploy_cfg) def run_wrapper(backend, wrapper, input): if backend == Backend.TENSORRT: input = input.cuda() results = wrapper({'input': input}) if backend != Backend.RKNN: results = results['output'] results = results.detach().cpu() return results ALL_BACKEND = list(Backend) ALL_BACKEND.remove(Backend.DEFAULT) ALL_BACKEND.remove(Backend.PYTORCH) ALL_BACKEND.remove(Backend.SDK) @pytest.mark.parametrize('backend', ALL_BACKEND) def test_wrapper(backend): check_backend(backend, True) model_files = ir2backend(backend, onnx_file, ts_file) assert model_files is not None wrapper = create_wrapper(backend, model_files) assert wrapper is not None results = run_wrapper(backend, wrapper, test_img) assert results is not None