mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* save codes * add test_model * save codes * wrap func * reformat * fix lint * refine docstring * remove pkl in .gitignore * add pkl * apply channel 3 * add function and trt backend rewrite unittest * fix lint and typo * add skip condition * fix typo * define deploy config inside func and keep ortwrapper original * speed up and remove ctx * only inference if no backends * fix ci * fix ci * [Fix] Fix test_calibration (#101) * fix test calibration * Modify cuda to cpu * add tensorrt check * Revert "[Fix] Fix test_calibration (#101)" This reverts commit 3f8b8384bfd880538050798d2567f1c137a36174. Co-authored-by: maningsheng <mnsheng@yeah.net> Co-authored-by: Yifan Zhou <singlezombie@163.com>
69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
import importlib
|
|
|
|
import mmcv
|
|
import pytest
|
|
import torch
|
|
|
|
from mmdeploy.mmdet.core.post_processing.bbox_nms import multiclass_nms
|
|
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
|
@pytest.mark.skipif(
|
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
|
def test_multiclass_nms_static():
|
|
|
|
import tensorrt as trt
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
onnx_config=dict(
|
|
output_names=['dets', 'labels'], input_shape=None),
|
|
backend_config=dict(
|
|
type='tensorrt',
|
|
common_config=dict(
|
|
fp16_mode=False,
|
|
log_level=trt.Logger.INFO,
|
|
max_workspace_size=1 << 30),
|
|
model_inputs=[
|
|
dict(
|
|
input_shapes=dict(
|
|
boxes=dict(
|
|
min_shape=[1, 500, 4],
|
|
opt_shape=[1, 500, 4],
|
|
max_shape=[1, 500, 4]),
|
|
scores=dict(
|
|
min_shape=[1, 500, 80],
|
|
opt_shape=[1, 500, 80],
|
|
max_shape=[1, 500, 80])))
|
|
]),
|
|
codebase_config=dict(
|
|
type='mmdet',
|
|
task='ObjectDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.5,
|
|
max_output_boxes_per_class=200,
|
|
pre_top_k=-1,
|
|
keep_top_k=100,
|
|
background_label_id=-1,
|
|
))))
|
|
|
|
boxes = torch.rand(1, 500, 4).cuda()
|
|
scores = torch.rand(1, 500, 80).cuda()
|
|
max_output_boxes_per_class = 200
|
|
keep_top_k = 100
|
|
wrapped_func = WrapFunction(
|
|
multiclass_nms,
|
|
max_output_boxes_per_class=max_output_boxes_per_class,
|
|
keep_top_k=keep_top_k)
|
|
rewrite_outputs = get_rewrite_outputs(
|
|
wrapped_func,
|
|
model_inputs={
|
|
'boxes': boxes,
|
|
'scores': scores
|
|
},
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
|
'outputs: {}'.format(rewrite_outputs)
|