189 lines
5.6 KiB
Python
189 lines
5.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import tempfile
|
|
|
|
import onnx
|
|
import pytest
|
|
import torch
|
|
from mmcv import Config
|
|
|
|
from mmdeploy.core import RewriterContext
|
|
|
|
onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
|
|
|
|
|
|
@pytest.fixture(autouse=False, scope='function')
|
|
def prepare_symbolics():
|
|
context = RewriterContext(
|
|
Config({'backend_config': {
|
|
'type': 'tensorrt'
|
|
}}), 'tensorrt', opset=11)
|
|
context.enter()
|
|
|
|
yield
|
|
|
|
context.exit()
|
|
|
|
|
|
@pytest.fixture(autouse=False, scope='function')
|
|
def prepare_symbolics_ncnn():
|
|
context = RewriterContext(
|
|
Config({'backend_config': {
|
|
'type': 'ncnn'
|
|
}}), 'ncnn', opset=11)
|
|
context.enter()
|
|
|
|
yield
|
|
|
|
context.exit()
|
|
|
|
|
|
class OpModel(torch.nn.Module):
|
|
|
|
def __init__(self, func, *args):
|
|
super().__init__()
|
|
self._func = func
|
|
self._arg_tuple = args
|
|
|
|
def forward(self, x):
|
|
return self._func(x, *self._arg_tuple)
|
|
|
|
|
|
def get_model_onnx_nodes(model, x, onnx_file=onnx_file):
|
|
torch.onnx.export(model, x, onnx_file, opset_version=11)
|
|
onnx_model = onnx.load(onnx_file)
|
|
nodes = onnx_model.graph.node
|
|
return nodes
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics')
|
|
class TestAdaptivePool:
|
|
|
|
def test_adaptive_pool_1d_global(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [1]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'GlobalAveragePool'
|
|
|
|
def test_adaptive_pool_1d(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [2]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'AveragePool'
|
|
|
|
def test_adaptive_pool_2d_global(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [1, 1]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'GlobalAveragePool'
|
|
|
|
def test_adaptive_pool_2d(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [2, 2]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'AveragePool'
|
|
|
|
def test_adaptive_pool_3d_global(self):
|
|
x = torch.rand(2, 2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool3d,
|
|
[1, 1, 1]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'GlobalAveragePool'
|
|
|
|
def test_adaptive_pool_3d(self):
|
|
x = torch.rand(2, 2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool3d,
|
|
[2, 2, 2]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'AveragePool'
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics_ncnn')
|
|
def test_adaptive_pool_2d_ncnn():
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool2d,
|
|
torch.tensor([2, 2], dtype=torch.int64)).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[1].op_type == 'AdaptiveAvgPool2d'
|
|
assert nodes[1].domain == 'mmdeploy'
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics')
|
|
def test_grid_sampler():
|
|
x = torch.rand(1, 1, 2, 2)
|
|
flow = torch.zeros([1, 2, 2, 2])
|
|
model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[1].op_type == 'grid_sampler'
|
|
assert nodes[1].domain == 'mmdeploy'
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics')
|
|
def test_instance_norm():
|
|
x = torch.rand(1, 2, 2, 2)
|
|
model = OpModel(torch.group_norm, 1, torch.rand([2]), torch.rand([2]),
|
|
1e-05).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[4].op_type == 'TRTInstanceNormalization'
|
|
assert nodes[4].domain == 'mmdeploy'
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics_ncnn')
|
|
class TestLinear:
|
|
|
|
def check(self, nodes):
|
|
print(nodes)
|
|
|
|
from packaging.version import parse as version_parse
|
|
version = version_parse(torch.__version__)
|
|
target = 'Gemm'
|
|
if version.major <= 1 and version.minor <= 8:
|
|
target = 'MatMul'
|
|
exist = False
|
|
for node in nodes:
|
|
if node.op_type == target:
|
|
exist = True
|
|
break
|
|
|
|
assert exist is True
|
|
|
|
def test_normal(self):
|
|
x = torch.rand(1, 2, 3)
|
|
w = torch.rand(2, 3)
|
|
bias = torch.rand(2)
|
|
model = OpModel(torch.nn.functional.linear, w, bias).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
self.check(nodes)
|
|
|
|
def test_no_bias(self):
|
|
x = torch.rand(1, 2, 3)
|
|
w = torch.rand(2, 3)
|
|
model = OpModel(torch.nn.functional.linear, w).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
self.check(nodes)
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics')
|
|
class TestSqueeze:
|
|
|
|
def test_squeeze_default(self):
|
|
x = torch.rand(1, 1, 2, 2)
|
|
model = OpModel(torch.squeeze)
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].attribute[0].ints == [0, 1]
|
|
assert nodes[0].op_type == 'Squeeze'
|
|
|
|
def test_squeeze(self):
|
|
x = torch.rand(1, 1, 2, 2)
|
|
model = OpModel(torch.squeeze, 0)
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].attribute[0].ints == [0]
|
|
assert nodes[0].op_type == 'Squeeze'
|
|
|
|
|
|
@pytest.mark.usefixtures('prepare_symbolics')
|
|
def test_hardsigmoid():
|
|
x = torch.rand(1, 2, 3, 4)
|
|
model = torch.nn.Hardsigmoid().eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'HardSigmoid'
|