2021-07-07 14:15:26 +08:00
|
|
|
import os
|
2021-07-12 16:26:44 +08:00
|
|
|
|
|
|
|
import onnx
|
2021-07-07 14:15:26 +08:00
|
|
|
import pytest
|
2021-07-12 16:26:44 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from mmdeploy.core import mark
|
2021-09-01 15:43:49 +08:00
|
|
|
from mmdeploy.core.optimizers import attribute_to_dict
|
2021-07-07 14:15:26 +08:00
|
|
|
|
|
|
|
output_file = 'test_mark.onnx'
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def clear_work_dir_after_test():
|
|
|
|
# clear tmp output before test
|
|
|
|
if os.path.exists(output_file):
|
|
|
|
os.remove(output_file)
|
|
|
|
yield
|
|
|
|
# clear tmp output after test
|
|
|
|
if os.path.exists(output_file):
|
|
|
|
os.remove(output_file)
|
|
|
|
|
|
|
|
|
|
|
|
def test_mark():
|
|
|
|
|
|
|
|
@mark('add', inputs=['a', 'b'], outputs='c')
|
|
|
|
def add(x, y):
|
|
|
|
return torch.add(x, y)
|
|
|
|
|
|
|
|
class TestModel(torch.nn.Module):
|
2021-07-12 16:26:44 +08:00
|
|
|
|
2021-07-07 14:15:26 +08:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x, y):
|
|
|
|
return add(x, y)
|
|
|
|
|
|
|
|
model = TestModel().eval()
|
|
|
|
|
|
|
|
# dummy input
|
|
|
|
x = torch.rand(2, 3, 4)
|
|
|
|
y = torch.rand(2, 3, 4)
|
|
|
|
|
|
|
|
torch.onnx.export(model, (x, y), output_file)
|
|
|
|
onnx_model = onnx.load(output_file)
|
|
|
|
|
|
|
|
nodes = onnx_model.graph.node
|
|
|
|
assert nodes[0].op_type == 'Mark'
|
|
|
|
assert nodes[0].domain == 'mmcv'
|
|
|
|
assert attribute_to_dict(nodes[0].attribute) == dict(
|
2021-09-06 15:11:10 +08:00
|
|
|
dtype=1,
|
|
|
|
func='add',
|
|
|
|
func_id=0,
|
|
|
|
id=0,
|
|
|
|
type='input',
|
|
|
|
name='a',
|
|
|
|
shape=[2, 3, 4])
|
2021-07-07 14:15:26 +08:00
|
|
|
|
|
|
|
assert nodes[1].op_type == 'Mark'
|
|
|
|
assert nodes[1].domain == 'mmcv'
|
|
|
|
assert attribute_to_dict(nodes[1].attribute) == dict(
|
2021-09-06 15:11:10 +08:00
|
|
|
dtype=1,
|
|
|
|
func='add',
|
|
|
|
func_id=0,
|
|
|
|
id=1,
|
|
|
|
type='input',
|
|
|
|
name='b',
|
|
|
|
shape=[2, 3, 4])
|
2021-07-07 14:15:26 +08:00
|
|
|
|
|
|
|
assert nodes[2].op_type == 'Add'
|
|
|
|
|
|
|
|
assert nodes[3].op_type == 'Mark'
|
|
|
|
assert nodes[3].domain == 'mmcv'
|
|
|
|
assert attribute_to_dict(nodes[3].attribute) == dict(
|
2021-09-06 15:11:10 +08:00
|
|
|
dtype=1,
|
|
|
|
func='add',
|
|
|
|
func_id=0,
|
|
|
|
id=0,
|
|
|
|
type='output',
|
|
|
|
name='c',
|
|
|
|
shape=[2, 3, 4])
|
2021-07-07 14:15:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_extract():
|
|
|
|
from mmdeploy.apis import extract_model
|
|
|
|
|
|
|
|
@mark('add', outputs='z')
|
|
|
|
def add(x, y):
|
|
|
|
return torch.add(x, y)
|
|
|
|
|
|
|
|
class TestModel(torch.nn.Module):
|
2021-07-12 16:26:44 +08:00
|
|
|
|
2021-07-07 14:15:26 +08:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x, y):
|
|
|
|
return add(x, y)
|
|
|
|
|
|
|
|
model = TestModel().eval()
|
|
|
|
|
|
|
|
# dummy input
|
|
|
|
x = torch.rand(2, 3, 4)
|
|
|
|
y = torch.rand(2, 3, 4)
|
|
|
|
|
|
|
|
torch.onnx.export(model, (x, y), output_file)
|
|
|
|
onnx_model = onnx.load(output_file)
|
|
|
|
|
|
|
|
extracted = extract_model(onnx_model, 'add:input', 'add:output')
|
|
|
|
|
|
|
|
assert extracted.graph.input[0].name == 'x'
|
|
|
|
assert extracted.graph.input[1].name == 'y'
|
|
|
|
assert extracted.graph.output[0].name == 'z'
|
|
|
|
assert extracted.graph.node[0].op_type == 'Add'
|