[Unittests] MMDet unittests (#112)

* add mmdet unittests

* remove redundant img meta info

* import wrapper from backend util

* force add wrapper

* use smaller nuance

* add seed everything

* add create input part and some inference part unittests

* fix lint

* skip ppl

* remove pyppl

* add dataset files

* import ncnn_ext inside ncnn warpper

* use cpu device to create input

* add pssd and ptsd unittests

* clear mmdet unittests

* wrap function to enable rewrite

* refine codes and resolve comments

* move mmdet inside test func

* remove redundant line

* test visualize in mmdeploy.apis

* use less memory

* resolve comments

* fix ci

* move pytest.skip inside test function and use 3 channel input
This commit is contained in:
AllentDan 2021-10-13 17:24:11 +08:00 committed by GitHub
parent 6fdf6b8616
commit d4828c7836
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1549 additions and 184 deletions

View File

@ -1,10 +1,10 @@
import importlib
from typing import Dict, Iterable, Optional from typing import Dict, Iterable, Optional
import ncnn import ncnn
import numpy as np import numpy as np
import torch import torch
from mmdeploy.apis.ncnn import ncnn_ext
from mmdeploy.utils.timer import TimeCounter from mmdeploy.utils.timer import TimeCounter
@ -37,7 +37,9 @@ class NCNNWrapper(torch.nn.Module):
super(NCNNWrapper, self).__init__() super(NCNNWrapper, self).__init__()
net = ncnn.Net() net = ncnn.Net()
ncnn_ext.register_mm_custom_layers(net) if importlib.util.find_spec('mmdeploy.apis.ncnn.ncnn_ext'):
from mmdeploy.apis.ncnn import ncnn_ext
ncnn_ext.register_mm_custom_layers(net)
net.load_param(param_file) net.load_param(param_file)
net.load_model(bin_file) net.load_model(bin_file)

View File

@ -333,12 +333,12 @@ def build_dataloader(codebase: Codebase, dataset: Dataset,
raise NotImplementedError(f'Unknown codebase type: {codebase.value}') raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
def get_tensor_from_input(codebase: Codebase, input_data: tuple): def get_tensor_from_input(codebase: Codebase, input_data: Dict[str, Any]):
"""Get input tensor from input data. """Get input tensor from input data.
Args: Args:
codebase (Codebase): Specifying codebase type. codebase (Codebase): Specifying codebase type.
input_data (tuple): Input data containing meta info and image tensor. input_data (dict): Input data containing meta info and image tensor.
Returns: Returns:
torch.Tensor: An image in `Tensor`. torch.Tensor: An image in `Tensor`.

View File

@ -48,13 +48,13 @@ def tblr2bboxes(ctx,
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes', func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
backend='ncnn') backend='ncnn')
def delta2bbox_ncnn(ctx, def tblr2bboxes_ncnn(ctx,
priors, priors,
tblr, tblr,
normalizer=4.0, normalizer=4.0,
normalize_by_wh=True, normalize_by_wh=True,
max_shape=None, max_shape=None,
clip_border=True): clip_border=True):
"""Rewrite for ONNX exporting of NCNN backend.""" """Rewrite for ONNX exporting of NCNN backend."""
assert priors.size(0) == tblr.size(0) assert priors.size(0) == tblr.size(0)
if priors.ndim == 3: if priors.ndim == 3:

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Union from typing import Any, Dict, Optional, Sequence, Union
import mmcv import mmcv
import numpy as np import numpy as np
@ -151,11 +151,11 @@ def build_dataloader(dataset: Dataset,
**kwargs) **kwargs)
def get_tensor_from_input(input_data: tuple): def get_tensor_from_input(input_data: Dict[str, Any]):
"""Get input tensor from input data. """Get input tensor from input data.
Args: Args:
input_data (tuple): Input data containing meta info and image tensor. input_data (dict): Input data containing meta info and image tensor.
Returns: Returns:
torch.Tensor: An image in `Tensor`. torch.Tensor: An image in `Tensor`.
""" """

View File

@ -40,8 +40,13 @@ def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
MaskTestMixin.simple_test_mask') MaskTestMixin.simple_test_mask')
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes, def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
det_labels, **kwargs): det_labels, **kwargs):
assert det_bboxes.shape[1] != 0, 'Can not record MaskHead as it \ if det_bboxes.shape[1] == 0:
has not been executed this time' bboxes_shape, labels_shape = list(det_bboxes.shape), list(
det_labels.shape)
bboxes_shape[1], labels_shape[1] = 1, 1
det_bboxes = torch.tensor([[[0., 0., 1., 1.,
0.]]]).expand(*bboxes_shape)
det_labels = torch.tensor([[0]]).expand(*labels_shape)
batch_size = det_bboxes.size(0) batch_size = det_bboxes.size(0)
det_bboxes = det_bboxes[..., :4] det_bboxes = det_bboxes[..., :4]

View File

@ -11,6 +11,8 @@ def size_of_tensor_static(ctx, self, *args):
ret = ctx.origin_func(self, *args) ret = ctx.origin_func(self, *args)
if isinstance(ret, torch.Tensor): if isinstance(ret, torch.Tensor):
ret = int(ret) ret = int(ret)
elif isinstance(ret, int):
return (ret)
else: else:
ret = [int(r) for r in ret] ret = [int(r) for r in ret]
ret = tuple(ret) ret = tuple(ret)

View File

@ -57,6 +57,54 @@ class WrapModel(nn.Module):
return func(*args, **kwargs) return func(*args, **kwargs)
class SwitchBackendWrapper:
"""A switcher for backend wrapper for unit tests.
Examples:
>>> from mmdeploy.utils.test import SwitchBackendWrapper
>>> from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
>>> SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
>>> ...
>>> SwitchBackendWrapper.recover(ORTWrapper)
"""
init = None
forward = None
call = None
class BackendWrapper(torch.nn.Module):
"""A dummy wrapper for unit tests."""
def __init__(self, *args, **kwargs):
self.output_names = ['dets', 'labels']
def forward(self, *args, **kwargs):
return self.outputs
def __call__(self, *args, **kwds):
return self.forward(*args, **kwds)
@staticmethod
def set(obj, **kwargs):
"""Replace attributes in backend wrappers with dummy items."""
SwitchBackendWrapper.init = obj.__init__
SwitchBackendWrapper.forward = obj.forward
SwitchBackendWrapper.call = obj.__call__
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
for k, v in kwargs.items():
setattr(obj, k, v)
@staticmethod
def recover(obj):
assert SwitchBackendWrapper.init is not None and \
SwitchBackendWrapper.forward is not None,\
'recover method must be called after exchange'
obj.__init__ = SwitchBackendWrapper.init
obj.forward = SwitchBackendWrapper.forward
obj.__call__ = SwitchBackendWrapper.call
def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]], def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
actual: List[Union[torch.Tensor, np.ndarray]], actual: List[Union[torch.Tensor, np.ndarray]],
tolerate_small_mismatch: bool = False): tolerate_small_mismatch: bool = False):
@ -180,6 +228,9 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
backend_feats[input_names[i]] = feature_list[i] backend_feats[input_names[i]] = feature_list[i]
else: else:
backend_feats[str(i)] = feature_list[i] backend_feats[str(i)] = feature_list[i]
elif backend == Backend.NCNN:
return ctx_outputs
with torch.no_grad(): with torch.no_grad():
backend_outputs = backend_model.forward(backend_feats) backend_outputs = backend_model.forward(backend_feats)
return backend_outputs return backend_outputs

View File

@ -0,0 +1,39 @@
{
"images": [
{
"file_name": "000000000139.jpg",
"height": 800,
"width": 800,
"id": 0
}
],
"annotations": [
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 1,
"image_id": 0
}
],
"categories": [
{
"id": 1,
"name": "bus",
"supercategory": "none"
},
{
"id": 2,
"name": "car",
"supercategory": "none"
}
],
"licenses": [],
"info": null
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

View File

@ -0,0 +1,231 @@
{
"type": "MaskRCNN",
"backbone": {
"type": "ResNet",
"depth": 50,
"num_stages": 4,
"out_indices": [
0,
1,
2,
3
],
"frozen_stages": 1,
"norm_cfg": {
"type": "BN",
"requires_grad": true
},
"norm_eval": true,
"style": "pytorch",
"init_cfg": {
"type": "Pretrained",
"checkpoint": "torchvision://resnet50"
}
},
"neck": {
"type": "FPN",
"in_channels": [
256,
512,
1024,
2048
],
"out_channels": 256,
"num_outs": 5
},
"rpn_head": {
"type": "RPNHead",
"in_channels": 256,
"feat_channels": 256,
"anchor_generator": {
"type": "AnchorGenerator",
"scales": [
8
],
"ratios": [
0.5,
1.0,
2.0
],
"strides": [
4,
8,
16,
32,
64
]
},
"bbox_coder": {
"type": "DeltaXYWHBBoxCoder",
"target_means": [
0.0,
0.0,
0.0,
0.0
],
"target_stds": [
1.0,
1.0,
1.0,
1.0
]
},
"loss_cls": {
"type": "CrossEntropyLoss",
"use_sigmoid": true,
"loss_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 1.0
}
},
"roi_head": {
"type": "StandardRoIHead",
"bbox_roi_extractor": {
"type": "SingleRoIExtractor",
"roi_layer": {
"type": "RoIAlign",
"output_size": 7,
"sampling_ratio": 0
},
"out_channels": 256,
"featmap_strides": [
4,
8,
16,
32
]
},
"bbox_head": {
"type": "Shared2FCBBoxHead",
"in_channels": 256,
"fc_out_channels": 1024,
"roi_feat_size": 7,
"num_classes": 80,
"bbox_coder": {
"type": "DeltaXYWHBBoxCoder",
"target_means": [
0.0,
0.0,
0.0,
0.0
],
"target_stds": [
0.1,
0.1,
0.2,
0.2
]
},
"reg_class_agnostic": false,
"loss_cls": {
"type": "CrossEntropyLoss",
"use_sigmoid": false,
"loss_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 1.0
}
},
"mask_roi_extractor": {
"type": "SingleRoIExtractor",
"roi_layer": {
"type": "RoIAlign",
"output_size": 14,
"sampling_ratio": 0
},
"out_channels": 256,
"featmap_strides": [
4,
8,
16,
32
]
},
"mask_head": {
"type": "FCNMaskHead",
"num_convs": 4,
"in_channels": 256,
"conv_out_channels": 256,
"num_classes": 80,
"loss_mask": {
"type": "CrossEntropyLoss",
"use_mask": true,
"loss_weight": 1.0
}
}
},
"train_cfg": {
"rpn": {
"assigner": {
"type": "MaxIoUAssigner",
"pos_iou_thr": 0.7,
"neg_iou_thr": 0.3,
"min_pos_iou": 0.3,
"match_low_quality": true,
"ignore_iof_thr": -1
},
"sampler": {
"type": "RandomSampler",
"num": 256,
"pos_fraction": 0.5,
"neg_pos_ub": -1,
"add_gt_as_proposals": false
},
"allowed_border": -1,
"pos_weight": -1,
"debug": false
},
"rpn_proposal": {
"nms_pre": 2000,
"max_per_img": 1000,
"nms": {
"type": "nms",
"iou_threshold": 0.7
},
"min_bbox_size": 0
},
"rcnn": {
"assigner": {
"type": "MaxIoUAssigner",
"pos_iou_thr": 0.5,
"neg_iou_thr": 0.5,
"min_pos_iou": 0.5,
"match_low_quality": true,
"ignore_iof_thr": -1
},
"sampler": {
"type": "RandomSampler",
"num": 512,
"pos_fraction": 0.25,
"neg_pos_ub": -1,
"add_gt_as_proposals": true
},
"mask_size": 28,
"pos_weight": -1,
"debug": false
}
},
"test_cfg": {
"rpn": {
"nms_pre": 1000,
"max_per_img": 1000,
"nms": {
"type": "nms",
"iou_threshold": 0.7
},
"min_bbox_size": 0
},
"rcnn": {
"score_thr": 0.05,
"nms": {
"type": "nms",
"iou_threshold": 0.5
},
"max_per_img": 100,
"mask_thr_binary": 0.5
}
}
}

View File

@ -0,0 +1,110 @@
{
"type": "RetinaNet",
"backbone": {
"type": "ResNet",
"depth": 50,
"num_stages": 4,
"out_indices": [
0,
1,
2,
3
],
"frozen_stages": 1,
"norm_cfg": {
"type": "BN",
"requires_grad": true
},
"norm_eval": true,
"style": "pytorch",
"init_cfg": {
"type": "Pretrained",
"checkpoint": "torchvision://resnet50"
}
},
"neck": {
"type": "FPN",
"in_channels": [
256,
512,
1024,
2048
],
"out_channels": 256,
"start_level": 1,
"add_extra_convs": "on_input",
"num_outs": 5
},
"bbox_head": {
"type": "RetinaHead",
"num_classes": 80,
"in_channels": 256,
"stacked_convs": 4,
"feat_channels": 256,
"anchor_generator": {
"type": "AnchorGenerator",
"octave_base_scale": 4,
"scales_per_octave": 3,
"ratios": [
0.5,
1.0,
2.0
],
"strides": [
8,
16,
32,
64,
128
]
},
"bbox_coder": {
"type": "DeltaXYWHBBoxCoder",
"target_means": [
0.0,
0.0,
0.0,
0.0
],
"target_stds": [
1.0,
1.0,
1.0,
1.0
]
},
"loss_cls": {
"type": "FocalLoss",
"use_sigmoid": true,
"gamma": 2.0,
"alpha": 0.25,
"loss_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 1.0
}
},
"train_cfg": {
"assigner": {
"type": "MaxIoUAssigner",
"pos_iou_thr": 0.5,
"neg_iou_thr": 0.4,
"min_pos_iou": 0,
"ignore_iof_thr": -1
},
"allowed_border": -1,
"pos_weight": -1,
"debug": false
},
"test_cfg": {
"nms_pre": 1000,
"min_bbox_size": 0,
"score_thr": 0.05,
"nms": {
"type": "nms",
"iou_threshold": 0.5
},
"max_per_img": 100
}
}

View File

@ -1,68 +0,0 @@
import importlib
import mmcv
import pytest
import torch
from mmdeploy.mmdet.core.post_processing.bbox_nms import multiclass_nms
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_multiclass_nms_static():
import tensorrt as trt
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
boxes=dict(
min_shape=[1, 500, 4],
opt_shape=[1, 500, 4],
max_shape=[1, 500, 4]),
scores=dict(
min_shape=[1, 500, 80],
opt_shape=[1, 500, 80],
max_shape=[1, 500, 80])))
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
boxes = torch.rand(1, 500, 4).cuda()
scores = torch.rand(1, 500, 80).cuda()
max_output_boxes_per_class = 200
keep_top_k = 100
wrapped_func = WrapFunction(
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'boxes': boxes,
'scores': scores
},
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
'outputs: {}'.format(rewrite_outputs)

View File

@ -0,0 +1,549 @@
import importlib
import mmcv
import numpy as np
import pytest
import torch
import mmdeploy.apis.ncnn as ncnn_apis
import mmdeploy.apis.onnxruntime as ort_apis
import mmdeploy.apis.ppl as ppl_apis
import mmdeploy.apis.tensorrt as trt_apis
from mmdeploy.utils.test import SwitchBackendWrapper
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTDetector():
# force add backend wrapper regardless of plugins
# make sure TensorRTDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
# simplify backend inference
outputs = {
'dets': torch.rand(1, 100, 5).cuda(),
'labels': torch.rand(1, 100).cuda()
}
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
from mmdeploy.mmdet.apis.inference import TensorRTDetector
trt_detector = TensorRTDetector('', ['' for i in range(80)], 0)
imgs = [torch.rand(1, 3, 64, 64).cuda()]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = trt_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using TensorRTDetector'
SwitchBackendWrapper.recover(TRTWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimeDetector():
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimeDetector can use ORTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector
ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ort_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '\
'ONNXRuntimeDetector'
SwitchBackendWrapper.recover(ORTWrapper)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
def test_PPLDetector():
# force add backend wrapper regardless of plugins
# make sure PPLDetector can use PPLWrapper inside itself
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
# simplify backend inference
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
SwitchBackendWrapper.set(PPLWrapper, outputs=outputs)
from mmdeploy.mmdet.apis.inference import PPLDetector
ppl_detector = PPLDetector('', ['' for i in range(80)], 0)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ppl_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using PPLDetector'
SwitchBackendWrapper.recover(PPLWrapper)
def get_test_cfg_and_post_processing():
test_cfg = {
'nms_pre': 100,
'min_bbox_size': 0,
'score_thr': 0.05,
'nms': {
'type': 'nms',
'iou_threshold': 0.5
},
'max_per_img': 10
}
post_processing = {
'score_threshold': 0.05,
'iou_threshold': 0.5,
'max_output_boxes_per_class': 20,
'pre_top_k': -1,
'keep_top_k': 10,
'background_label_id': -1
}
return test_cfg, post_processing
def test_PartitionSingleStageDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
from mmdeploy.mmdet.apis.inference import PartitionSingleStageDetector
pss_detector = PartitionSingleStageDetector(['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
scores = torch.rand(1, 120, 80)
bboxes = torch.rand(1, 120, 4)
results = pss_detector.partition0_postprocess(scores=scores, bboxes=bboxes)
assert results is not None, 'failed to get output using '\
'partition0_postprocess of PartitionSingleStageDetector'
@pytest.mark.skipif(
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
def test_NCNNPSSDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
# force add backend wrapper regardless of plugins
# make sure NCNNPSSDetector can use NCNNWrapper inside itself
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 120, 80),
'boxes': torch.rand(1, 120, 4)
}
SwitchBackendWrapper.set(
NCNNWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
from mmdeploy.mmdet.apis.inference import NCNNPSSDetector
ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ncnn_pss_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using NCNNPSSDetector'
SwitchBackendWrapper.recover(NCNNWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimePSSDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimePSSDetector can use ORTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
outputs = [
np.random.rand(1, 120, 80).astype(np.float32),
np.random.rand(1, 120, 4).astype(np.float32)
]
SwitchBackendWrapper.set(
ORTWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector
ort_pss_detector = ONNXRuntimePSSDetector(
'', ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ort_pss_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'ONNXRuntimePSSDetector'
SwitchBackendWrapper.recover(ORTWrapper)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTPSSDetector():
test_cfg, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
# force add backend wrapper regardless of plugins
# make sure TensorRTPSSDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 120, 80).cuda(),
'boxes': torch.rand(1, 120, 4).cuda()
}
SwitchBackendWrapper.set(
TRTWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector
trt_pss_detector = TensorRTPSSDetector(
'', ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32).cuda()]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = trt_pss_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'TensorRTPSSDetector'
SwitchBackendWrapper.recover(TRTWrapper)
def prepare_model_deploy_cfgs():
test_cfg, post_processing = get_test_cfg_and_post_processing()
bbox_roi_extractor = {
'type': 'SingleRoIExtractor',
'roi_layer': {
'type': 'RoIAlign',
'output_size': 7,
'sampling_ratio': 0
},
'out_channels': 8,
'featmap_strides': [4]
}
bbox_head = {
'type': 'Shared2FCBBoxHead',
'in_channels': 8,
'fc_out_channels': 1024,
'roi_feat_size': 7,
'num_classes': 80,
'bbox_coder': {
'type': 'DeltaXYWHBBoxCoder',
'target_means': [0.0, 0.0, 0.0, 0.0],
'target_stds': [0.1, 0.1, 0.2, 0.2]
},
'reg_class_agnostic': False,
'loss_cls': {
'type': 'CrossEntropyLoss',
'use_sigmoid': False,
'loss_weight': 1.0
},
'loss_bbox': {
'type': 'L1Loss',
'loss_weight': 1.0
}
}
roi_head = dict(bbox_roi_extractor=bbox_roi_extractor, bbox_head=bbox_head)
model_cfg = mmcv.Config(
dict(
model=dict(
test_cfg=dict(rpn=test_cfg, rcnn=test_cfg),
roi_head=roi_head)))
deploy_cfg = mmcv.Config(
dict(codebase_config=dict(post_processing=post_processing)))
return model_cfg, deploy_cfg
def test_PartitionTwoStageDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
pts_detector = PartitionTwoStageDetector(['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
feats = [torch.randn(1, 8, 14, 14) for i in range(5)]
scores = torch.rand(1, 50, 1)
bboxes = torch.rand(1, 50, 4)
bboxes[..., 2:4] = 2 * bboxes[..., :2]
results = pts_detector.partition0_postprocess(
x=feats, scores=scores, bboxes=bboxes)
assert results is not None, 'failed to get output using '\
'partition0_postprocess of PartitionTwoStageDetector'
rois = torch.rand(1, 10, 5)
cls_score = torch.rand(10, 81)
bbox_pred = torch.rand(10, 320)
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = pts_detector.partition1_postprocess(
rois=rois,
cls_score=cls_score,
bbox_pred=bbox_pred,
img_metas=img_metas)
assert results is not None, 'failed to get output using '\
'partition1_postprocess of PartitionTwoStageDetector'
class DummyPTSDetector(torch.nn.Module):
"""A dummy wrapper for unit tests."""
def __init__(self, *args, **kwargs):
self.output_names = ['dets', 'labels']
def partition0_postprocess(self, *args, **kwargs):
return self.outputs0
def partition1_postprocess(self, *args, **kwargs):
return self.outputs1
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTPTSDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
# force add backend wrapper regardless of plugins
# make sure TensorRTPTSDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 12, 80).cuda(),
'boxes': torch.rand(1, 12, 4).cuda(),
'cls_score': torch.rand(1, 12, 80).cuda(),
'bbox_pred': torch.rand(1, 12, 4).cuda()
}
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
TRTWrapper.model_cfg = model_cfg
TRTWrapper.deploy_cfg = deploy_cfg
# replace original function in PartitionTwoStageDetector
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
PartitionTwoStageDetector.partition0_postprocess = \
DummyPTSDetector.partition0_postprocess
PartitionTwoStageDetector.partition1_postprocess = \
DummyPTSDetector.partition1_postprocess
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2
PartitionTwoStageDetector.outputs1 = [
torch.rand(1, 9, 5).cuda(),
torch.rand(1, 9).cuda()
]
PartitionTwoStageDetector.device_id = 0
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector
trt_pts_detector = TensorRTPTSDetector(['', ''], ['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32).cuda()]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = trt_pts_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'TensorRTPTSDetector'
SwitchBackendWrapper.recover(TRTWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimePTSDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimePTSDetector can use TRTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
outputs = [
np.random.rand(1, 12, 80).astype(np.float32),
np.random.rand(1, 12, 4).astype(np.float32),
] * 2
SwitchBackendWrapper.set(
ORTWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
# replace original function in PartitionTwoStageDetector
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
PartitionTwoStageDetector.partition0_postprocess = \
DummyPTSDetector.partition0_postprocess
PartitionTwoStageDetector.partition1_postprocess = \
DummyPTSDetector.partition1_postprocess
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
PartitionTwoStageDetector.outputs1 = [
torch.rand(1, 9, 5), torch.rand(1, 9)
]
PartitionTwoStageDetector.device_id = -1
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector
ort_pts_detector = ONNXRuntimePTSDetector(['', ''],
['' for i in range(80)],
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ort_pts_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'ONNXRuntimePTSDetector'
SwitchBackendWrapper.recover(ORTWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
def test_NCNNPTSDetector():
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
num_outs = dict(model=dict(neck=dict(num_outs=0)))
model_cfg.update(num_outs)
# force add backend wrapper regardless of plugins
# make sure NCNNPTSDetector can use TRTWrapper inside itself
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
# simplify backend inference
outputs = {
'scores': torch.rand(1, 12, 80),
'boxes': torch.rand(1, 12, 4),
'cls_score': torch.rand(1, 12, 80),
'bbox_pred': torch.rand(1, 12, 4)
}
SwitchBackendWrapper.set(
NCNNWrapper,
outputs=outputs,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg)
# replace original function in PartitionTwoStageDetector
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
PartitionTwoStageDetector.partition0_postprocess = \
DummyPTSDetector.partition0_postprocess
PartitionTwoStageDetector.partition1_postprocess = \
DummyPTSDetector.partition1_postprocess
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
PartitionTwoStageDetector.outputs1 = [
torch.rand(1, 9, 5), torch.rand(1, 9)
]
PartitionTwoStageDetector.device_id = -1
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
from mmdeploy.mmdet.apis.inference import NCNNPTSDetector
ncnn_pts_detector = NCNNPTSDetector(
[''] * 4, [''] * 80,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=0)
imgs = [torch.rand(1, 3, 32, 32)]
img_metas = [[{
'ori_shape': [32, 32, 3],
'img_shape': [32, 32, 3],
'scale_factor': [2.09, 1.87, 2.09, 1.87],
}]]
results = ncnn_pts_detector.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '
'NCNNPTSDetector'
SwitchBackendWrapper.recover(NCNNWrapper)
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_build_detector():
_, post_processing = get_test_cfg_and_post_processing()
model_cfg = mmcv.Config(dict(data=dict(test={'type': 'CocoDataset'})))
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(
type='mmdet', post_processing=post_processing)))
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
SwitchBackendWrapper.set(
ORTWrapper, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
from mmdeploy.apis.utils import init_backend_model
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
assert detector is not None
SwitchBackendWrapper.recover(ORTWrapper)

View File

@ -0,0 +1,147 @@
import importlib
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_multiclass_nms_static():
import tensorrt as trt
from mmdeploy.mmdet.core import multiclass_nms
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 20),
model_inputs=[
dict(
input_shapes=dict(
boxes=dict(
min_shape=[1, 5, 4],
opt_shape=[1, 5, 4],
max_shape=[1, 5, 4]),
scores=dict(
min_shape=[1, 5, 8],
opt_shape=[1, 5, 8],
max_shape=[1, 5, 8])))
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=20,
pre_top_k=-1,
keep_top_k=10,
background_label_id=-1,
))))
boxes = torch.rand(1, 5, 4).cuda()
scores = torch.rand(1, 5, 8).cuda()
max_output_boxes_per_class = 20
keep_top_k = 10
wrapped_func = WrapFunction(
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'boxes': boxes,
'scores': scores
},
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
'outputs: {}'.format(rewrite_outputs)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_delta2bbox(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(type=backend_type, model_inputs=None),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
# wrap function to enable rewrite
def delta2bbox(*args, **kwargs):
import mmdet
return mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox(
*args, **kwargs)
rois = torch.rand(1, 5, 4)
deltas = torch.rand(1, 5, 4)
original_outputs = delta2bbox(rois, deltas)
# wrap function to nn.Module, enable torch.onn.export
wrapped_func = WrapFunction(delta2bbox)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'rois': rois,
'deltas': deltas
},
deploy_cfg=deploy_cfg)
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_tblr2bbox(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
backend_config=dict(type=backend_type, model_inputs=None),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
# wrap function to enable rewrite
def tblr2bboxes(*args, **kwargs):
import mmdet
return mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes(
*args, **kwargs)
priors = torch.rand(1, 5, 4)
tblr = torch.rand(1, 5, 4)
original_outputs = tblr2bboxes(priors, tblr)
# wrap function to nn.Module, enable torch.onn.export
wrapped_func = WrapFunction(tblr2bboxes)
rewrite_outputs = get_rewrite_outputs(
wrapped_func,
model_inputs={
'priors': priors,
'tblr': tblr
},
deploy_cfg=deploy_cfg)
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
def test_distance2bbox():
from mmdeploy.mmdet.core import distance2bbox
points = torch.rand(3, 2)
distance = torch.rand(3, 4)
bbox = distance2bbox(points, distance)
assert bbox.shape == torch.Size([3, 4])

View File

@ -0,0 +1,104 @@
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.apis.utils import (build_dataloader, build_dataset, create_input,
get_tensor_from_input)
from mmdeploy.utils.constants import Codebase, Task
def test_create_input():
task = Task.OBJECT_DETECTION
test = dict(pipeline=[{
'type': 'LoadImageFromWebcam'
}, {
'type':
'MultiScaleFlipAug',
'img_scale': [32, 32],
'flip':
False,
'transforms': [{
'type': 'Resize',
'keep_ratio': True
}, {
'type': 'RandomFlip'
}, {
'type': 'Normalize',
'mean': [123.675, 116.28, 103.53],
'std': [58.395, 57.12, 57.375],
'to_rgb': True
}, {
'type': 'Pad',
'size_divisor': 32
}, {
'type': 'DefaultFormatBundle'
}, {
'type': 'Collect',
'keys': ['img']
}]
}])
data = dict(test=test)
model_cfg = mmcv.Config(dict(data=data))
imgs = [np.random.rand(32, 32, 3)]
inputs = create_input(
Codebase.MMDET,
task,
model_cfg,
imgs,
input_shape=(32, 32),
device='cpu')
assert inputs is not None, 'Failed to create input'
@pytest.mark.parametrize('input_data', [{'img': [torch.ones(3, 4, 5)]}])
def test_get_tensor_from_input(input_data):
inputs = get_tensor_from_input(Codebase.MMDET, input_data)
assert inputs is not None, 'Failed to get tensor from input'
def test_build_dataset():
data = dict(
test={
'type': 'CocoDataset',
'ann_file': 'tests/test_mmdet/data/coco_sample.json',
'img_prefix': 'tests/test_mmdet/data/imgs/',
'pipeline': [
{
'type': 'LoadImageFromFile'
},
]
})
dataset_cfg = mmcv.Config(dict(data=data))
dataset = build_dataset(
Codebase.MMDET, dataset_cfg=dataset_cfg, dataset_type='test')
assert dataset is not None, 'Failed to build dataset'
dataloader = build_dataloader(Codebase.MMDET, dataset, 1, 1)
assert dataloader is not None, 'Failed to build dataloader'
def test_clip_bboxes():
from mmdeploy.mmdet.export import clip_bboxes
x1 = torch.rand(3, 2) * 224
y1 = torch.rand(3, 2) * 224
x2 = x1 * 2
y2 = y1 * 2
outs = clip_bboxes(x1, y1, x2, y2, [224, 224])
for out in outs:
assert int(out.max()) <= 224
def test_pad_with_value():
from mmdeploy.mmdet.export import pad_with_value
x = torch.rand(3, 2)
padded_x = pad_with_value(x, pad_dim=1, pad_size=4, pad_value=0)
assert np.allclose(
padded_x.shape, torch.Size([3, 6]), rtol=1e-03, atol=1e-05)
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('partition_type', ['single_stage', 'two_stage'])
def test_get_partition_cfg(partition_type):
from mmdeploy.mmdet.export import get_partition_cfg
partition_cfg = get_partition_cfg(partition_type=partition_type)
assert partition_cfg is not None

View File

@ -0,0 +1,293 @@
import copy
import importlib
import os
import random
import tempfile
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.utils.constants import Backend, Codebase
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
get_rewrite_outputs)
def seed_everything(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def get_anchor_head_model():
"""AnchorHead Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import AnchorHead
model = AnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def get_fcos_head_model():
"""FCOS Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import FCOSHead
model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def get_rpn_head_model():
"""RPN Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import RPNHead
model = RPNHead(in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_anchor_head_get_bboxes(backend_type):
"""Test get_bboxes rewrite of anchor head."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_get_bboxes_of_fcos_head(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
fcos_head = get_fcos_head_model()
fcos_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
seed_everything(9101)
centernesses = [
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'centernesses': centernesses,
'img_metas': img_metas
}
model_outputs = get_model_outputs(fcos_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
fcos_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'centernesses': centernesses
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
def _replace_r50_with_r18(model):
"""Replace ResNet50 with ResNet18 in config."""
model = copy.deepcopy(model)
if model.backbone.type == 'ResNet':
model.backbone.depth = 18
model.backbone.base_channels = 2
model.neck.in_channels = [2, 4, 8, 16]
return model
@pytest.mark.parametrize('model_cfg_path', [
'tests/test_mmdet/data/single_stage_model.json',
'tests/test_mmdet/data/mask_model.json'
])
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_forward_of_base_detector_and_visualize(model_cfg_path):
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
model_cfg.model = _replace_r50_with_r18(model_cfg.model)
from mmdet.apis import init_detector
model = init_detector(model_cfg, None, 'cpu')
img = torch.randn(1, 3, 64, 64)
rewrite_inputs = {'img': img}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
from mmdeploy.apis.utils import visualize
output_file = tempfile.NamedTemporaryFile(suffix='.jpg').name
model.CLASSES = [''] * 80
visualize(
Codebase.MMDET,
img.squeeze().permute(1, 2, 0).numpy(),
result=[torch.rand(0, 5).numpy()] * 80,
model=model,
output_file=output_file,
backend=Backend.ONNXRUNTIME,
show_result=False)
assert rewrite_outputs is not None

View File

@ -1,100 +0,0 @@
import os.path as osp
import mmcv
import numpy as np
import torch
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
get_rewrite_outputs)
data_path = osp.join(osp.dirname(__file__), 'data')
def get_anchor_head_model():
"""AnchorHead Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import AnchorHead
model = AnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def test_anchor_head_get_bboxes():
"""Test get_bboxes rewrite of anchor head."""
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 2)
}]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
# The data of anchor_head_get_bboxes.pkl contains two parts:
# cls_score(list(Tensor)) and bboxes(list(Tensor)),
# where each torch.Tensor is generated by torch.rand().
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
retina_head_data = 'anchor_head_get_bboxes.pkl'
feats = mmcv.load(osp.join(data_path, retina_head_data))
cls_score = feats[:5]
bboxes = feats[5:]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)