mmdeploy/tests/test_core/test_mark.py
q.yao 2b98040b26
[Feature] Refactor v1 (#56)
* [Refactor] add enum class and use functions to get configuration (#40)

* add task and codebase enum class

* use funcitons to get config

* Refactor wrappers of mmcls and mmseg (#41)

* move wrappers of cls & det to apis

* remove get_classes_from_config

* rename onnx_helper to onnx_utils

* move import to outside of class

* refactor ortwrappers

* Refactor build dataset and dataloader for mmseg (#44)

* refactor build_dataset and build_dataloader for mmcls and mmseg

* remove repeated classes

* set build_dataloader with shuffle=False

* [Refactor] pplwrapper and mmocr refactor (#46)

* add

* add pplwrapper and refactor mmocr

* fix lint

* remove unused arguments

* apply dict input for pplwrapper and ortwrapper

* add condition before import ppl and ort stuff

* update ppl (#51)

* Refactor return value and extract_model (#54)

* remove ret_value

* refactor extract_model

* fix typo

* resolve comments

* [Refactor] Refactor model inference pipeline (#52)

* move attribute_to_dict to extract_model

* simplify the inference and visualization

* remove unused import

* [Feature] Support SRCNN in mmedit with ONNXRuntime and TensorRT (#45)

* finish mmedit-ort

* edit __init__ files

* add noqa

* add tensorrt support

* 1. Rename "base.py"
2. Move srcnn.py to correct directory

* fix bugs

* remove figures

* align to refactor-v1

* update comment in srcnn

* fix lint

* newfunc -> new_func

* Add visualize.py

split visualize() in each codebase

* fix lint

* fix lint

* remove unnecessary code in ORTRestorer

* remove .api

* edit super(), remove dataset

* [Refactor]: Change name of split to partition (#57)

* refactor mmcls configs

* refactor mmdet configs and split params

* rename rest split to partition from master

* remove base.py

* fix init of inference class

* fix mmocr init, add show_result alias

Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: RunningLeon <maningsheng@sensetime.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
2021-09-01 15:43:49 +08:00

96 lines
2.3 KiB
Python

import os
import onnx
import pytest
import torch
from mmdeploy.core import mark
from mmdeploy.core.optimizers import attribute_to_dict
output_file = 'test_mark.onnx'
@pytest.fixture(autouse=True)
def clear_work_dir_after_test():
# clear tmp output before test
if os.path.exists(output_file):
os.remove(output_file)
yield
# clear tmp output after test
if os.path.exists(output_file):
os.remove(output_file)
def test_mark():
@mark('add', inputs=['a', 'b'], outputs='c')
def add(x, y):
return torch.add(x, y)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return add(x, y)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
torch.onnx.export(model, (x, y), output_file)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'Mark'
assert nodes[0].domain == 'mmcv'
assert attribute_to_dict(nodes[0].attribute) == dict(
func='add', id=0, type='input', name='a')
assert nodes[1].op_type == 'Mark'
assert nodes[1].domain == 'mmcv'
assert attribute_to_dict(nodes[1].attribute) == dict(
func='add', id=1, type='input', name='b')
assert nodes[2].op_type == 'Add'
assert nodes[3].op_type == 'Mark'
assert nodes[3].domain == 'mmcv'
assert attribute_to_dict(nodes[3].attribute) == dict(
func='add', id=0, type='output', name='c')
def test_extract():
from mmdeploy.apis import extract_model
@mark('add', outputs='z')
def add(x, y):
return torch.add(x, y)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return add(x, y)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
torch.onnx.export(model, (x, y), output_file)
onnx_model = onnx.load(output_file)
extracted = extract_model(onnx_model, 'add:input', 'add:output')
assert extracted.graph.input[0].name == 'x'
assert extracted.graph.input[1].name == 'y'
assert extracted.graph.output[0].name == 'z'
assert extracted.graph.node[0].op_type == 'Add'