[Unittests] MMDet unittests (#112)

* add mmdet unittests

* remove redundant img meta info

* import wrapper from backend util

* force add wrapper

* use smaller nuance

* add seed everything

* add create input part and some inference part unittests

* fix lint

* skip ppl

* remove pyppl

* add dataset files

* import ncnn_ext inside ncnn warpper

* use cpu device to create input

* add pssd and ptsd unittests

* clear mmdet unittests

* wrap function to enable rewrite

* refine codes and resolve comments

* move mmdet inside test func

* remove redundant line

* test visualize in mmdeploy.apis

* use less memory

* resolve comments

* fix ci

* move pytest.skip inside test function and use 3 channel input
This commit is contained in:
AllentDan 2021-10-13 17:24:11 +08:00 committed by GitHub
parent 6fdf6b8616
commit d4828c7836
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1549 additions and 184 deletions

View File

@ -1,10 +1,10 @@
import importlib
from typing import Dict, Iterable, Optional
import ncnn
import numpy as np
import torch
from mmdeploy.apis.ncnn import ncnn_ext
from mmdeploy.utils.timer import TimeCounter
@ -37,7 +37,9 @@ class NCNNWrapper(torch.nn.Module):
super(NCNNWrapper, self).__init__()
net = ncnn.Net()
ncnn_ext.register_mm_custom_layers(net)
if importlib.util.find_spec('mmdeploy.apis.ncnn.ncnn_ext'):
from mmdeploy.apis.ncnn import ncnn_ext
ncnn_ext.register_mm_custom_layers(net)
net.load_param(param_file)
net.load_model(bin_file)

View File

@ -333,12 +333,12 @@ def build_dataloader(codebase: Codebase, dataset: Dataset,
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
def get_tensor_from_input(codebase: Codebase, input_data: tuple):
def get_tensor_from_input(codebase: Codebase, input_data: Dict[str, Any]):
"""Get input tensor from input data.
Args:
codebase (Codebase): Specifying codebase type.
input_data (tuple): Input data containing meta info and image tensor.
input_data (dict): Input data containing meta info and image tensor.
Returns:
torch.Tensor: An image in `Tensor`.

View File

@ -48,13 +48,13 @@ def tblr2bboxes(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
backend='ncnn')
def delta2bbox_ncnn(ctx,
priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
def tblr2bboxes_ncnn(ctx,
priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
"""Rewrite for ONNX exporting of NCNN backend."""
assert priors.size(0) == tblr.size(0)
if priors.ndim == 3:

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union
import mmcv
import numpy as np
@ -151,11 +151,11 @@ def build_dataloader(dataset: Dataset,
**kwargs)
def get_tensor_from_input(input_data: tuple):
def get_tensor_from_input(input_data: Dict[str, Any]):
"""Get input tensor from input data.
Args:
input_data (tuple): Input data containing meta info and image tensor.
input_data (dict): Input data containing meta info and image tensor.
Returns:
torch.Tensor: An image in `Tensor`.
"""

View File

@ -40,8 +40,13 @@ def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
MaskTestMixin.simple_test_mask')
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
det_labels, **kwargs):
assert det_bboxes.shape[1] != 0, 'Can not record MaskHead as it \
has not been executed this time'
if det_bboxes.shape[1] == 0:
bboxes_shape, labels_shape = list(det_bboxes.shape), list(
det_labels.shape)
bboxes_shape[1], labels_shape[1] = 1, 1
det_bboxes = torch.tensor([[[0., 0., 1., 1.,
0.]]]).expand(*bboxes_shape)
det_labels = torch.tensor([[0]]).expand(*labels_shape)
batch_size = det_bboxes.size(0)
det_bboxes = det_bboxes[..., :4]

View File

@ -11,6 +11,8 @@ def size_of_tensor_static(ctx, self, *args):
ret = ctx.origin_func(self, *args)
if isinstance(ret, torch.Tensor):
ret = int(ret)
elif isinstance(ret, int):
return (ret)
else:
ret = [int(r) for r in ret]
ret = tuple(ret)

View File

@ -57,6 +57,54 @@ class WrapModel(nn.Module):
return func(*args, **kwargs)
class SwitchBackendWrapper:
"""A switcher for backend wrapper for unit tests.
Examples:
>>> from mmdeploy.utils.test import SwitchBackendWrapper
>>> from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
>>> SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
>>> ...
>>> SwitchBackendWrapper.recover(ORTWrapper)
"""
init = None
forward = None
call = None
class BackendWrapper(torch.nn.Module):
"""A dummy wrapper for unit tests."""
def __init__(self, *args, **kwargs):
self.output_names = ['dets', 'labels']
def forward(self, *args, **kwargs):
return self.outputs
def __call__(self, *args, **kwds):
return self.forward(*args, **kwds)
@staticmethod
def set(obj, **kwargs):
"""Replace attributes in backend wrappers with dummy items."""
SwitchBackendWrapper.init = obj.__init__
SwitchBackendWrapper.forward = obj.forward
SwitchBackendWrapper.call = obj.__call__
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
for k, v in kwargs.items():
setattr(obj, k, v)
@staticmethod
def recover(obj):
assert SwitchBackendWrapper.init is not None and \
SwitchBackendWrapper.forward is not None,\
'recover method must be called after exchange'
obj.__init__ = SwitchBackendWrapper.init
obj.forward = SwitchBackendWrapper.forward
obj.__call__ = SwitchBackendWrapper.call
def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
actual: List[Union[torch.Tensor, np.ndarray]],
tolerate_small_mismatch: bool = False):
@ -180,6 +228,9 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
backend_feats[input_names[i]] = feature_list[i]
else:
backend_feats[str(i)] = feature_list[i]
elif backend == Backend.NCNN:
return ctx_outputs
with torch.no_grad():
backend_outputs = backend_model.forward(backend_feats)
return backend_outputs

View File

@ -0,0 +1,39 @@
{
"images": [
{
"file_name": "000000000139.jpg",
"height": 800,
"width": 800,
"id": 0
}
],
"annotations": [
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 1,
"image_id": 0
}
],
"categories": [
{
"id": 1,
"name": "bus",
"supercategory": "none"
},
{
"id": 2,
"name": "car",
"supercategory": "none"
}
],
"licenses": [],
"info": null
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

View File

@ -0,0 +1,231 @@
{
"type": "MaskRCNN",
"backbone": {
"type": "ResNet",
"depth": 50,
"num_stages": 4,
"out_indices": [
0,
1,
2,
3
],
"frozen_stages": 1,
"norm_cfg": {
"type": "BN",
"requires_grad": true
},
"norm_eval": true,
"style": "pytorch",
"init_cfg": {
"type": "Pretrained",
"checkpoint": "torchvision://resnet50"
}
},
"neck": {
"type": "FPN",
"in_channels": [
256,
512,
1024,
2048
],
"out_channels": 256,
"num_outs": 5
},
"rpn_head": {
"type": "RPNHead",
"in_channels": 256,
"feat_channels": 256,
"anchor_generator": {
"type": "AnchorGenerator",
"scales": [
8
],
"ratios": [
0.5,
1.0,
2.0
],
"strides": [
4,
8,
16,
32,
64
]
},
"bbox_coder": {
"type": "DeltaXYWHBBoxCoder",
"target_means": [
0.0,
0.0,
0.0,
0.0
],
"target_stds": [
1.0,
1.0,
1.0,
1.0
]
},
"loss_cls": {
"type": "CrossEntropyLoss",
"use_sigmoid": true,
"loss_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 1.0
}
},
"roi_head": {
"type": "StandardRoIHead",
"bbox_roi_extractor": {
"type": "SingleRoIExtractor",
"roi_layer": {
"type": "RoIAlign",
"output_size": 7,
"sampling_ratio": 0
},
"out_channels": 256,
"featmap_strides": [
4,
8,
16,
32
]
},
"bbox_head": {
"type": "Shared2FCBBoxHead",
"in_channels": 256,
"fc_out_channels": 1024,
"roi_feat_size": 7,
"num_classes": 80,
"bbox_coder": {
"type": "DeltaXYWHBBoxCoder",
"target_means": [
0.0,
0.0,
0.0,
0.0
],
"target_stds": [
0.1,
0.1,
0.2,
0.2
]
},
"reg_class_agnostic": false,
"loss_cls": {
"type": "CrossEntropyLoss",
"use_sigmoid": false,
"loss_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 1.0
}
},
"mask_roi_extractor": {
"type": "SingleRoIExtractor",
"roi_layer": {
"type": "RoIAlign",
"output_size": 14,
"sampling_ratio": 0
},
"out_channels": 256,
"featmap_strides": [
4,
8,
16,
32
]
},
"mask_head": {
"type": "FCNMaskHead",
"num_convs": 4,
"in_channels": 256,
"conv_out_channels": 256,
"num_classes": 80,
"loss_mask": {
"type": "CrossEntropyLoss",
"use_mask": true,
"loss_weight": 1.0
}
}
},
"train_cfg": {
"rpn": {
"assigner": {
"type": "MaxIoUAssigner",
"pos_iou_thr": 0.7,
"neg_iou_thr": 0.3,
"min_pos_iou": 0.3,
"match_low_quality": true,
"ignore_iof_thr": -1
},
"sampler": {
"type": "RandomSampler",
"num": 256,
"pos_fraction": 0.5,
"neg_pos_ub": -1,
"add_gt_as_proposals": false
},
"allowed_border": -1,
"pos_weight": -1,
"debug": false
},
"rpn_proposal": {
"nms_pre": 2000,
"max_per_img": 1000,
"nms": {
"type": "nms",
"iou_threshold": 0.7
},
"min_bbox_size": 0
},
"rcnn": {
"assigner": {
"type": "MaxIoUAssigner",
"pos_iou_thr": 0.5,
"neg_iou_thr": 0.5,
"min_pos_iou": 0.5,
"match_low_quality": true,
"ignore_iof_thr": -1
},
"sampler": {
"type": "RandomSampler",
"num": 512,
"pos_fraction": 0.25,
"neg_pos_ub": -1,
"add_gt_as_proposals": true
},
"mask_size": 28,
"pos_weight": -1,
"debug": false
}
},
"test_cfg": {
"rpn": {
"nms_pre": 1000,
"max_per_img": 1000,
"nms": {
"type": "nms",
"iou_threshold": 0.7
},
"min_bbox_size": 0
},
"rcnn": {
"score_thr": 0.05,
"nms": {
"type": "nms",
"iou_threshold": 0.5
},
"max_per_img": 100,
"mask_thr_binary": 0.5
}
}
}

View File

@ -0,0 +1,110 @@
{
"type": "RetinaNet",
"backbone": {
"type": "ResNet",
"depth": 50,
"num_stages": 4,
"out_indices": [
0,
1,
2,
3
],
"frozen_stages": 1,
"norm_cfg": {
"type": "BN",
"requires_grad": true
},
"norm_eval": true,
"style": "pytorch",
"init_cfg": {
"type": "Pretrained",
"checkpoint": "torchvision://resnet50"
}
},
"neck": {
"type": "FPN",
"in_channels": [
256,
512,
1024,
2048
],
"out_channels": 256,
"start_level": 1,
"add_extra_convs": "on_input",
"num_outs": 5
},
"bbox_head": {
"type": "RetinaHead",
"num_classes": 80,
"in_channels": 256,
"stacked_convs": 4,
"feat_channels": 256,
"anchor_generator": {
"type": "AnchorGenerator",
"octave_base_scale": 4,
"scales_per_octave": 3,
"ratios": [
0.5,
1.0,
2.0
],
"strides": [
8,
16,
32,
64,
128
]
},
"bbox_coder": {
"type": "DeltaXYWHBBoxCoder",
"target_means": [
0.0,
0.0,
0.0,
0.0
],
"target_stds": [
1.0,
1.0,
1.0,
1.0
]
},
"loss_cls": {
"type": "FocalLoss",
"use_sigmoid": true,
"gamma": 2.0,
"alpha": 0.25,
"loss_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 1.0
}
},
"train_cfg": {
"assigner": {
"type": "MaxIoUAssigner",
"pos_iou_thr": 0.5,
"neg_iou_thr": 0.4,
"min_pos_iou": 0,
"ignore_iof_thr": -1
},
"allowed_border": -1,
"pos_weight": -1,
"debug": false
},
"test_cfg": {
"nms_pre": 1000,
"min_bbox_size": 0,
"score_thr": 0.05,
"nms": {
"type": "nms",
"iou_threshold": 0.5
},
"max_per_img": 100
}
}

View File

@ -1,68 +0,0 @@
import importlib
import mmcv
import pytest
import torch
from mmdeploy.mmdet.core.post_processing.bbox_nms import multiclass_nms
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_multiclass_nms_static():
import tensorrt as trt
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
boxes=dict(
min_shape=[1, 500, 4],
opt_shape=[1, 500, 4],
max_shape=[1, 500, 4]),
scores=dict(
min_shape=[1, 500, 80],
opt_shape=[1, 500, 80],
max_shape=[1, 500, 80])))
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
boxes = torch.rand(1, 500, 4).cuda()
scores = torch.rand(1, 500, 80).cuda()
max_output_boxes_per_class = 200
keep_top_k = 100
wrapped_func = WrapFunction(
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'boxes': boxes,
'scores': scores
},
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
'outputs: {}'.format(rewrite_outputs)

View File

@ -0,0 +1,549 @@
import importlib
import mmcv
import numpy as np
import pytest
import torch
import mmdeploy.apis.ncnn as ncnn_apis
import mmdeploy.apis.onnxruntime as ort_apis
import mmdeploy.apis.ppl as ppl_apis
import mmdeploy.apis.tensorrt as trt_apis
from mmdeploy.utils.test import SwitchBackendWrapper
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTDetector():
# force add backend wrapper regardless of plugins
# make sure TensorRTDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
# simplify backend inference
outputs = {
'dets': torch.rand(1, 100, 5).cuda(),
'labels': torch.rand(1, 100).cuda()
}
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
from mmdeploy.mmdet.apis.inference import TensorRTDetector
trt_detector = TensorRTDetector('', ['' for i in range(80)], 0)
imgs = [torch.rand(1, 3, 64, 64).cuda()]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = trt_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using TensorRTDetector'
SwitchBackendWrapper.recover(TRTWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimeDetector():
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimeDetector can use ORTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector
ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ort_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '\
'ONNXRuntimeDetector'
SwitchBackendWrapper.recover(ORTWrapper)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
def test_PPLDetector():
# force add backend wrapper regardless of plugins
# make sure PPLDetector can use PPLWrapper inside itself
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
# simplify backend inference
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
SwitchBackendWrapper.set(PPLWrapper, outputs=outputs)
from mmdeploy.mmdet.apis.inference import PPLDetector
ppl_detector = PPLDetector('', ['' for i in range(80)], 0)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ppl_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using PPLDetector'
SwitchBackendWrapper.recover(PPLWrapper)
def get_test_cfg_and_post_processing():
test_cfg = {
'nms_pre': 100,
'min_bbox_size': 0,
'score_thr': 0.05,
'nms': {
'type': 'nms',
'iou_threshold': 0.5
},
'max_per_img': 10
}
post_processing = {
'score_threshold': 0.05,
'iou_threshold': 0.5,
'max_output_boxes_per_class': 20,
'pre_top_k': -1,
'keep_top_k': 10,
'background_label_id': -1
}
return test_cfg, post_processing
def test_PartitionSingleStageDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
from mmdeploy.mmdet.apis.inference import PartitionSingleStageDetector
pss_detector = PartitionSingleStageDetector(['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
scores = torch.rand(1, 120, 80)
bboxes = torch.rand(1, 120, 4)
results = pss_detector.partition0_postprocess(scores=scores, bboxes=bboxes)
assert results is not None, 'failed to get output using '\
'partition0_postprocess of PartitionSingleStageDetector'
@pytest.mark.skipif(
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
def test_NCNNPSSDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
# force add backend wrapper regardless of plugins
# make sure NCNNPSSDetector can use NCNNWrapper inside itself
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 120, 80),
'boxes': torch.rand(1, 120, 4)
}
SwitchBackendWrapper.set(
NCNNWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
from mmdeploy.mmdet.apis.inference import NCNNPSSDetector
ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ncnn_pss_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using NCNNPSSDetector'
SwitchBackendWrapper.recover(NCNNWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimePSSDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimePSSDetector can use ORTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
outputs = [
np.random.rand(1, 120, 80).astype(np.float32),
np.random.rand(1, 120, 4).astype(np.float32)
]
SwitchBackendWrapper.set(
ORTWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector
ort_pss_detector = ONNXRuntimePSSDetector(
'', ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ort_pss_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'ONNXRuntimePSSDetector'
SwitchBackendWrapper.recover(ORTWrapper)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTPSSDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
# force add backend wrapper regardless of plugins
# make sure TensorRTPSSDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 120, 80).cuda(),
'boxes': torch.rand(1, 120, 4).cuda()
}
SwitchBackendWrapper.set(
TRTWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector
trt_pss_detector = TensorRTPSSDetector(
'', ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32).cuda()]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = trt_pss_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'TensorRTPSSDetector'
SwitchBackendWrapper.recover(TRTWrapper)
def prepare_model_deploy_cfgs():
test_cfg, post_processing = get_test_cfg_and_post_processing()
bbox_roi_extractor = {
'type': 'SingleRoIExtractor',
'roi_layer': {
'type': 'RoIAlign',
'output_size': 7,
'sampling_ratio': 0
},
'out_channels': 8,
'featmap_strides': [4]
}
bbox_head = {
'type': 'Shared2FCBBoxHead',
'in_channels': 8,
'fc_out_channels': 1024,
'roi_feat_size': 7,
'num_classes': 80,
'bbox_coder': {
'type': 'DeltaXYWHBBoxCoder',
'target_means': [0.0, 0.0, 0.0, 0.0],
'target_stds': [0.1, 0.1, 0.2, 0.2]
},
'reg_class_agnostic': False,
'loss_cls': {
'type': 'CrossEntropyLoss',
'use_sigmoid': False,
'loss_weight': 1.0
},
'loss_bbox': {
'type': 'L1Loss',
'loss_weight': 1.0
}
}
roi_head = dict(bbox_roi_extractor=bbox_roi_extractor, bbox_head=bbox_head)
model_cfg = mmcv.Config(
dict(
model=dict(
test_cfg=dict(rpn=test_cfg, rcnn=test_cfg),
roi_head=roi_head)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
return model_cfg, deploy_cfg
def test_PartitionTwoStageDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
pts_detector = PartitionTwoStageDetector(['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
feats = [torch.randn(1, 8, 14, 14) for i in range(5)]
scores = torch.rand(1, 50, 1)
bboxes = torch.rand(1, 50, 4)
bboxes[..., 2:4] = 2 * bboxes[..., :2]
results = pts_detector.partition0_postprocess(
x=feats, scores=scores, bboxes=bboxes)
assert results is not None, 'failed to get output using '\
'partition0_postprocess of PartitionTwoStageDetector'
rois = torch.rand(1, 10, 5)
cls_score = torch.rand(10, 81)
bbox_pred = torch.rand(10, 320)
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = pts_detector.partition1_postprocess(
rois=rois,
cls_score=cls_score,
bbox_pred=bbox_pred,
img_metas=img_metas)
assert results is not None, 'failed to get output using '\
'partition1_postprocess of PartitionTwoStageDetector'
class DummyPTSDetector(torch.nn.Module):
"""A dummy wrapper for unit tests."""
def __init__(self, *args, **kwargs):
self.output_names = ['dets', 'labels']
def partition0_postprocess(self, *args, **kwargs):
return self.outputs0
def partition1_postprocess(self, *args, **kwargs):
return self.outputs1
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTPTSDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
# force add backend wrapper regardless of plugins
# make sure TensorRTPTSDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 12, 80).cuda(),
'boxes': torch.rand(1, 12, 4).cuda(),
'cls_score': torch.rand(1, 12, 80).cuda(),
'bbox_pred': torch.rand(1, 12, 4).cuda()
}
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
TRTWrapper.model_cfg = model_cfg
TRTWrapper.deploy_cfg = deploy_cfg
# replace original function in PartitionTwoStageDetector
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
PartitionTwoStageDetector.partition0_postprocess = \
DummyPTSDetector.partition0_postprocess
PartitionTwoStageDetector.partition1_postprocess = \
DummyPTSDetector.partition1_postprocess
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2
PartitionTwoStageDetector.outputs1 = [
torch.rand(1, 9, 5).cuda(),
torch.rand(1, 9).cuda()
]
PartitionTwoStageDetector.device_id = 0
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector
trt_pts_detector = TensorRTPTSDetector(['', ''], ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32).cuda()]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = trt_pts_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'TensorRTPTSDetector'
SwitchBackendWrapper.recover(TRTWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimePTSDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimePTSDetector can use TRTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
outputs = [
np.random.rand(1, 12, 80).astype(np.float32),
np.random.rand(1, 12, 4).astype(np.float32),
] * 2
SwitchBackendWrapper.set(
ORTWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
# replace original function in PartitionTwoStageDetector
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
PartitionTwoStageDetector.partition0_postprocess = \
DummyPTSDetector.partition0_postprocess
PartitionTwoStageDetector.partition1_postprocess = \
DummyPTSDetector.partition1_postprocess
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
PartitionTwoStageDetector.outputs1 = [
torch.rand(1, 9, 5), torch.rand(1, 9)
]
PartitionTwoStageDetector.device_id = -1
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector
ort_pts_detector = ONNXRuntimePTSDetector(['', ''],
['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ort_pts_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'ONNXRuntimePTSDetector'
SwitchBackendWrapper.recover(ORTWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
def test_NCNNPTSDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
num_outs = dict(model=dict(neck=dict(num_outs=0)))
model_cfg.update(num_outs)
# force add backend wrapper regardless of plugins
# make sure NCNNPTSDetector can use TRTWrapper inside itself
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 12, 80),
'boxes': torch.rand(1, 12, 4),
'cls_score': torch.rand(1, 12, 80),
'bbox_pred': torch.rand(1, 12, 4)
}
SwitchBackendWrapper.set(
NCNNWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
# replace original function in PartitionTwoStageDetector
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
PartitionTwoStageDetector.partition0_postprocess = \
DummyPTSDetector.partition0_postprocess
PartitionTwoStageDetector.partition1_postprocess = \
DummyPTSDetector.partition1_postprocess
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
PartitionTwoStageDetector.outputs1 = [
torch.rand(1, 9, 5), torch.rand(1, 9)
]
PartitionTwoStageDetector.device_id = -1
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
from mmdeploy.mmdet.apis.inference import NCNNPTSDetector
ncnn_pts_detector = NCNNPTSDetector(
[''] * 4, [''] * 80,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ncnn_pts_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'NCNNPTSDetector'
SwitchBackendWrapper.recover(NCNNWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_build_detector():
_, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(data=dict(test={'type': 'CocoDataset'})))
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(
type='mmdet', post_processing=post_processing)))
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
SwitchBackendWrapper.set(
ORTWrapper, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
from mmdeploy.apis.utils import init_backend_model
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
assert detector is not None
SwitchBackendWrapper.recover(ORTWrapper)

View File

@ -0,0 +1,147 @@
import importlib
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_multiclass_nms_static():
import tensorrt as trt
from mmdeploy.mmdet.core import multiclass_nms
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 20),
model_inputs=[
dict(
input_shapes=dict(
boxes=dict(
min_shape=[1, 5, 4],
opt_shape=[1, 5, 4],
max_shape=[1, 5, 4]),
scores=dict(
min_shape=[1, 5, 8],
opt_shape=[1, 5, 8],
max_shape=[1, 5, 8])))
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=20,
pre_top_k=-1,
keep_top_k=10,
background_label_id=-1,
))))
boxes = torch.rand(1, 5, 4).cuda()
scores = torch.rand(1, 5, 8).cuda()
max_output_boxes_per_class = 20
keep_top_k = 10
wrapped_func = WrapFunction(
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'boxes': boxes,
'scores': scores
},
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
'outputs: {}'.format(rewrite_outputs)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_delta2bbox(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(type=backend_type, model_inputs=None),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
# wrap function to enable rewrite
def delta2bbox(*args, **kwargs):
import mmdet
return mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox(
*args, **kwargs)
rois = torch.rand(1, 5, 4)
deltas = torch.rand(1, 5, 4)
original_outputs = delta2bbox(rois, deltas)
# wrap function to nn.Module, enable torch.onn.export
wrapped_func = WrapFunction(delta2bbox)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'rois': rois,
'deltas': deltas
},
deploy_cfg=deploy_cfg)
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_tblr2bbox(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(type=backend_type, model_inputs=None),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
# wrap function to enable rewrite
def tblr2bboxes(*args, **kwargs):
import mmdet
return mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes(
*args, **kwargs)
priors = torch.rand(1, 5, 4)
tblr = torch.rand(1, 5, 4)
original_outputs = tblr2bboxes(priors, tblr)
# wrap function to nn.Module, enable torch.onn.export
wrapped_func = WrapFunction(tblr2bboxes)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'priors': priors,
'tblr': tblr
},
deploy_cfg=deploy_cfg)
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
def test_distance2bbox():
from mmdeploy.mmdet.core import distance2bbox
points = torch.rand(3, 2)
distance = torch.rand(3, 4)
bbox = distance2bbox(points, distance)
assert bbox.shape == torch.Size([3, 4])

View File

@ -0,0 +1,104 @@
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.apis.utils import (build_dataloader, build_dataset, create_input,
get_tensor_from_input)
from mmdeploy.utils.constants import Codebase, Task
def test_create_input():
task = Task.OBJECT_DETECTION
test = dict(pipeline=[{
'type': 'LoadImageFromWebcam'
}, {
'type':
'MultiScaleFlipAug',
'img_scale': [32, 32],
'flip':
False,
'transforms': [{
'type': 'Resize',
'keep_ratio': True
}, {
'type': 'RandomFlip'
}, {
'type': 'Normalize',
'mean': [123.675, 116.28, 103.53],
'std': [58.395, 57.12, 57.375],
'to_rgb': True
}, {
'type': 'Pad',
'size_divisor': 32
}, {
'type': 'DefaultFormatBundle'
}, {
'type': 'Collect',
'keys': ['img']
}]
}])
data = dict(test=test)
model_cfg = mmcv.Config(dict(data=data))
imgs = [np.random.rand(32, 32, 3)]
inputs = create_input(
Codebase.MMDET,
task,
model_cfg,
imgs,
input_shape=(32, 32),
device='cpu')
assert inputs is not None, 'Failed to create input'
@pytest.mark.parametrize('input_data', [{'img': [torch.ones(3, 4, 5)]}])
def test_get_tensor_from_input(input_data):
inputs = get_tensor_from_input(Codebase.MMDET, input_data)
assert inputs is not None, 'Failed to get tensor from input'
def test_build_dataset():
data = dict(
test={
'type': 'CocoDataset',
'ann_file': 'tests/test_mmdet/data/coco_sample.json',
'img_prefix': 'tests/test_mmdet/data/imgs/',
'pipeline': [
{
'type': 'LoadImageFromFile'
},
]
})
dataset_cfg = mmcv.Config(dict(data=data))
dataset = build_dataset(
Codebase.MMDET, dataset_cfg=dataset_cfg, dataset_type='test')
assert dataset is not None, 'Failed to build dataset'
dataloader = build_dataloader(Codebase.MMDET, dataset, 1, 1)
assert dataloader is not None, 'Failed to build dataloader'
def test_clip_bboxes():
from mmdeploy.mmdet.export import clip_bboxes
x1 = torch.rand(3, 2) * 224
y1 = torch.rand(3, 2) * 224
x2 = x1 * 2
y2 = y1 * 2
outs = clip_bboxes(x1, y1, x2, y2, [224, 224])
for out in outs:
assert int(out.max()) <= 224
def test_pad_with_value():
from mmdeploy.mmdet.export import pad_with_value
x = torch.rand(3, 2)
padded_x = pad_with_value(x, pad_dim=1, pad_size=4, pad_value=0)
assert np.allclose(
padded_x.shape, torch.Size([3, 6]), rtol=1e-03, atol=1e-05)
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('partition_type', ['single_stage', 'two_stage'])
def test_get_partition_cfg(partition_type):
from mmdeploy.mmdet.export import get_partition_cfg
partition_cfg = get_partition_cfg(partition_type=partition_type)
assert partition_cfg is not None

View File

@ -0,0 +1,293 @@
import copy
import importlib
import os
import random
import tempfile
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.utils.constants import Backend, Codebase
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
get_rewrite_outputs)
def seed_everything(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def get_anchor_head_model():
"""AnchorHead Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import AnchorHead
model = AnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def get_fcos_head_model():
"""FCOS Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import FCOSHead
model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def get_rpn_head_model():
"""RPN Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import RPNHead
model = RPNHead(in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_anchor_head_get_bboxes(backend_type):
"""Test get_bboxes rewrite of anchor head."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_get_bboxes_of_fcos_head(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
fcos_head = get_fcos_head_model()
fcos_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
seed_everything(9101)
centernesses = [
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'centernesses': centernesses,
'img_metas': img_metas
}
model_outputs = get_model_outputs(fcos_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
fcos_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'centernesses': centernesses
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
def _replace_r50_with_r18(model):
"""Replace ResNet50 with ResNet18 in config."""
model = copy.deepcopy(model)
if model.backbone.type == 'ResNet':
model.backbone.depth = 18
model.backbone.base_channels = 2
model.neck.in_channels = [2, 4, 8, 16]
return model
@pytest.mark.parametrize('model_cfg_path', [
'tests/test_mmdet/data/single_stage_model.json',
'tests/test_mmdet/data/mask_model.json'
])
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_forward_of_base_detector_and_visualize(model_cfg_path):
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
model_cfg.model = _replace_r50_with_r18(model_cfg.model)
from mmdet.apis import init_detector
model = init_detector(model_cfg, None, 'cpu')
img = torch.randn(1, 3, 64, 64)
rewrite_inputs = {'img': img}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
from mmdeploy.apis.utils import visualize
output_file = tempfile.NamedTemporaryFile(suffix='.jpg').name
model.CLASSES = [''] * 80
visualize(
Codebase.MMDET,
img.squeeze().permute(1, 2, 0).numpy(),
result=[torch.rand(0, 5).numpy()] * 80,
model=model,
output_file=output_file,
backend=Backend.ONNXRUNTIME,
show_result=False)
assert rewrite_outputs is not None

View File

@ -1,100 +0,0 @@
import os.path as osp
import mmcv
import numpy as np
import torch
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
get_rewrite_outputs)
data_path = osp.join(osp.dirname(__file__), 'data')
def get_anchor_head_model():
"""AnchorHead Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import AnchorHead
model = AnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def test_anchor_head_get_bboxes():
"""Test get_bboxes rewrite of anchor head."""
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 2)
}]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
# The data of anchor_head_get_bboxes.pkl contains two parts:
# cls_score(list(Tensor)) and bboxes(list(Tensor)),
# where each torch.Tensor is generated by torch.rand().
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
retina_head_data = 'anchor_head_get_bboxes.pkl'
feats = mmcv.load(osp.join(data_path, retina_head_data))
cls_score = feats[:5]
bboxes = feats[5:]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)