mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
|
import importlib
|
||
|
|
||
|
import mmcv
|
||
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmdeploy.mmdet.core.post_processing.bbox_nms import multiclass_nms
|
||
|
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
|
||
|
|
||
|
|
||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||
|
@pytest.mark.skipif(
|
||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||
|
def test_multiclass_nms_static():
|
||
|
|
||
|
import tensorrt as trt
|
||
|
deploy_cfg = mmcv.Config(
|
||
|
dict(
|
||
|
onnx_config=dict(
|
||
|
output_names=['dets', 'labels'], input_shape=None),
|
||
|
backend_config=dict(
|
||
|
type='tensorrt',
|
||
|
common_config=dict(
|
||
|
fp16_mode=False,
|
||
|
log_level=trt.Logger.INFO,
|
||
|
max_workspace_size=1 << 30),
|
||
|
model_inputs=[
|
||
|
dict(
|
||
|
input_shapes=dict(
|
||
|
boxes=dict(
|
||
|
min_shape=[1, 500, 4],
|
||
|
opt_shape=[1, 500, 4],
|
||
|
max_shape=[1, 500, 4]),
|
||
|
scores=dict(
|
||
|
min_shape=[1, 500, 80],
|
||
|
opt_shape=[1, 500, 80],
|
||
|
max_shape=[1, 500, 80])))
|
||
|
]),
|
||
|
codebase_config=dict(
|
||
|
type='mmdet',
|
||
|
task='ObjectDetection',
|
||
|
post_processing=dict(
|
||
|
score_threshold=0.05,
|
||
|
iou_threshold=0.5,
|
||
|
max_output_boxes_per_class=200,
|
||
|
pre_top_k=-1,
|
||
|
keep_top_k=100,
|
||
|
background_label_id=-1,
|
||
|
))))
|
||
|
|
||
|
boxes = torch.rand(1, 500, 4).cuda()
|
||
|
scores = torch.rand(1, 500, 80).cuda()
|
||
|
max_output_boxes_per_class = 200
|
||
|
keep_top_k = 100
|
||
|
wrapped_func = WrapFunction(
|
||
|
multiclass_nms,
|
||
|
max_output_boxes_per_class=max_output_boxes_per_class,
|
||
|
keep_top_k=keep_top_k)
|
||
|
rewrite_outputs = get_rewrite_outputs(
|
||
|
wrapped_func,
|
||
|
model_inputs={
|
||
|
'boxes': boxes,
|
||
|
'scores': scores
|
||
|
},
|
||
|
deploy_cfg=deploy_cfg)
|
||
|
|
||
|
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
||
|
'outputs: {}'.format(rewrite_outputs)
|