mmdeploy/tests/test_mmdet/test_core.py
AllentDan 5818095fe0
[Unittests] Add a demo for codebase rewrite part unittests (#89)
* save codes

* add test_model

* save codes

* wrap func

* reformat

* fix lint

* refine docstring

* remove pkl in .gitignore

* add pkl

* apply channel 3

* add function and trt backend rewrite unittest

* fix lint and typo

* add skip condition

* fix typo

* define deploy config inside func and keep ortwrapper original

* speed up and remove ctx

* only inference if no backends

* fix ci

* fix ci

* [Fix] Fix test_calibration (#101)

* fix test calibration

* Modify cuda to cpu

* add tensorrt check

* Revert "[Fix] Fix test_calibration (#101)"

This reverts commit 3f8b8384bfd880538050798d2567f1c137a36174.

Co-authored-by: maningsheng <mnsheng@yeah.net>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
2021-09-28 19:29:58 +08:00

69 lines
2.3 KiB
Python

import importlib
import mmcv
import pytest
import torch
from mmdeploy.mmdet.core.post_processing.bbox_nms import multiclass_nms
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_multiclass_nms_static():
import tensorrt as trt
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
boxes=dict(
min_shape=[1, 500, 4],
opt_shape=[1, 500, 4],
max_shape=[1, 500, 4]),
scores=dict(
min_shape=[1, 500, 80],
opt_shape=[1, 500, 80],
max_shape=[1, 500, 80])))
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
boxes = torch.rand(1, 500, 4).cuda()
scores = torch.rand(1, 500, 80).cuda()
max_output_boxes_per_class = 200
keep_top_k = 100
wrapped_func = WrapFunction(
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'boxes': boxes,
'scores': scores
},
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
'outputs: {}'.format(rewrite_outputs)