[Unittest]: Add ncnn test exporter and topk test (#84)

* add ncnn test exporter in test_ops.py

* add ncnn test exporter in utils.py

* fix ncnn unittest

* fix lint

* fix lint

* fix lint isort

* remove ncnn roi_align pytest

* add ncnn topk unittest

* update to new api

* fix lint

* add comments

* skip class

* skip ncnn

Co-authored-by: hanrui1sensetime <hanrui1@sensetime.com>
Co-authored-by: maningsheng <mnsheng@yeah.net>
pull/12/head
VVsssssk 2021-09-28 14:20:04 +08:00 committed by GitHub
parent 2073f3327e
commit 0bef0513c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 3 deletions

View File

@ -1,3 +1,3 @@
from .utils import TestOnnxRTExporter, TestTensorRTExporter
from .utils import TestNCNNExporter, TestOnnxRTExporter, TestTensorRTExporter
__all__ = ['TestTensorRTExporter', 'TestOnnxRTExporter']
__all__ = ['TestTensorRTExporter', 'TestOnnxRTExporter', 'TestNCNNExporter']

View File

@ -4,10 +4,11 @@ import torch.nn as nn
from mmdeploy.core import register_extra_symbolics
from mmdeploy.utils.test import WrapFunction
from .utils import TestOnnxRTExporter, TestTensorRTExporter
from .utils import TestNCNNExporter, TestOnnxRTExporter, TestTensorRTExporter
TEST_ONNXRT = TestOnnxRTExporter()
TEST_TENSORRT = TestTensorRTExporter()
TEST_NCNN = TestNCNNExporter()
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT])
@ -185,3 +186,49 @@ def test_instance_norm(backend,
dynamic_axes=dynamic_axes,
output_names=['output'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('k', [1, 3, 5])
@pytest.mark.parametrize('dim', [1, 2, 3])
@pytest.mark.parametrize('largest', [True, False])
@pytest.mark.parametrize('sorted', [True, False])
def test_topk(backend,
k,
dim,
largest,
sorted,
input_list=None,
save_dir=None):
backend.check_env()
if not input_list:
input = torch.rand(1, 8, 12, 17)
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
but not {input.shape[0]}')
cfg = dict()
register_extra_symbolics(cfg=cfg, opset=11)
def wrapped_function(inputs):
return torch.Tensor.topk(inputs, k, dim, largest, sorted)
wrapped_model = WrapFunction(wrapped_function)
# when the 'sorted' attribute is False, pytorch will return
# a hard to expect result, which only features that the topk
# number is right. So the Topk unittest only check whether the
# topk elements are right, all the possible order will be accepted.
if not sorted:
backend.run_and_validate(
wrapped_model, [input.float()],
'topk' + f'_no_sorted_dim_{dim}',
input_names=['inputs'],
output_names=['data', 'index'],
save_dir=save_dir)
else:
backend.run_and_validate(
wrapped_model, [input.float()],
'topk',
input_names=['inputs'],
output_names=['data', 'index'],
save_dir=save_dir)

View File

@ -1,4 +1,5 @@
import os
import subprocess
import tempfile
import mmcv
@ -6,6 +7,7 @@ import onnx
import pytest
import torch
import mmdeploy.apis.ncnn as ncnn_apis
import mmdeploy.apis.onnxruntime as ort_apis
import mmdeploy.apis.tensorrt as trt_apis
from mmdeploy.utils.test import assert_allclose
@ -149,3 +151,77 @@ class TestTensorRTExporter:
trt_outputs[name].cpu().float() for name in output_names
]
assert_allclose(model_outputs, trt_outputs, tolerate_small_mismatch)
@pytest.mark.skip(reason='This a not test class but a utility class.')
class TestNCNNExporter:
def __init__(self):
self.backend_name = 'ncnn'
def check_env(self):
if not ncnn_apis.is_available():
pytest.skip(
'NCNN is not installed or custom ops are not compiled.')
def run_and_validate(self,
model,
inputs_list,
model_name='tmp',
tolerate_small_mismatch=False,
do_constant_folding=True,
dynamic_axes=None,
output_names=None,
input_names=None,
save_dir=None):
if not save_dir:
onnx_file_path = tempfile.NamedTemporaryFile().name
ncnn_param_path = tempfile.NamedTemporaryFile().name
ncnn_bin_path = tempfile.NamedTemporaryFile().name
else:
onnx_file_path = os.path.join(save_dir, model_name + '.onnx')
ncnn_param_path = os.path.join(save_dir, model_name + '.param')
ncnn_bin_path = os.path.join(save_dir, model_name + '.bin')
with torch.no_grad():
torch.onnx.export(
model,
tuple(inputs_list),
onnx_file_path,
export_params=True,
keep_initializers_as_inputs=True,
input_names=input_names,
output_names=output_names,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
opset_version=11)
onnx2ncnn_path = ncnn_apis.get_onnx2ncnn_path()
subprocess.call(
[onnx2ncnn_path, onnx_file_path, ncnn_param_path, ncnn_bin_path])
with torch.no_grad():
model_outputs = model(*inputs_list)
if isinstance(model_outputs, torch.Tensor):
model_outputs = [model_outputs]
else:
model_outputs = list(model_outputs)
model_outputs = [
model_output.float() for model_output in model_outputs
]
ncnn_model = ncnn_apis.NCNNWrapper(ncnn_param_path, ncnn_bin_path,
output_names)
ncnn_outputs = ncnn_model(dict(zip(input_names, inputs_list)))
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
if model_name.startswith('topk_no_sorted'):
dim = int(model_name.split('_')[-1])
model_outputs = torch.stack(model_outputs, dim=-1).\
sort(dim=dim).values
ncnn_outputs = torch.stack(ncnn_outputs, dim=-1).\
sort(dim=dim).values
assert_allclose([model_outputs], [ncnn_outputs],
tolerate_small_mismatch)
else:
assert_allclose(model_outputs, ncnn_outputs,
tolerate_small_mismatch)