mmdeploy/tests/test_pytorch/test_pytorch_ops.py
AllentDan 4fc8828af8 fix ocr UT
fix ci and lint

fix det

fix cuda ci

fix mmdet test

update object detection

fix ut

fix layer norm ut

update ut

lock mmeit version

fix mmocr mmcls ut

add conftest.py

fix ocr ut

fix mmedit ci

install mmedit from source

fix rknn model and prepare_onnx_paddings__tensorrt UT

docstring

fix coreml export

update mmocr config

small test

recovery assert

fix ci
2022-09-29 16:37:36 +08:00

192 lines
5.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import onnx
import pytest
import torch
from mmengine import Config
from mmdeploy.core import RewriterContext
onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
@pytest.fixture(autouse=False, scope='function')
def prepare_symbolics():
context = RewriterContext(
Config(
dict(
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='end2end.onnx',
input_names=['input'],
output_names=['output'],
input_shape=None),
backend_config=dict(type='tensorrt'))),
'tensorrt',
opset=11)
context.enter()
yield
context.exit()
@pytest.fixture(autouse=False, scope='function')
def prepare_symbolics_ncnn():
context = RewriterContext(
Config({'backend_config': {
'type': 'ncnn'
}}), 'ncnn', opset=11)
context.enter()
yield
context.exit()
class OpModel(torch.nn.Module):
def __init__(self, func, *args):
super().__init__()
self._func = func
self._arg_tuple = args
def forward(self, x):
return self._func(x, *self._arg_tuple)
def get_model_onnx_nodes(model, x, onnx_file=onnx_file):
torch.onnx.export(model, x, onnx_file, opset_version=11)
onnx_model = onnx.load(onnx_file)
nodes = onnx_model.graph.node
return nodes
@pytest.mark.usefixtures('prepare_symbolics')
class TestAdaptivePool:
def test_adaptive_pool_2d_global(self):
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [1, 1]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'GlobalAveragePool'
def test_adaptive_pool_2d(self):
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [2, 2]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[-1].op_type == 'AveragePool'
@pytest.mark.usefixtures('prepare_symbolics_ncnn')
def test_adaptive_pool_2d_ncnn():
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool2d,
torch.tensor([2, 2], dtype=torch.int64)).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[1].op_type == 'AdaptiveAvgPool2d'
assert nodes[1].domain == 'mmdeploy'
@pytest.mark.usefixtures('prepare_symbolics')
def test_grid_sampler():
x = torch.rand(1, 1, 2, 2)
flow = torch.zeros([1, 2, 2, 2])
model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[1].op_type == 'grid_sampler'
assert nodes[1].domain == 'mmdeploy'
@pytest.mark.usefixtures('prepare_symbolics')
def test_roll():
x = torch.rand(1, 4, 4, 4)
shifts = [1, 1, 1]
dims = [3, 3, 3]
model = OpModel(torch.roll, shifts, dims).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[-2].op_type == 'Slice'
assert nodes[-1].op_type == 'Concat'
@pytest.mark.usefixtures('prepare_symbolics')
def test_instance_norm():
x = torch.rand(1, 2, 2, 2)
model = OpModel(torch.group_norm, 1, torch.rand([2]), torch.rand([2]),
1e-05).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[4].op_type == 'TRTInstanceNormalization'
assert nodes[4].domain == 'mmdeploy'
@pytest.mark.usefixtures('prepare_symbolics_ncnn')
class TestLinear:
def check(self, nodes):
print(nodes)
exist = False
for node in nodes:
if node.op_type in ['Gemm', 'MatMul']:
exist = True
break
assert exist is True
def test_normal(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
bias = torch.rand(2)
model = OpModel(torch.nn.functional.linear, w, bias).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)
def test_no_bias(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
model = OpModel(torch.nn.functional.linear, w).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)
@pytest.mark.usefixtures('prepare_symbolics')
class TestSqueeze:
def test_squeeze_default(self):
x = torch.rand(1, 1, 2, 2)
model = OpModel(torch.squeeze)
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].attribute[0].ints == [0, 1]
assert nodes[0].op_type == 'Squeeze'
def test_squeeze(self):
x = torch.rand(1, 1, 2, 2)
model = OpModel(torch.squeeze, 0)
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].attribute[0].ints == [0]
assert nodes[0].op_type == 'Squeeze'
@pytest.mark.usefixtures('prepare_symbolics')
def test_hardsigmoid():
x = torch.rand(1, 2, 3, 4)
model = torch.nn.Hardsigmoid().eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'HardSigmoid'
@pytest.mark.usefixtures('prepare_symbolics')
def test_layer_norm():
x = torch.rand(2, 1, 4)
model = torch.nn.LayerNorm(4).eval()
torch.onnx.export(model, x, onnx_file, opset_version=11)
onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
output = graph.output[0]
dim = output.type.tensor_type.shape.dim
assert dim[0].dim_value == 2
assert dim[1].dim_value == 1
assert dim[2].dim_value == 4