mmdeploy/tests/test_pytorch/test_pytorch_ops.py
q.yao 3a785f1223
[Refactor] Refactor codebase (#220)
* [WIP] Refactor v2.0 (#163)

* Refactor backend wrapper

* Refactor mmdet.inference

* Fix

* merge

* refactor utils

* Use deployer and deploy_model to manage pipeline

* Resolve comments

* Add a real inference api function

* rename wrappers

* Set execute to private method

* Rename deployer deploy_model

* Refactor task

* remove type hint

* lint

* Resolve comments

* resolve comments

* lint

* docstring

* [Fix]: Fix bugs in details in refactor branch (#192)

* [WIP] Refactor v2.0 (#163)

* Refactor backend wrapper

* Refactor mmdet.inference

* Fix

* merge

* refactor utils

* Use deployer and deploy_model to manage pipeline

* Resolve comments

* Add a real inference api function

* rename wrappers

* Set execute to private method

* Rename deployer deploy_model

* Refactor task

* remove type hint

* lint

* Resolve comments

* resolve comments

* lint

* docstring

* Fix errors

* lint

* resolve comments

* fix bugs

* conflict

* lint and typo

* Resolve comment

* refactor mmseg (#201)

* support mmseg

* fix docstring

* fix docstring

* [Refactor]: Get the count of backend files (#202)

* Fix backend files

* resolve comments

* lint

* Fix ncnn

* [Refactor]: Refactor folders of mmdet (#200)

* Move folders

* lint

* test object detection model

* lint

* reset changes

* fix openvino

* resolve comments

* __init__.py

* Fix path

* [Refactor]: move mmseg (#206)

* [Refactor]: Refactor mmedit (#205)

* feature mmedit

* edit2.0

* edit

* refactor mmedit

* fix __init__.py

* fix __init__

* fix formai

* fix comment

* fix comment

* Fix wrong func_name of ConvFCBBoxHead (#209)

* [Refactor]: Refactor mmdet unit test (#207)

* Move folders

* lint

* test object detection model

* lint

* WIP

* remove print

* finish unit test

* Fix tests

* resolve comments

* Add mask test

* lint

* resolve comments

* Refine cfg file

* Move files

* add files

* Fix path

* [Unittest]: Refine the unit tests in mmdet #214

* [Refactor] refactor mmocr to mmdeploy/codebase (#213)

* refactor mmocr to mmdeploy/codebase

* fix docstring of show_result

* fix docstring of visualize

* refine docstring

* replace print with logging

* refince codes

* resolve comments

* resolve comments

* [Refactor]: mmseg  tests (#210)

* refactor mmseg tests

* rename test_codebase

* update

* add model.py

* fix

* [Refactor] Refactor mmcls and the package (#217)

* refactor mmcls

* fix yapf

* fix isort

* refactor-mmcls-package

* fix print to logging

* fix docstrings according to others comments

* fix comments

* fix comments

* fix allentdans comment in pr215

* remove mmocr init

* [Refactor] Refactor mmedit tests (#212)

* feature mmedit

* edit2.0

* edit

* refactor mmedit

* fix __init__.py

* fix __init__

* fix formai

* fix comment

* fix comment

* buff

* edit test and code refactor

* refactor dir

* refactor tests/mmedit

* fix docstring

* add test coverage

* fix lint

* fix comment

* fix comment

* Update typehint (#216)

* update type hint

* update docstring

* update

* remove file

* fix ppl

* Refine get_predefined_partition_cfg

* fix tensorrt version > 8

* move parse_cuda_device_id to device.py

* Fix cascade

* onnx2ncnn docstring

Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: RunningLeon <maningsheng@sensetime.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
2021-11-25 09:57:05 +08:00

114 lines
3.4 KiB
Python

import tempfile
import onnx
import pytest
import torch
from mmdeploy.core import RewriterContext
onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
@pytest.fixture(autouse=True, scope='module')
def prepare_symbolics():
context = RewriterContext({}, 'tensorrt', opset=11)
context.enter()
yield
context.exit()
class OpModel(torch.nn.Module):
def __init__(self, func, *args):
super().__init__()
self._func = func
self._arg_tuple = args
def forward(self, x):
return self._func(x, *self._arg_tuple)
def get_model_onnx_nodes(model, x, onnx_file=onnx_file):
torch.onnx.export(model, x, onnx_file, opset_version=11)
onnx_model = onnx.load(onnx_file)
nodes = onnx_model.graph.node
return nodes
class TestAdaptivePool:
def test_adaptive_pool_1d_global(self):
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [1]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'GlobalAveragePool'
def test_adaptive_pool_1d(self):
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool1d, [2]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'AveragePool'
def test_adaptive_pool_2d_global(self):
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [1, 1]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'GlobalAveragePool'
def test_adaptive_pool_2d(self):
x = torch.rand(2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool2d, [2, 2]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'AveragePool'
def test_adaptive_pool_3d_global(self):
x = torch.rand(2, 2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool3d,
[1, 1, 1]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'GlobalAveragePool'
def test_adaptive_pool_3d(self):
x = torch.rand(2, 2, 2, 2)
model = OpModel(torch.nn.functional.adaptive_avg_pool3d,
[2, 2, 2]).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'AveragePool'
def test_grid_sampler():
x = torch.rand(1, 1, 2, 2)
flow = torch.zeros([1, 2, 2, 2])
model = OpModel(torch.grid_sampler, flow, 0, 0, False).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[1].op_type == 'grid_sampler'
assert nodes[1].domain == 'mmcv'
def test_instance_norm():
x = torch.rand(1, 2, 2, 2)
model = OpModel(torch.group_norm, 1, torch.rand([2]), torch.rand([2]),
1e-05).eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[4].op_type == 'TRTInstanceNormalization'
assert nodes[4].domain == 'mmcv'
class TestSqueeze:
def test_squeeze_default(self):
x = torch.rand(1, 1, 2, 2)
model = OpModel(torch.squeeze)
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].attribute[0].ints == [0, 1]
assert nodes[0].op_type == 'Squeeze'
def test_squeeze(self):
x = torch.rand(1, 1, 2, 2)
model = OpModel(torch.squeeze, 0)
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].attribute[0].ints == [0]
assert nodes[0].op_type == 'Squeeze'