mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* [WIP] Refactor v2.0 (#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * [Fix]: Fix bugs in details in refactor branch (#192) * [WIP] Refactor v2.0 (#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * Fix errors * lint * resolve comments * fix bugs * conflict * lint and typo * Resolve comment * refactor mmseg (#201) * support mmseg * fix docstring * fix docstring * [Refactor]: Get the count of backend files (#202) * Fix backend files * resolve comments * lint * Fix ncnn * [Refactor]: Refactor folders of mmdet (#200) * Move folders * lint * test object detection model * lint * reset changes * fix openvino * resolve comments * __init__.py * Fix path * [Refactor]: move mmseg (#206) * [Refactor]: Refactor mmedit (#205) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * Fix wrong func_name of ConvFCBBoxHead (#209) * [Refactor]: Refactor mmdet unit test (#207) * Move folders * lint * test object detection model * lint * WIP * remove print * finish unit test * Fix tests * resolve comments * Add mask test * lint * resolve comments * Refine cfg file * Move files * add files * Fix path * [Unittest]: Refine the unit tests in mmdet #214 * [Refactor] refactor mmocr to mmdeploy/codebase (#213) * refactor mmocr to mmdeploy/codebase * fix docstring of show_result * fix docstring of visualize * refine docstring * replace print with logging * refince codes * resolve comments * resolve comments * [Refactor]: mmseg tests (#210) * refactor mmseg tests * rename test_codebase * update * add model.py * fix * [Refactor] Refactor mmcls and the package (#217) * refactor mmcls * fix yapf * fix isort * refactor-mmcls-package * fix print to logging * fix docstrings according to others comments * fix comments * fix comments * fix allentdans comment in pr215 * remove mmocr init * [Refactor] Refactor mmedit tests (#212) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * buff * edit test and code refactor * refactor dir * refactor tests/mmedit * fix docstring * add test coverage * fix lint * fix comment * fix comment * Update typehint (#216) * update type hint * update docstring * update * remove file * fix ppl * Refine get_predefined_partition_cfg * fix tensorrt version > 8 * move parse_cuda_device_id to device.py * Fix cascade * onnx2ncnn docstring Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: RunningLeon <maningsheng@sensetime.com> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
import tempfile
|
|
|
|
import onnx
|
|
import pytest
|
|
import torch
|
|
|
|
from mmdeploy.core import RewriterContext
|
|
|
|
onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope='module')
|
|
def prepare_symbolics():
|
|
context = RewriterContext({}, 'tensorrt', opset=11)
|
|
context.enter()
|
|
|
|
yield
|
|
|
|
context.exit()
|
|
|
|
|
|
class OpModel(torch.nn.Module):
|
|
|
|
def __init__(self, func, *args):
|
|
super().__init__()
|
|
self._func = func
|
|
self._arg_tuple = args
|
|
|
|
def forward(self, x):
|
|
return self._func(x, *self._arg_tuple)
|
|
|
|
|
|
def get_model_onnx_nodes(model, x, onnx_file=onnx_file):
|
|
torch.onnx.export(model, x, onnx_file, opset_version=11)
|
|
onnx_model = onnx.load(onnx_file)
|
|
nodes = onnx_model.graph.node
|
|
return nodes
|
|
|
|
|
|
class TestAdaptivePool:
|
|
|
|
def test_adaptive_pool_1d_global(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [1]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'GlobalAveragePool'
|
|
|
|
def test_adaptive_pool_1d(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [2]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'AveragePool'
|
|
|
|
def test_adaptive_pool_2d_global(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [1, 1]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'GlobalAveragePool'
|
|
|
|
def test_adaptive_pool_2d(self):
|
|
x = torch.rand(2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [2, 2]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'AveragePool'
|
|
|
|
def test_adaptive_pool_3d_global(self):
|
|
x = torch.rand(2, 2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool3d,
|
|
[1, 1, 1]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'GlobalAveragePool'
|
|
|
|
def test_adaptive_pool_3d(self):
|
|
x = torch.rand(2, 2, 2, 2)
|
|
model = OpModel(torch.nn.functional.adaptive_avg_pool3d,
|
|
[2, 2, 2]).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].op_type == 'AveragePool'
|
|
|
|
|
|
def test_grid_sampler():
|
|
x = torch.rand(1, 1, 2, 2)
|
|
flow = torch.zeros([1, 2, 2, 2])
|
|
model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[1].op_type == 'grid_sampler'
|
|
assert nodes[1].domain == 'mmcv'
|
|
|
|
|
|
def test_instance_norm():
|
|
x = torch.rand(1, 2, 2, 2)
|
|
model = OpModel(torch.group_norm, 1, torch.rand([2]), torch.rand([2]),
|
|
1e-05).eval()
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[4].op_type == 'TRTInstanceNormalization'
|
|
assert nodes[4].domain == 'mmcv'
|
|
|
|
|
|
class TestSqueeze:
|
|
|
|
def test_squeeze_default(self):
|
|
x = torch.rand(1, 1, 2, 2)
|
|
model = OpModel(torch.squeeze)
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].attribute[0].ints == [0, 1]
|
|
assert nodes[0].op_type == 'Squeeze'
|
|
|
|
def test_squeeze(self):
|
|
x = torch.rand(1, 1, 2, 2)
|
|
model = OpModel(torch.squeeze, 0)
|
|
nodes = get_model_onnx_nodes(model, x)
|
|
assert nodes[0].attribute[0].ints == [0]
|
|
assert nodes[0].op_type == 'Squeeze'
|