[Unittest]: Add ncnn test exporter and topk test (#84)
* add ncnn test exporter in test_ops.py * add ncnn test exporter in utils.py * fix ncnn unittest * fix lint * fix lint * fix lint isort * remove ncnn roi_align pytest * add ncnn topk unittest * update to new api * fix lint * add comments * skip class * skip ncnn Co-authored-by: hanrui1sensetime <hanrui1@sensetime.com> Co-authored-by: maningsheng <mnsheng@yeah.net>pull/12/head
parent
2073f3327e
commit
0bef0513c6
|
@ -1,3 +1,3 @@
|
|||
from .utils import TestOnnxRTExporter, TestTensorRTExporter
|
||||
from .utils import TestNCNNExporter, TestOnnxRTExporter, TestTensorRTExporter
|
||||
|
||||
__all__ = ['TestTensorRTExporter', 'TestOnnxRTExporter']
|
||||
__all__ = ['TestTensorRTExporter', 'TestOnnxRTExporter', 'TestNCNNExporter']
|
||||
|
|
|
@ -4,10 +4,11 @@ import torch.nn as nn
|
|||
|
||||
from mmdeploy.core import register_extra_symbolics
|
||||
from mmdeploy.utils.test import WrapFunction
|
||||
from .utils import TestOnnxRTExporter, TestTensorRTExporter
|
||||
from .utils import TestNCNNExporter, TestOnnxRTExporter, TestTensorRTExporter
|
||||
|
||||
TEST_ONNXRT = TestOnnxRTExporter()
|
||||
TEST_TENSORRT = TestTensorRTExporter()
|
||||
TEST_NCNN = TestNCNNExporter()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT])
|
||||
|
@ -185,3 +186,49 @@ def test_instance_norm(backend,
|
|||
dynamic_axes=dynamic_axes,
|
||||
output_names=['output'],
|
||||
save_dir=save_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_NCNN])
|
||||
@pytest.mark.parametrize('k', [1, 3, 5])
|
||||
@pytest.mark.parametrize('dim', [1, 2, 3])
|
||||
@pytest.mark.parametrize('largest', [True, False])
|
||||
@pytest.mark.parametrize('sorted', [True, False])
|
||||
def test_topk(backend,
|
||||
k,
|
||||
dim,
|
||||
largest,
|
||||
sorted,
|
||||
input_list=None,
|
||||
save_dir=None):
|
||||
backend.check_env()
|
||||
|
||||
if not input_list:
|
||||
input = torch.rand(1, 8, 12, 17)
|
||||
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
|
||||
but not {input.shape[0]}')
|
||||
cfg = dict()
|
||||
register_extra_symbolics(cfg=cfg, opset=11)
|
||||
|
||||
def wrapped_function(inputs):
|
||||
return torch.Tensor.topk(inputs, k, dim, largest, sorted)
|
||||
|
||||
wrapped_model = WrapFunction(wrapped_function)
|
||||
|
||||
# when the 'sorted' attribute is False, pytorch will return
|
||||
# a hard to expect result, which only features that the topk
|
||||
# number is right. So the Topk unittest only check whether the
|
||||
# topk elements are right, all the possible order will be accepted.
|
||||
if not sorted:
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input.float()],
|
||||
'topk' + f'_no_sorted_dim_{dim}',
|
||||
input_names=['inputs'],
|
||||
output_names=['data', 'index'],
|
||||
save_dir=save_dir)
|
||||
else:
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input.float()],
|
||||
'topk',
|
||||
input_names=['inputs'],
|
||||
output_names=['data', 'index'],
|
||||
save_dir=save_dir)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
|
@ -6,6 +7,7 @@ import onnx
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import mmdeploy.apis.ncnn as ncnn_apis
|
||||
import mmdeploy.apis.onnxruntime as ort_apis
|
||||
import mmdeploy.apis.tensorrt as trt_apis
|
||||
from mmdeploy.utils.test import assert_allclose
|
||||
|
@ -149,3 +151,77 @@ class TestTensorRTExporter:
|
|||
trt_outputs[name].cpu().float() for name in output_names
|
||||
]
|
||||
assert_allclose(model_outputs, trt_outputs, tolerate_small_mismatch)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='This a not test class but a utility class.')
|
||||
class TestNCNNExporter:
|
||||
|
||||
def __init__(self):
|
||||
self.backend_name = 'ncnn'
|
||||
|
||||
def check_env(self):
|
||||
if not ncnn_apis.is_available():
|
||||
pytest.skip(
|
||||
'NCNN is not installed or custom ops are not compiled.')
|
||||
|
||||
def run_and_validate(self,
|
||||
model,
|
||||
inputs_list,
|
||||
model_name='tmp',
|
||||
tolerate_small_mismatch=False,
|
||||
do_constant_folding=True,
|
||||
dynamic_axes=None,
|
||||
output_names=None,
|
||||
input_names=None,
|
||||
save_dir=None):
|
||||
if not save_dir:
|
||||
onnx_file_path = tempfile.NamedTemporaryFile().name
|
||||
ncnn_param_path = tempfile.NamedTemporaryFile().name
|
||||
ncnn_bin_path = tempfile.NamedTemporaryFile().name
|
||||
else:
|
||||
onnx_file_path = os.path.join(save_dir, model_name + '.onnx')
|
||||
ncnn_param_path = os.path.join(save_dir, model_name + '.param')
|
||||
ncnn_bin_path = os.path.join(save_dir, model_name + '.bin')
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model,
|
||||
tuple(inputs_list),
|
||||
onnx_file_path,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
do_constant_folding=do_constant_folding,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=11)
|
||||
|
||||
onnx2ncnn_path = ncnn_apis.get_onnx2ncnn_path()
|
||||
subprocess.call(
|
||||
[onnx2ncnn_path, onnx_file_path, ncnn_param_path, ncnn_bin_path])
|
||||
|
||||
with torch.no_grad():
|
||||
model_outputs = model(*inputs_list)
|
||||
if isinstance(model_outputs, torch.Tensor):
|
||||
model_outputs = [model_outputs]
|
||||
else:
|
||||
model_outputs = list(model_outputs)
|
||||
model_outputs = [
|
||||
model_output.float() for model_output in model_outputs
|
||||
]
|
||||
ncnn_model = ncnn_apis.NCNNWrapper(ncnn_param_path, ncnn_bin_path,
|
||||
output_names)
|
||||
ncnn_outputs = ncnn_model(dict(zip(input_names, inputs_list)))
|
||||
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
|
||||
|
||||
if model_name.startswith('topk_no_sorted'):
|
||||
dim = int(model_name.split('_')[-1])
|
||||
model_outputs = torch.stack(model_outputs, dim=-1).\
|
||||
sort(dim=dim).values
|
||||
ncnn_outputs = torch.stack(ncnn_outputs, dim=-1).\
|
||||
sort(dim=dim).values
|
||||
assert_allclose([model_outputs], [ncnn_outputs],
|
||||
tolerate_small_mismatch)
|
||||
else:
|
||||
assert_allclose(model_outputs, ncnn_outputs,
|
||||
tolerate_small_mismatch)
|
||||
|
|
Loading…
Reference in New Issue