mmdeploy/tests/test_core/test_mark.py
Yifan Zhou 4de5bbb461
[Unittest]: Add demos for core/apis/utils (#91)
* add unittests

* lint

* modify .gitignore, remove useless files

* remove emtpy.py and generate it when use

* Update according to comments

1. Use tempfile
2. Delete inference test (which will be tested in each codebase)
3. Refine calibration test

* update annotation

* Add export_info

* Reduce data scale, fix assert

* update json blank line

* add backend check
2021-09-27 16:10:47 +08:00

70 lines
1.5 KiB
Python

import tempfile
import onnx
import torch
from mmdeploy.core import mark
from mmdeploy.core.optimizers import attribute_to_dict
output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
def test_mark():
@mark('add', inputs=['a', 'b'], outputs='c')
def add(x, y):
return torch.add(x, y)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return add(x, y)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
torch.onnx.export(model, (x, y), output_file)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'Mark'
assert nodes[0].domain == 'mmcv'
assert attribute_to_dict(nodes[0].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=0,
type='input',
name='a',
shape=[2, 3, 4])
assert nodes[1].op_type == 'Mark'
assert nodes[1].domain == 'mmcv'
assert attribute_to_dict(nodes[1].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=1,
type='input',
name='b',
shape=[2, 3, 4])
assert nodes[2].op_type == 'Add'
assert nodes[3].op_type == 'Mark'
assert nodes[3].domain == 'mmcv'
assert attribute_to_dict(nodes[3].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=0,
type='output',
name='c',
shape=[2, 3, 4])