mmdeploy/tests/test_ops/test_ops.py
RunningLeon 4d8ea40f55
Sync v0.7.0 to dev-1.x (#907)
* make -install -> make install (#621)

change `make -install` to `make install`

https://github.com/open-mmlab/mmdeploy/issues/618

* [Fix] fix csharp api detector release result (#620)

* fix csharp api detector release result

* fix wrong count arg of xxx_release_result in c# api

* [Enhancement] Support two-stage rotated detector TensorRT. (#530)

* upload

* add fake_multiclass_nms_rotated

* delete unused code

* align with pytorch

* Update delta_midpointoffset_rbbox_coder.py

* add trt rotated roi align

* add index feature in nms

* not good

* fix index

* add ut

* add benchmark

* move to csrc/mmdeploy

* update unit test

Co-authored-by: zytx121 <592267829@qq.com>

* Reduce mmcls version dependency (#635)

* fix shufflenetv2 with trt (#645)

* fix shufflenetv2 and pspnet

* fix ci

* remove print

* ' -> " (#654)

If there is a variable in the string, single quotes will ignored it, while double quotes will bring the variable into the string after parsing

* ' -> " (#655)

same with https://github.com/open-mmlab/mmdeploy/pull/654

* Support deployment of Segmenter (#587)

* support segmentor with ncnn

* update regression yml

* replace chunk with split to support ts

* update regression yml

* update docs

* fix segmenter ncnn inference failure brought by #477

* add test

* fix test for ncnn and trt

* fix lint

* export nn.linear to Gemm op in onnx for ncnn

* fix ci

* simplify `Expand` (#617)

* Fix typo (#625)

* Add make install in en docs

* Add make install in zh docs

* Fix typo

* Merge and add windows build

Co-authored-by: tripleMu <865626@163.com>

* [Enhancement] Fix ncnn unittest (#626)

* optmize-csp-darknet

* replace floordiv to torch.div

* update csp_darknet default implement

* fix test

* [Enhancement] TensorRT Anchor generator plugin (#646)

* custom trt anchor generator

* add ut

* add docstring, update doc

* Add partition doc and sample code (#599)

* update torch2onnx tool to support onnx partition

* add model partition of yolov3

* add cn doc

* update torch2onnx tool to support onnx partition

* add model partition of yolov3

* add cn doc

* add to index.rst

* resolve comment

* resolve comments

* fix lint

* change caption level in docs

* update docs (#624)

* Add java apis and demos (#563)

* add java classifier detector

* add segmentor

* fix lint

* add ImageRestorer java apis and demo

* remove useless count parameter for Segmentor and Restorer, add PoseDetector

* add RotatedDetection java api and demo

* add Ocr java demo and apis

* remove mmrotate ncnn java api and demo

* fix lint

* sync java api folder after rebase to master

* fix include

* remove record

* fix java apis dir path in cmake

* add java demo readme

* fix lint mdformat

* add test javaapi ci

* fix lint

* fix flake8

* fix test javaapi ci

* refactor readme.md

* fix install opencv for ci

* fix install opencv : add permission

* add all codebases and mmcv install

* add torch

* install mmdeploy

* fix image path

* fix picture path

* fix import ncnn

* fix import ncnn

* add submodule of pybind

* fix pybind submodule

* change download to git clone for submodule

* fix ncnn dir

* fix README error

* simplify the github ci

* fix ci

* fix yapf

* add JNI as required

* fix Capitalize

* fix Capitalize

* fix copyright

* ignore .class changed

* add OpenJDK installation docs

* install target of javaapi

* simplify ci

* add jar

* fix ci

* fix ci

* fix test java command

* debugging what failed

* debugging what failed

* debugging what failed

* add java version info

* install openjdk

* add java env var

* fix export

* fix export

* fix export

* fix export

* fix picture path

* fix picture path

* fix file name

* fix file name

* fix README

* remove java_api strategy

* fix python version

* format task name

* move args position

* extract common utils code

* show image class result

* add detector result

* segmentation result format

* add ImageRestorer result

* add PoseDetection java result format

* fix ci

* stage ocr

* add visualize

* move utils

* fix lint

* fix ocr bugs

* fix ci demo

* fix java classpath for ci

* fix popd

* fix ocr demo text garbled

* fix ci

* fix ci

* fix ci

* fix path of utils ci

* update the circleci config file by adding workflows both for linux, windows and linux-gpu (#368)

* update circleci by adding more workflows

* fix test workflow failure on windows platform

* fix docker exec command for SDK unittests

* Fixed tensorrt plugin not found in Windows (#672)

* update introduction.png (#674)

* [Enhancement] Add fuse select assign pass (#589)

* Add fuse select assign pass

* move code to csrc

* add config flag

* remove bool cast

* fix export sdk info of input shape (#667)

* Update get_started.md (#675)

Fix backend model assignment

* Update get_started.md (#676)

Fix backend model assignment

* [Fix] fix clang build (#677)

* fix clang build

* fix ndk build

* fix ndk build

* switch to `std::filesystem` for clang-7 and later

* Deploy the Swin Transformer on TensorRT. (#652)

* resolve conflicts

* update ut and docs

* fix ut

* refine docstring

* add comments and refine UT

* resolve comments

* resolve comments

* update doc

* add roll export

* check backend

* update regression test

* bump version to 0.6.0 (#680)

* bump vertion to 0.6.0

* update version

* pass img_metas while exporting to onnx (#681)

* pass img_metas while exporting to onnx

* remove try-catch in tools for beter debugging

* use get

* fix typo

* [Fix] fix ssd ncnn ut (#692)

* fix ssd ncnn ut

* fix yapf

* fix passing img_metas to pytorch2onnx for mmedit (#700)

* fix passing img_metas for mmdet3d (#707)

* [Fix] Fix android build (#698)

* fix android build

* fix cmake

* fix url link

* fix wrong exit code in pipeline_manager (#715)

* fix exit

* change to general exit errorcode=1

* fix passing wrong backend type (#719)

* Rename onnx2ncnn to mmdeploy_onnx2ncnn (#694)

* improvement(tools/onnx2ncnn.py): rename to mmdeploy_onnx2ncnn

* format(tools/deploy.py): clean code

* fix(init_plugins.py): improve if condition

* fix(CI): update target

* fix(test_onnx2ncnn.py): update desc

* Update init_plugins.py

* [Fix] Fix mmdet ort static shape bug (#687)

* fix shape

* add device

* fix yapf

* fix rewriter for transforms

* reverse image shape

* fix ut of distance2bbox

* fix rewriter name

* fix c4 for torchscript (#724)

* [Enhancement] Standardize C API (#634)

* unify C API naming

* fix demo and move apis/c/* -> apis/c/mmdeploy/*

* fix lint

* fix C# project

* fix Java API

* [Enhancement] Support Slide Vertex TRT (#650)

* reorgnize mmrotate

* fix

* add hbb2obb

* add ut

* fix rotated nms

* update docs

* update benchmark

* update test

* remove ort regression test, remove comment

* Fix get-started rendering issues in readthedocs (#740)

* fix mermaid markdown rendering issue in readthedocs

* fix error in C++ example

* fix error in c++ example in zh_cn get_started doc

* [Fix] set default topk for dump info (#702)

* set default topk for dump info

* remove redundant docstrings

* add ci densenet

* fix classification warnings

* fix mmcls version

* fix logger.warnings

* add version control (#754)

* fix satrn for ORT (#753)

* fix satrn for ORT

* move rewrite into pytorch

* Add inference latency test tool (#665)

* add profile tool

* remove print envs in profile tool

* set cudnn_benchmark to True

* add doc

* update tests

* fix typo

* support test with images from a directory

* update doc

* resolve comments

* [Enhancement] Add CSE ONNX pass (#647)

* Add fuse select assign pass

* move code to csrc

* add config flag

* Add fuse select assign pass

* Add CSE for ONNX

* remove useless code

* Test robot

Just test robot

* Update README.md

Revert

* [Fix] fix yolox point_generator (#758)

* fix yolox point_generator

* add a UT

* resolve comments

* fix comment lines

* limit markdown version (#773)

* [Enhancement] Better index put ONNX export. (#704)

* Add rewriter for tensor setitem

* add version check

* Upgrade Dockerfile to use TensorRT==8.2.4.2 (#706)

* Upgrade TensorRT to 8.2.4.2

* upgrade pytorch&mmcv in CPU Dockerfile

* Delete redundant port example in Docker

* change 160x160-608x608 to 64x64-608x608 for yolov3

* [Fix] reduce log verbosity & improve error reporting (#755)

* reduce log verbosity & improve error reporting

* improve error reporting

* [Enhancement] Support latest ppl.nn & ppl.cv (#564)

* support latest ppl.nn

* fix pplnn for model convertor

* fix lint

* update memory policy

* import algo from buffer

* update ppl.cv

* use `ppl.cv==0.7.0`

* document supported ppl.nn version

* skip pplnn dependency when building shared libs

* [Fix][P0] Fix for torch1.12 (#751)

* fix for torch1.12

* add comment

* fix check env (#785)

* [Fix] fix cascade mask rcnn (#787)

* fix cascade mask rcnn

* fix lint

* add regression

* [Feature] Support RoITransRoIHead (#713)

* [Feature] Support RoITransRoIHead

* Add docs

* Add mmrotate models regression test

* Add a draft for test code

* change the argument name

* fix test code

* fix minor change for not class agnostic case

* fix sample for test code

* fix sample for test code

* Add mmrotate in requirements

* Revert "Add mmrotate in requirements"

This reverts commit 043490075e6dbe4a8fb98e94b2b583b91fc5038d.

* [Fix] fix triu (#792)

* fix triu

* triu -> triu_default

* [Enhancement] Install Optimizer by setuptools (#690)

* Add fuse select assign pass

* move code to csrc

* add config flag

* Add fuse select assign pass

* Add CSE for ONNX

* remove useless code

* Install optimizer by setup tools

* fix comment

* [Feature] support MMRotate model with le135 (#788)

* support MMRotate model with le135

* cse before fuse select assign

* remove unused import

* [Fix] Support macOS build (#762)

* fix macOS build

* fix missing

* add option to build & install examples (#822)

* [Fix] Fix setup on non-linux-x64 (#811)

* fix setup

* replace long to int64_t

* [Feature] support build single sdk library (#806)

* build single lib for c api

* update csharp doc & project

* update test build

* fix test build

* fix

* update document for building android sdk (#817)

Co-authored-by: dwSun <dwsunny@icloud.com>

* [Enhancement] support kwargs in SDK python bindings (#794)

* support-kwargs

* make '__call__' as single image inference and add 'batch' API to deal with batch images inference

* fix linting error and typo

* fix lint

* improvement(sdk): add sdk code coverage (#808)

* feat(doc): add CI

* CI(sdk): add sdk coverage

* style(test): code format

* fix(CI): update coverage.info path

* improvement(CI): use internal image

* improvement(CI): push coverage info once

* [Feature] Add C++ API for SDK (#831)

* add C++ API

* unify result type & add examples

* minor fix

* install cxx API headers

* fix Mat, add more examples

* fix monolithic build & fix lint

* install examples correctly

* fix lint

* feat(tools/deploy.py): support snpe (#789)

* fix(tools/deploy.py): support snpe

* improvement(backend/snpe): review advices

* docs(backend/snpe): update build

* docs(backend/snpe): server support specify port

* docs(backend/snpe): update path

* fix(backend/snpe): time counter missing argument

* docs(backend/snpe): add missing argument

* docs(backend/snpe): update download and using

* improvement(snpe_net.cpp): load model with modeldata

* Support setup on environment with no PyTorch (#843)

* support test with multi batch (#829)

* support test with multi batch

* resolve comment

* import algorithm from buffer (#793)

* [Enhancement] build sdk python api in standard-alone manner (#810)

* build sdk python api in standard-alone manner

* enable MMDEPLOY_BUILD_SDK_MONOLITHIC and MMDEPLOY_BUILD_EXAMPLES in prebuild config

* link mmdeploy to python target when monolithic option is on

* checkin README to describe precompiled package build procedure

* use packaging.version.parse(python_version) instead of list(python_version)

* fix according to review results

* rebase master

* rollback cmake.in and apis/python/CMakeLists.txt

* reorganize files in install/example

* let cmake detect visual studio instead of specifying 2019

* rename whl name of precompiled package

* fix according to review results

* Fix SDK backend (#844)

* fix mmpose python api (#852)

* add prebuild package usage docs on windows (#816)

* add prebuild package usage docs on windows

* fix lint

* update

* try fix lint

* add en docs

* update

* update

* udpate faq

* fix typo (#862)

* [Enhancement] Improve get_started documents and bump version to 0.7.0 (#813)

* simplify commands in get_started

* add installation commands for Windows

* fix typo

* limit markdown and sphinx_markdown_tables version

* adopt html <details open> tag

* bump mmdeploy version

* bump mmdeploy version

* update get_started

* update get_started

* use python3.8 instead of python3.7

* remove duplicate section

* resolve issue #856

* update according to review results

* add reference to prebuilt_package_windows.md

* fix error when build sdk demos

* fix mmcls

Co-authored-by: Ryan_Huang <44900829+DrRyanHuang@users.noreply.github.com>
Co-authored-by: Chen Xin <xinchen.tju@gmail.com>
Co-authored-by: q.yao <yaoqian@sensetime.com>
Co-authored-by: zytx121 <592267829@qq.com>
Co-authored-by: Li Zhang <lzhang329@gmail.com>
Co-authored-by: tripleMu <gpu@163.com>
Co-authored-by: tripleMu <865626@163.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: Bryan Glen Suello <11388006+bgsuello@users.noreply.github.com>
Co-authored-by: zambranohally <63218980+zambranohally@users.noreply.github.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: Hakjin Lee <nijkah@gmail.com>
Co-authored-by: 孙德伟 <5899962+dwSun@users.noreply.github.com>
Co-authored-by: dwSun <dwsunny@icloud.com>
Co-authored-by: Chen Xin <irexyc@gmail.com>
2022-08-19 09:30:13 +08:00

1099 lines
42 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import onnx
import pytest
import torch
import torch.nn as nn
from mmcv import Config
from onnx.helper import (make_graph, make_model, make_node,
make_tensor_value_info)
from mmdeploy.core import RewriterContext
from mmdeploy.utils.test import WrapFunction, assert_allclose
from .utils import TestNCNNExporter, TestOnnxRTExporter, TestTensorRTExporter
TEST_ONNXRT = TestOnnxRTExporter()
TEST_TENSORRT = TestTensorRTExporter()
TEST_NCNN = TestNCNNExporter()
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('pool_h,pool_w,spatial_scale,sampling_ratio',
[(2, 2, 1.0, 2), (4, 4, 2.0, 4)])
def test_roi_align(backend,
pool_h,
pool_w,
spatial_scale,
sampling_ratio,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand(1, 1, 16, 16, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
else:
input = torch.tensor(input_list[0], dtype=torch.float32)
single_roi = torch.tensor(input_list[1], dtype=torch.float32)
from mmcv.ops import roi_align
def wrapped_function(torch_input, torch_rois):
return roi_align(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, 'avg', True)
wrapped_model = WrapFunction(wrapped_function).eval()
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
wrapped_model, [input, single_roi],
'roi_align',
input_names=['input', 'rois'],
output_names=['roi_feat'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT])
@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
def test_grid_sample(backend,
mode,
padding_mode,
align_corners,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand(1, 1, 10, 10)
else:
input = torch.tensor(input_list[0])
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(
grid, (1, 1, input.shape[2] * 2, input.shape[3] * 2)).type_as(input)
def wrapped_function(inputs, grid):
return nn.functional.grid_sample(
inputs,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)
wrapped_model = WrapFunction(wrapped_function).eval()
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
wrapped_model, [input, grid],
'grid_sampler',
input_names=['input', 'grid'],
output_names=['output'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('dynamic_export', [True, False])
@pytest.mark.parametrize('mode', ['bicubic', 'nearest'])
@pytest.mark.parametrize('align_corners', [True, False])
@pytest.mark.parametrize('output_size', [[10, 20], None])
@pytest.mark.parametrize('scale_factor', [2])
@pytest.mark.parametrize('n, c, h, w', [(2, 3, 5, 10)])
def test_bicubic_interpolate(backend,
dynamic_export,
mode,
align_corners,
output_size,
scale_factor,
n,
c,
h,
w,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.randn(n, c, h, w)
if dynamic_export:
dynamic_axes = {
'input': {
0: 'n',
2: 'h',
3: 'w',
},
'output': {
0: 'n',
2: 'h',
3: 'w',
},
}
else:
dynamic_axes = None
if mode == 'nearest':
align_corners = None
if output_size is None:
resize = nn.Upsample(
scale_factor=scale_factor, mode=mode, align_corners=align_corners)
else:
resize = nn.Upsample(
size=output_size, mode=mode, align_corners=align_corners)
expected_result = resize(input).cuda()
wrapped_model = WrapFunction(resize).eval()
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [input],
'bicubic_interpolate',
input_names=['input'],
dynamic_axes=dynamic_axes,
output_names=['output'],
save_dir=save_dir,
expected_result=expected_result)
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT])
@pytest.mark.parametrize('in_channels,out_channels,stride,padding,'
'dilation,groups,deform_groups,kernel_size',
[(3, 64, 1, 0, 1, 1, 1, 3),
(1, 32, 3, 2, 1, 1, 1, 3)])
@pytest.mark.parametrize('bias', [True, False])
def test_modulated_deform_conv(backend,
in_channels,
out_channels,
stride,
padding,
dilation,
groups,
deform_groups,
kernel_size,
bias,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand(
1, in_channels, 28, 28, requires_grad=False) # (n, c, h, w)
else:
input = torch.tensor(input_list[0])
conv_offset = nn.Conv2d(
in_channels=in_channels,
out_channels=deform_groups * 3 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True)
out = conv_offset(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
from mmcv.ops import ModulatedDeformConv2d
model = ModulatedDeformConv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups,
deform_groups, bias).eval()
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
model, [input, offset, mask],
'modulated_deform_conv',
input_names=['input', 'offset', 'mask'],
output_names=['output'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('in_channels,out_channels,stride,padding,'
'dilation,groups,deform_groups,kernel_size',
[(3, 64, 1, 0, 1, 1, 1, 3),
(1, 32, 3, 2, 1, 1, 1, 3)])
def test_deform_conv(backend,
in_channels,
out_channels,
stride,
padding,
dilation,
groups,
deform_groups,
kernel_size,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand(
1, in_channels, 28, 28, requires_grad=False) # (n, c, h, w)
else:
input = torch.tensor(input_list[0])
conv_offset = nn.Conv2d(
in_channels=in_channels,
out_channels=deform_groups * 2 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True)
offset = conv_offset(input)
from mmcv.ops import DeformConv2d
model = DeformConv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, deform_groups).eval()
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
model, [input, offset],
'deform_conv',
input_names=['input', 'offset'],
output_names=['output'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('dynamic_export', [True, False])
@pytest.mark.parametrize('fp16_mode', [True, False])
@pytest.mark.parametrize('n, c, h, w', [(2, 3, 10, 10)])
def test_instance_norm(backend,
dynamic_export,
fp16_mode,
n,
c,
h,
w,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.randn(n, c, h, w)
if dynamic_export:
dynamic_axes = {
'input': {
0: 'n',
2: 'h',
3: 'w',
},
'output': {
0: 'n',
2: 'h',
3: 'w',
},
}
else:
dynamic_axes = None
norm = nn.InstanceNorm2d(c, affine=True)
wrapped_model = WrapFunction(norm).eval()
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [input],
'instance_norm',
input_names=['input'],
dynamic_axes=dynamic_axes,
output_names=['output'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('num_classes,pre_topk,after_topk,iou_threshold,'
'score_threshold,background_label_id',
[(5, 6, 3, 0.7, 0.1, -1)])
def test_batched_nms(backend,
num_classes,
pre_topk,
after_topk,
iou_threshold,
score_threshold,
background_label_id,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
nms_boxes = torch.tensor([[[291.1746, 316.2263, 343.5029, 347.7312],
[288.4846, 315.0447, 343.7267, 346.5630],
[288.5307, 318.1989, 341.6425, 349.7222],
[918.9102, 83.7463, 933.3920, 164.9041],
[895.5786, 78.2361, 907.8049, 172.0883],
[292.5816, 316.5563, 340.3462, 352.9989],
[609.4592, 83.5447, 631.2532, 144.0749],
[917.7308, 85.5870, 933.2839, 168.4530],
[895.5138, 79.3596, 908.2865, 171.0418],
[291.4747, 318.6987, 347.1208, 349.5754]]])
scores = torch.tensor([[[0.9577, 0.9745, 0.3030, 0.6589, 0.2742],
[0.1618, 0.7963, 0.5124, 0.6964, 0.6850],
[0.8425, 0.4843, 0.9489, 0.8068, 0.7340],
[0.7337, 0.4340, 0.9923, 0.0704, 0.4506],
[0.3090, 0.5606, 0.6939, 0.3764, 0.6920],
[0.0044, 0.7986, 0.2221, 0.2782, 0.4378],
[0.7293, 0.2735, 0.8381, 0.0264, 0.6278],
[0.7144, 0.1066, 0.4125, 0.4041, 0.8819],
[0.4963, 0.7891, 0.6908, 0.1499, 0.5584],
[0.4385, 0.6035, 0.0508, 0.0662, 0.5938]]])
else:
nms_boxes = torch.tensor(input_list[0], dtype=torch.float32)
scores = torch.tensor(input_list[1], dtype=torch.float32)
from mmdeploy.codebase.mmdet.core.post_processing import _multiclass_nms
expected_result = _multiclass_nms(
nms_boxes,
scores,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_topk + 1,
keep_top_k=after_topk + 1)
expected_result = (expected_result[0][:,
0:-1, :], expected_result[1][:,
0:-1])
boxes = nms_boxes.unsqueeze(2).tile(num_classes, 1)
from mmdeploy.mmcv.ops.nms import TRTBatchedNMSop
batched_nms = TRTBatchedNMSop.apply
def wrapped_function(boxes, scores):
return batched_nms(boxes, scores, num_classes, pre_topk, after_topk,
iou_threshold, score_threshold, background_label_id)
wrapped_model = WrapFunction(wrapped_function)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [boxes, scores],
'batched_nms',
input_names=['boxes', 'scores'],
output_names=['batched_nms_bboxes', 'inds'],
expected_result=expected_result,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('num_classes,pre_topk,after_topk,iou_threshold,'
'score_threshold,background_label_id',
[(5, 6, 3, 0.7, 0.1, -1)])
def test_batched_rotated_nms(backend,
num_classes,
pre_topk,
after_topk,
iou_threshold,
score_threshold,
background_label_id,
input_list=None,
save_dir=None):
backend.check_env()
pytest.importorskip('mmrotate', reason='mmrorate is not installed.')
if input_list is None:
nms_boxes = torch.tensor(
[[[291.1746, 316.2263, 343.5029, 347.7312, 1.],
[288.4846, 315.0447, 343.7267, 346.5630, 2.],
[288.5307, 318.1989, 341.6425, 349.7222, 3.],
[918.9102, 83.7463, 933.3920, 164.9041, 4.],
[895.5786, 78.2361, 907.8049, 172.0883, 5.],
[292.5816, 316.5563, 340.3462, 352.9989, 6.],
[609.4592, 83.5447, 631.2532, 144.0749, 7.],
[917.7308, 85.5870, 933.2839, 168.4530, 8.],
[895.5138, 79.3596, 908.2865, 171.0418, 9.],
[291.4747, 318.6987, 347.1208, 349.5754, 10.]]])
scores = torch.tensor([[[0.9577, 0.9745, 0.3030, 0.6589, 0.2742],
[0.1618, 0.7963, 0.5124, 0.6964, 0.6850],
[0.8425, 0.4843, 0.9489, 0.8068, 0.7340],
[0.7337, 0.4340, 0.9923, 0.0704, 0.4506],
[0.3090, 0.5606, 0.6939, 0.3764, 0.6920],
[0.0044, 0.7986, 0.2221, 0.2782, 0.4378],
[0.7293, 0.2735, 0.8381, 0.0264, 0.6278],
[0.7144, 0.1066, 0.4125, 0.4041, 0.8819],
[0.4963, 0.7891, 0.6908, 0.1499, 0.5584],
[0.4385, 0.6035, 0.0508, 0.0662, 0.5938]]])
else:
nms_boxes = torch.tensor(input_list[0], dtype=torch.float32)
scores = torch.tensor(input_list[1], dtype=torch.float32)
from mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms import \
_multiclass_nms_rotated
expected_result = _multiclass_nms_rotated(
nms_boxes,
scores,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_topk + 1,
keep_top_k=after_topk + 1)
expected_result = (expected_result[0][:,
0:-1, :], expected_result[1][:,
0:-1])
boxes = nms_boxes.unsqueeze(2).tile(num_classes, 1)
from mmdeploy.mmcv.ops.nms_rotated import TRTBatchedRotatedNMSop
batched_rotated_nms = TRTBatchedRotatedNMSop.apply
def wrapped_function(boxes, scores):
return batched_rotated_nms(boxes, scores, num_classes, pre_topk,
after_topk, iou_threshold, score_threshold,
background_label_id)
wrapped_model = WrapFunction(wrapped_function)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [boxes, scores],
'batched_rotated_nms',
input_names=['boxes', 'scores'],
output_names=['batched_rotated_nms_bboxes', 'inds'],
expected_result=expected_result,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize(
'out_size, pool_mode, sampling_ratio,roi_scale_factor,'
' finest_scale,featmap_strides, aligned',
[(tuple([2, 2]), 0, 2, 1.0, 2, list([2.0, 4.0]), 1),
(tuple([2, 2]), 1, 2, 1.0, 2, list([2.0, 4.0]), 1)])
def test_multi_level_roi_align(backend,
out_size,
pool_mode,
sampling_ratio,
roi_scale_factor,
finest_scale,
featmap_strides,
aligned,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = [
torch.tensor([[[[0.3014, 0.7334, 0.6502, 0.1689],
[0.3031, 0.3735, 0.6032, 0.1644],
[0.0393, 0.4415, 0.3858, 0.2657],
[0.5766, 0.0211, 0.6384, 0.0016]],
[[0.0811, 0.6255, 0.0247, 0.3471],
[0.1390, 0.9298, 0.6178, 0.6636],
[0.2243, 0.2024, 0.2366, 0.3660],
[0.1050, 0.2301, 0.7489, 0.7506]],
[[0.3868, 0.1706, 0.2390, 0.8494],
[0.2643, 0.9347, 0.0412, 0.5790],
[0.6202, 0.0682, 0.0390, 0.5296],
[0.5383, 0.1221, 0.6344, 0.1514]]]]),
torch.tensor([[[[0.1939, 0.9983, 0.4031, 0.2712],
[0.7929, 0.1504, 0.0946, 0.5030],
[0.1421, 0.7908, 0.9595, 0.4198],
[0.6880, 0.4722, 0.9896, 0.2266]],
[[0.0778, 0.4232, 0.0736, 0.0168],
[0.2887, 0.8461, 0.1140, 0.9582],
[0.5169, 0.4924, 0.8275, 0.5530],
[0.8961, 0.7466, 0.5976, 0.3760]],
[[0.1542, 0.5028, 0.8412, 0.6617],
[0.3751, 0.2798, 0.3835, 0.8640],
[0.5821, 0.6588, 0.1324, 0.7619],
[0.9178, 0.7282, 0.0291, 0.3028]]]])
]
rois = torch.tensor([[0., 0., 0., 4., 4.]])
if pool_mode == 1:
expected_result = torch.tensor([[[[0.1939, 0.3950],
[0.3437, 0.4543]],
[[0.0778, 0.1641],
[0.1305, 0.2301]],
[[0.1542, 0.2413],
[0.2094, 0.2688]]]])
else:
expected_result = torch.tensor([[[[0.1939, 0.4956],
[0.4185, 0.5167]],
[[0.0778, 0.2073],
[0.1569, 0.3162]],
[[0.1542, 0.2849],
[0.2370, 0.3053]]]])
else:
input = input_list[0]
rois = input_list[1]
expected_result = input_list[2]
input_name = [('input_' + str(i)) for i in range(len(featmap_strides))]
input_name.insert(0, 'rois')
inputs = [
onnx.helper.make_tensor_value_info(
input_name[i + 1], onnx.TensorProto.FLOAT, shape=input[i].shape)
for i in range(len(input_name) - 1)
]
inputs.append(
onnx.helper.make_tensor_value_info(
'rois', onnx.TensorProto.FLOAT, shape=rois.shape))
outputs = [
onnx.helper.make_tensor_value_info(
'bbox_feats', onnx.TensorProto.FLOAT, shape=expected_result.shape)
]
node = onnx.helper.make_node(
'MMCVMultiLevelRoiAlign',
input_name, ['bbox_feats'],
'MMCVMultiLevelRoiAlign_0',
None,
'mmdeploy',
pool_mode=pool_mode,
aligned=aligned,
featmap_strides=featmap_strides,
finest_scale=finest_scale,
output_height=out_size[0],
output_width=out_size[1],
roi_scale_factor=roi_scale_factor,
sampling_ratio=sampling_ratio)
graph = onnx.helper.make_graph([node], 'torch-jit-export', inputs, outputs)
onnx_model = onnx.helper.make_model(
graph, producer_name='pytorch', producer_version='1.8')
onnx_model.opset_import[0].version = 11
onnx_model.opset_import.append(
onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmdeploy', version=1))
backend.run_and_validate(
onnx_model, [rois, *input],
'multi_level_roi_align',
input_names=input_name,
output_names=['bbox_feats'],
expected_result=expected_result,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('k', [1, 3, 5])
@pytest.mark.parametrize('dim', [1, 2, 3])
@pytest.mark.parametrize('largest', [True, False])
@pytest.mark.parametrize('sorted', [True, False])
def test_topk(backend,
k,
dim,
largest,
sorted,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand(1, 8, 12, 17)
else:
input = input_list[0]
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
but got {input.shape[0]}')
def topk_function(inputs):
return torch.Tensor.topk(inputs, k, dim, largest, sorted)
wrapped_model = WrapFunction(topk_function)
# when the 'sorted' attribute is False, pytorch will return
# a hard to expect result, which only features that the topk
# number is right. So the Topk unittest only check whether the
# topk elements are right, all the possible order will be accepted.
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
if not sorted:
backend.run_and_validate(
wrapped_model, [input.float()],
'topk' + f'_no_sorted_dim_{dim}',
input_names=['inputs'],
output_names=['data', 'index'],
save_dir=save_dir)
else:
backend.run_and_validate(
wrapped_model, [input.float()],
'topk',
input_names=['inputs'],
output_names=['data', 'index'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('dim, n, c, h, w', [(1, 1, 1, 1, 8), (2, 1, 1, 5, 7),
(3, 1, 3, 10, 15)])
def test_shape(backend,
dim,
n,
c,
h,
w,
input_names=['input'],
output_names=['output'],
tolerate_small_mismatch=False,
input_list=None,
save_dir=None):
backend.check_env()
orig_shape = (n, c, h, w)[-dim - 1:]
if input_list is None:
input = torch.rand(orig_shape)
else:
input = input_list[0]
assert input.dim() == dim + 1, 'input.dim() must equal to dim + 1'
assert tuple(input.shape) == orig_shape, 'input.shape must the \
same as orig_shape'
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
but got {input.shape[0]}')
shape_node = make_node('Shape', input_names, output_names)
assert len(input_names) == 1, 'length of input_names must be 1'
assert len(output_names) == 1, 'length of output_names must be 1'
shape_graph = make_graph([shape_node], 'shape_graph', [
make_tensor_value_info(input_names[0], onnx.TensorProto.FLOAT,
orig_shape)
], [
make_tensor_value_info(output_names[0], onnx.TensorProto.FLOAT,
(dim + 1, ))
])
shape_model = make_model(shape_graph)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
ncnn_model = backend.onnx2ncnn(shape_model, 'shape', output_names,
save_dir)
# ncnn mat has implicit batch for mat, the ncnn_output is a mat,
# so the ncnn_outputs has 2 dimensions, not 1.
model_outputs = [torch.tensor(orig_shape).unsqueeze(0).float()]
ncnn_outputs = ncnn_model(dict(zip(input_names, [input])))
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('dim, n, c, h, w', [(1, 1, 1, 1, 8), (2, 1, 1, 5, 7),
(3, 1, 3, 10, 15)])
@pytest.mark.parametrize('val', [0., 1., -3, 4.25])
def test_constantofshape(backend,
dim,
n,
c,
h,
w,
val,
input_names=['input'],
output_names=['output'],
tolerate_small_mismatch=False,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.tensor((n, c, h, w)[-dim - 1:]).unsqueeze(0)
else:
input = input_list[0]
assert input.dim() == dim + 1, 'input.dim() must equal to dim + 1'
assert tuple(input.shape) == (n, c, h,
w)[-dim - 1:], 'input.shape must the \
same as orig_shape'
assert input.shape[0] == 1, (f'ncnn input batch must be 1, \
got {input.shape[0]}')
assert input[0][0] == 1, (f'ncnn output mat batch must be 1, \
got {input[0][0]}')
constantofshape_node = make_node(
'ConstantOfShape', input_names, output_names, value=float(val))
assert len(input_names) == 1, 'length of input_names must be 1'
assert len(output_names) == 1, 'length of output_names must be 1'
constantofshape_graph = make_graph(
[constantofshape_node], 'constantofshape_graph', [
make_tensor_value_info(input_names[0], onnx.TensorProto.FLOAT,
input.shape)
], [
make_tensor_value_info(output_names[0], onnx.TensorProto.FLOAT,
torch.Size(input[0]))
])
constantofshape_model = make_model(constantofshape_graph)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
ncnn_model = backend.onnx2ncnn(constantofshape_model,
'constantofshape', output_names,
save_dir)
# ncnn mat has implicit batch for mat, the ncnn_output is a mat,
# so the ncnn_outputs has 2 dimensions, not 1.
model_outputs = [torch.fill_(torch.rand(tuple(input[0])), val)]
ncnn_outputs = ncnn_model(dict(zip(input_names, [input.float()])))
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('axis, data_dims, indice_dims', [(0, 1, 1), (0, 2, 1),
(1, 2, 1), (0, 3, 1),
(1, 3, 1),
(2, 3, 1)])
def test_gather(backend,
axis,
data_dims,
indice_dims,
input_names=['input', 'indices'],
output_names=['output'],
tolerate_small_mismatch=False,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
# the real data dims is data_dims + 1
data = torch.rand((8, 12, 17)[-data_dims:]).unsqueeze(0)
indice = torch.randint(0, 8, (3, 4, 5)[-indice_dims:]).unsqueeze(0)
else:
data = input_list[0]
indice = input_list[1]
assert data.shape[0] == 1, (f'ncnn batch must be 1, \
but got {data.shape[0]}')
assert indice.shape[0] == 1, (f'ncnn batch must be 1, \
but got {indice.shape[0]}')
gather_node = make_node('Gather', input_names, output_names, axis=axis + 1)
gather_graph = make_graph([gather_node], 'gather_graph', [
make_tensor_value_info(input_names[0], onnx.TensorProto.FLOAT, None),
make_tensor_value_info(input_names[1], onnx.TensorProto.INT64, None)
], [make_tensor_value_info(output_names[0], onnx.TensorProto.FLOAT, None)])
gather_model = make_model(gather_graph)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
ncnn_model = backend.onnx2ncnn(gather_model, 'gather', output_names,
save_dir)
# ncnn mat has implicit batch for mat, the ncnn_output is a mat,
# so the ncnn_outputs has 2 dimensions, not 1.
import importlib
import onnxruntime
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
not installed.'
import numpy as np
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
model_outputs = session.run(
output_names,
dict(
zip(input_names, [
np.array(data, dtype=np.float32),
np.array(indice[0], dtype=np.int64)
])))
model_outputs = [model_output for model_output in model_outputs]
ncnn_outputs = ncnn_model(
dict(zip(input_names, [data.float(), indice.float()])))
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('dim', [1, 2, 3])
def test_tensorslice(backend, dim, input_list=None, save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand((8, 12, 17)[-dim:]).unsqueeze(0)
else:
input = input_list[0]
assert input.dim() == dim + 1, f'input.dim() must equal to \
dim + 1, expected: {dim + 1}, got: {input.dim()}'
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
but got {input.shape[0]}')
def tensorslice_function(inputs):
if dim == 1:
return inputs[:, 2:17:7]
if dim == 2:
return inputs[:, 3:12:4, 2:15:3]
if dim == 3:
return inputs[:, 0:8:2, 2:12:4, 2:17:7]
wrapped_model = WrapFunction(tensorslice_function)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [input.float()],
'tensorslice',
input_names=['inputs'],
output_names=['outputs'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('input_dim, output_dim', [(1, 1), (1, 2), (1, 3),
(2, 2), (2, 3), (3, 3)])
def test_expand(backend,
input_dim,
output_dim,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand((1, 12, 1)[-input_dim:]).unsqueeze(0)
target = torch.rand((8, 12, 17)[-output_dim:]).unsqueeze(0)
else:
input = input_list[0]
target = input_list[1]
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
but not {input.shape[0]}')
assert target.shape[0] == 1, (f'ncnn batch must be 1, \
but not {target.shape[0]}')
def expand_function(input, target):
return input.expand_as(target)
wrapped_model = WrapFunction(expand_function)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [input.float(), target.float()],
'expand',
input_names=['input', 'shape'],
output_names=['output'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_ONNXRT])
@pytest.mark.parametrize('iou_threshold', [0.1, 0.3])
@pytest.mark.parametrize('score_threshold', [0., 0.1])
def test_nms_rotated(backend, iou_threshold, score_threshold, save_dir=None):
backend.check_env()
boxes = torch.tensor(
[[[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]],
[[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]]],
dtype=torch.float32)
scores = torch.tensor(
[[[0.5, 0.1, 0.1], [0.1, 0.6, 0.1], [0.1, 0.1, 0.7], [0.1, 0.1, 0.1]],
[[0.1, 0.1, 0.1], [0.7, 0.1, 0.1], [0.1, 0.6, 0.1], [0.1, 0.1, 0.5]]],
dtype=torch.float32)
from mmdeploy.mmcv.ops import ONNXNMSRotatedOp
def wrapped_function(torch_boxes, torch_scores):
return ONNXNMSRotatedOp.apply(torch_boxes, torch_scores, iou_threshold,
score_threshold)
wrapped_model = WrapFunction(wrapped_function).eval()
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
wrapped_model, [boxes, scores],
'nms_rotated',
input_names=['boxes', 'scores'],
output_names=['keep_inds'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_ONNXRT])
@pytest.mark.parametrize('pool_h,pool_w,spatial_scale,sampling_ratio',
[(2, 2, 1.0, 2), (4, 4, 2.0, 4)])
def test_roi_align_rotated(backend,
pool_h,
pool_w,
spatial_scale,
sampling_ratio,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
# input = torch.rand(1, 1, 16, 16, dtype=torch.float32)
input = torch.tensor([[[[1., 2.], [3., 4.]]]], dtype=torch.float32)
single_roi = torch.tensor([[0., 0.5, 0.5, 1., 1., 0]],
dtype=torch.float32)
else:
input = torch.tensor(input_list[0], dtype=torch.float32)
single_roi = torch.tensor(input_list[1], dtype=torch.float32)
from mmcv.ops import roi_align_rotated
def wrapped_function(torch_input, torch_rois):
return roi_align_rotated(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, True, False)
wrapped_model = WrapFunction(wrapped_function).eval()
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
wrapped_model, [input, single_roi],
'roi_align_rotated',
input_names=['input', 'rois'],
output_names=['roi_feat'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize(
'out_size, clockwise, sampling_ratio, roi_scale_factor,'
' finest_scale, featmap_strides, aligned',
[(tuple([2, 2]), False, 2, 1.0, 2, list([1.0]), 1)])
def test_multi_level_rotated_roi_align(backend,
out_size,
clockwise,
sampling_ratio,
roi_scale_factor,
finest_scale,
featmap_strides,
aligned,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
import numpy as np
input = [
torch.tensor([[[[1., 2., 5., 6.], [3., 4., 7., 8.],
[9., 10., 13., 14.], [11., 12., 15., 16.]]]])
]
rois = torch.tensor([[0., 1.5, 1.5, 3., 3., np.pi / 2]])
expected_result = torch.tensor([[[[7.5625, 1.9375], [10.375, 4.75]]]])
else:
input = input_list[0]
rois = input_list[1]
expected_result = input_list[2]
input_name = [('input_' + str(i)) for i in range(len(featmap_strides))]
input_name.insert(0, 'rois')
inputs = [
onnx.helper.make_tensor_value_info(
input_name[i + 1], onnx.TensorProto.FLOAT, shape=input[i].shape)
for i in range(len(input_name) - 1)
]
inputs.append(
onnx.helper.make_tensor_value_info(
'rois', onnx.TensorProto.FLOAT, shape=rois.shape))
outputs = [
onnx.helper.make_tensor_value_info(
'bbox_feats', onnx.TensorProto.FLOAT, shape=expected_result.shape)
]
node = onnx.helper.make_node(
'MMCVMultiLevelRotatedRoiAlign',
input_name, ['bbox_feats'],
'MMCVMultiLevelRotatedRoiAlign_0',
None,
'mmdeploy',
featmap_strides=featmap_strides,
finest_scale=finest_scale,
output_height=out_size[0],
output_width=out_size[1],
clockwise=clockwise,
roi_scale_factor=roi_scale_factor,
sampling_ratio=sampling_ratio,
aligned=aligned)
graph = onnx.helper.make_graph([node], 'torch-jit-export', inputs, outputs)
onnx_model = onnx.helper.make_model(
graph, producer_name='pytorch', producer_version='1.8')
onnx_model.opset_import[0].version = 11
onnx_model.opset_import.append(
onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmdeploy', version=1))
backend.run_and_validate(
onnx_model, [rois, *input],
'multi_level_rotated_roi_align',
input_names=input_name,
output_names=['bbox_feats'],
expected_result=expected_result,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('strides', [(4, 4)])
def test_trt_grid_priors(backend, strides, input_list=None, save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand(1, 3, 2, 2)
base_anchors = torch.tensor([[-22.6274, -11.3137, 22.6274, 11.3137],
[-16.0000, -16.0000, 16.0000, 16.0000],
[-11.3137, -22.6274, 11.3137, 22.6274]])
expected_result = torch.tensor([[-22.6274, -11.3137, 22.6274, 11.3137],
[-16.0000, -16.0000, 16.0000, 16.0000],
[-11.3137, -22.6274, 11.3137, 22.6274],
[-18.6274, -11.3137, 26.6274, 11.3137],
[-12.0000, -16.0000, 20.0000, 16.0000],
[-7.3137, -22.6274, 15.3137, 22.6274],
[-22.6274, -7.3137, 22.6274, 15.3137],
[-16.0000, -12.0000, 16.0000, 20.0000],
[-11.3137, -18.6274, 11.3137, 26.6274],
[-18.6274, -7.3137, 26.6274, 15.3137],
[-12.0000, -12.0000, 20.0000, 20.0000],
[-7.3137, -18.6274, 15.3137, 26.6274]])
else:
input = input_list[0]
base_anchors = input_list[1]
expected_result = input_list[2]
input_name = ['input']
output_name = ['output']
class GridPriorsTestOps(torch.autograd.Function):
@staticmethod
def forward(ctx, base_anchor, feat_h, feat_w, stride_h: int,
stride_w: int):
a = base_anchor.shape[0]
return base_anchor.new_empty(feat_h * feat_w * a, 4)
@staticmethod
def symbolic(g, base_anchor, feat_h, feat_w, stride_h: int,
stride_w: int):
from torch.onnx import symbolic_helper
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
zero_h = g.op(
'ConstantOfShape',
feat_h,
value_t=torch.tensor([0], dtype=torch.long),
)
zero_w = g.op(
'ConstantOfShape',
feat_w,
value_t=torch.tensor([0], dtype=torch.long),
)
return g.op(
'mmdeploy::GridPriorsTRT',
base_anchor,
zero_h,
zero_w,
stride_h_i=stride_h,
stride_w_i=stride_w)
class GridPriorsTestModel(torch.nn.Module):
def __init__(self, strides, base_anchors=base_anchors) -> None:
super().__init__()
self.strides = strides
self.base_anchors = base_anchors
def forward(self, x):
base_anchors = self.base_anchors
h, w = x.shape[2:]
strides = self.strides
return GridPriorsTestOps.apply(base_anchors, h, w, strides[0],
strides[1])
model = GridPriorsTestModel(strides=strides)
backend.run_and_validate(
model, [input],
'trt_grid_priors',
input_names=input_name,
output_names=output_name,
expected_result=expected_result,
dynamic_axes=dict(input={
2: 'h',
3: 'w'
}),
save_dir=save_dir)