mmdeploy/tests/test_core/test_mark.py

71 lines
1.6 KiB
Python
Raw Normal View History

2021-11-30 15:00:37 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import onnx
import torch
from mmdeploy.core import mark
[Feature] Refactor v1 (#56) * [Refactor] add enum class and use functions to get configuration (#40) * add task and codebase enum class * use funcitons to get config * Refactor wrappers of mmcls and mmseg (#41) * move wrappers of cls & det to apis * remove get_classes_from_config * rename onnx_helper to onnx_utils * move import to outside of class * refactor ortwrappers * Refactor build dataset and dataloader for mmseg (#44) * refactor build_dataset and build_dataloader for mmcls and mmseg * remove repeated classes * set build_dataloader with shuffle=False * [Refactor] pplwrapper and mmocr refactor (#46) * add * add pplwrapper and refactor mmocr * fix lint * remove unused arguments * apply dict input for pplwrapper and ortwrapper * add condition before import ppl and ort stuff * update ppl (#51) * Refactor return value and extract_model (#54) * remove ret_value * refactor extract_model * fix typo * resolve comments * [Refactor] Refactor model inference pipeline (#52) * move attribute_to_dict to extract_model * simplify the inference and visualization * remove unused import * [Feature] Support SRCNN in mmedit with ONNXRuntime and TensorRT (#45) * finish mmedit-ort * edit __init__ files * add noqa * add tensorrt support * 1. Rename "base.py" 2. Move srcnn.py to correct directory * fix bugs * remove figures * align to refactor-v1 * update comment in srcnn * fix lint * newfunc -> new_func * Add visualize.py split visualize() in each codebase * fix lint * fix lint * remove unnecessary code in ORTRestorer * remove .api * edit super(), remove dataset * [Refactor]: Change name of split to partition (#57) * refactor mmcls configs * refactor mmdet configs and split params * rename rest split to partition from master * remove base.py * fix init of inference class * fix mmocr init, add show_result alias Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: RunningLeon <maningsheng@sensetime.com> Co-authored-by: Yifan Zhou <singlezombie@163.com>
2021-09-01 15:43:49 +08:00
from mmdeploy.core.optimizers import attribute_to_dict
output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
def test_mark():
@mark('add', inputs=['a', 'b'], outputs='c')
def add(x, y):
return torch.add(x, y)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return add(x, y)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
torch.onnx.export(model, (x, y), output_file)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'Mark'
assert nodes[0].domain == 'mmdeploy'
assert attribute_to_dict(nodes[0].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=0,
type='input',
name='a',
shape=[2, 3, 4])
assert nodes[1].op_type == 'Mark'
assert nodes[1].domain == 'mmdeploy'
assert attribute_to_dict(nodes[1].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=1,
type='input',
name='b',
shape=[2, 3, 4])
assert nodes[2].op_type == 'Add'
assert nodes[3].op_type == 'Mark'
assert nodes[3].domain == 'mmdeploy'
assert attribute_to_dict(nodes[3].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=0,
type='output',
name='c',
shape=[2, 3, 4])