mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Unittests] MMDet unittests (#112)
* add mmdet unittests * remove redundant img meta info * import wrapper from backend util * force add wrapper * use smaller nuance * add seed everything * add create input part and some inference part unittests * fix lint * skip ppl * remove pyppl * add dataset files * import ncnn_ext inside ncnn warpper * use cpu device to create input * add pssd and ptsd unittests * clear mmdet unittests * wrap function to enable rewrite * refine codes and resolve comments * move mmdet inside test func * remove redundant line * test visualize in mmdeploy.apis * use less memory * resolve comments * fix ci * move pytest.skip inside test function and use 3 channel input
This commit is contained in:
parent
6fdf6b8616
commit
d4828c7836
@ -1,10 +1,10 @@
|
|||||||
|
import importlib
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
import ncnn
|
import ncnn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmdeploy.apis.ncnn import ncnn_ext
|
|
||||||
from mmdeploy.utils.timer import TimeCounter
|
from mmdeploy.utils.timer import TimeCounter
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +37,9 @@ class NCNNWrapper(torch.nn.Module):
|
|||||||
super(NCNNWrapper, self).__init__()
|
super(NCNNWrapper, self).__init__()
|
||||||
|
|
||||||
net = ncnn.Net()
|
net = ncnn.Net()
|
||||||
ncnn_ext.register_mm_custom_layers(net)
|
if importlib.util.find_spec('mmdeploy.apis.ncnn.ncnn_ext'):
|
||||||
|
from mmdeploy.apis.ncnn import ncnn_ext
|
||||||
|
ncnn_ext.register_mm_custom_layers(net)
|
||||||
net.load_param(param_file)
|
net.load_param(param_file)
|
||||||
net.load_model(bin_file)
|
net.load_model(bin_file)
|
||||||
|
|
||||||
|
@ -333,12 +333,12 @@ def build_dataloader(codebase: Codebase, dataset: Dataset,
|
|||||||
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_from_input(codebase: Codebase, input_data: tuple):
|
def get_tensor_from_input(codebase: Codebase, input_data: Dict[str, Any]):
|
||||||
"""Get input tensor from input data.
|
"""Get input tensor from input data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
codebase (Codebase): Specifying codebase type.
|
codebase (Codebase): Specifying codebase type.
|
||||||
input_data (tuple): Input data containing meta info and image tensor.
|
input_data (dict): Input data containing meta info and image tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: An image in `Tensor`.
|
torch.Tensor: An image in `Tensor`.
|
||||||
|
@ -48,13 +48,13 @@ def tblr2bboxes(ctx,
|
|||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
|
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
|
||||||
backend='ncnn')
|
backend='ncnn')
|
||||||
def delta2bbox_ncnn(ctx,
|
def tblr2bboxes_ncnn(ctx,
|
||||||
priors,
|
priors,
|
||||||
tblr,
|
tblr,
|
||||||
normalizer=4.0,
|
normalizer=4.0,
|
||||||
normalize_by_wh=True,
|
normalize_by_wh=True,
|
||||||
max_shape=None,
|
max_shape=None,
|
||||||
clip_border=True):
|
clip_border=True):
|
||||||
"""Rewrite for ONNX exporting of NCNN backend."""
|
"""Rewrite for ONNX exporting of NCNN backend."""
|
||||||
assert priors.size(0) == tblr.size(0)
|
assert priors.size(0) == tblr.size(0)
|
||||||
if priors.ndim == 3:
|
if priors.ndim == 3:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Optional, Sequence, Union
|
from typing import Any, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -151,11 +151,11 @@ def build_dataloader(dataset: Dataset,
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_from_input(input_data: tuple):
|
def get_tensor_from_input(input_data: Dict[str, Any]):
|
||||||
"""Get input tensor from input data.
|
"""Get input tensor from input data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data (tuple): Input data containing meta info and image tensor.
|
input_data (dict): Input data containing meta info and image tensor.
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: An image in `Tensor`.
|
torch.Tensor: An image in `Tensor`.
|
||||||
"""
|
"""
|
||||||
|
@ -40,8 +40,13 @@ def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
|
|||||||
MaskTestMixin.simple_test_mask')
|
MaskTestMixin.simple_test_mask')
|
||||||
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
|
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
|
||||||
det_labels, **kwargs):
|
det_labels, **kwargs):
|
||||||
assert det_bboxes.shape[1] != 0, 'Can not record MaskHead as it \
|
if det_bboxes.shape[1] == 0:
|
||||||
has not been executed this time'
|
bboxes_shape, labels_shape = list(det_bboxes.shape), list(
|
||||||
|
det_labels.shape)
|
||||||
|
bboxes_shape[1], labels_shape[1] = 1, 1
|
||||||
|
det_bboxes = torch.tensor([[[0., 0., 1., 1.,
|
||||||
|
0.]]]).expand(*bboxes_shape)
|
||||||
|
det_labels = torch.tensor([[0]]).expand(*labels_shape)
|
||||||
|
|
||||||
batch_size = det_bboxes.size(0)
|
batch_size = det_bboxes.size(0)
|
||||||
det_bboxes = det_bboxes[..., :4]
|
det_bboxes = det_bboxes[..., :4]
|
||||||
|
@ -11,6 +11,8 @@ def size_of_tensor_static(ctx, self, *args):
|
|||||||
ret = ctx.origin_func(self, *args)
|
ret = ctx.origin_func(self, *args)
|
||||||
if isinstance(ret, torch.Tensor):
|
if isinstance(ret, torch.Tensor):
|
||||||
ret = int(ret)
|
ret = int(ret)
|
||||||
|
elif isinstance(ret, int):
|
||||||
|
return (ret)
|
||||||
else:
|
else:
|
||||||
ret = [int(r) for r in ret]
|
ret = [int(r) for r in ret]
|
||||||
ret = tuple(ret)
|
ret = tuple(ret)
|
||||||
|
@ -57,6 +57,54 @@ class WrapModel(nn.Module):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchBackendWrapper:
|
||||||
|
"""A switcher for backend wrapper for unit tests.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
|
>>> from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
>>> SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
|
||||||
|
>>> ...
|
||||||
|
>>> SwitchBackendWrapper.recover(ORTWrapper)
|
||||||
|
"""
|
||||||
|
init = None
|
||||||
|
forward = None
|
||||||
|
call = None
|
||||||
|
|
||||||
|
class BackendWrapper(torch.nn.Module):
|
||||||
|
"""A dummy wrapper for unit tests."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.output_names = ['dets', 'labels']
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.outputs
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwds):
|
||||||
|
return self.forward(*args, **kwds)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set(obj, **kwargs):
|
||||||
|
"""Replace attributes in backend wrappers with dummy items."""
|
||||||
|
SwitchBackendWrapper.init = obj.__init__
|
||||||
|
SwitchBackendWrapper.forward = obj.forward
|
||||||
|
SwitchBackendWrapper.call = obj.__call__
|
||||||
|
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
|
||||||
|
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
|
||||||
|
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(obj, k, v)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def recover(obj):
|
||||||
|
assert SwitchBackendWrapper.init is not None and \
|
||||||
|
SwitchBackendWrapper.forward is not None,\
|
||||||
|
'recover method must be called after exchange'
|
||||||
|
obj.__init__ = SwitchBackendWrapper.init
|
||||||
|
obj.forward = SwitchBackendWrapper.forward
|
||||||
|
obj.__call__ = SwitchBackendWrapper.call
|
||||||
|
|
||||||
|
|
||||||
def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
|
def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
|
||||||
actual: List[Union[torch.Tensor, np.ndarray]],
|
actual: List[Union[torch.Tensor, np.ndarray]],
|
||||||
tolerate_small_mismatch: bool = False):
|
tolerate_small_mismatch: bool = False):
|
||||||
@ -180,6 +228,9 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
|
|||||||
backend_feats[input_names[i]] = feature_list[i]
|
backend_feats[input_names[i]] = feature_list[i]
|
||||||
else:
|
else:
|
||||||
backend_feats[str(i)] = feature_list[i]
|
backend_feats[str(i)] = feature_list[i]
|
||||||
|
elif backend == Backend.NCNN:
|
||||||
|
return ctx_outputs
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
backend_outputs = backend_model.forward(backend_feats)
|
backend_outputs = backend_model.forward(backend_feats)
|
||||||
return backend_outputs
|
return backend_outputs
|
||||||
|
Binary file not shown.
39
tests/test_mmdet/data/coco_sample.json
Normal file
39
tests/test_mmdet/data/coco_sample.json
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"file_name": "000000000139.jpg",
|
||||||
|
"height": 800,
|
||||||
|
"width": 800,
|
||||||
|
"id": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"annotations": [
|
||||||
|
{
|
||||||
|
"bbox": [
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
20,
|
||||||
|
20
|
||||||
|
],
|
||||||
|
"area": 400.00,
|
||||||
|
"score": 1.0,
|
||||||
|
"category_id": 1,
|
||||||
|
"id": 1,
|
||||||
|
"image_id": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"categories": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "bus",
|
||||||
|
"supercategory": "none"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "car",
|
||||||
|
"supercategory": "none"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"licenses": [],
|
||||||
|
"info": null
|
||||||
|
}
|
BIN
tests/test_mmdet/data/imgs/000000000139.jpg
Executable file
BIN
tests/test_mmdet/data/imgs/000000000139.jpg
Executable file
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
231
tests/test_mmdet/data/mask_model.json
Normal file
231
tests/test_mmdet/data/mask_model.json
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
{
|
||||||
|
"type": "MaskRCNN",
|
||||||
|
"backbone": {
|
||||||
|
"type": "ResNet",
|
||||||
|
"depth": 50,
|
||||||
|
"num_stages": 4,
|
||||||
|
"out_indices": [
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
"frozen_stages": 1,
|
||||||
|
"norm_cfg": {
|
||||||
|
"type": "BN",
|
||||||
|
"requires_grad": true
|
||||||
|
},
|
||||||
|
"norm_eval": true,
|
||||||
|
"style": "pytorch",
|
||||||
|
"init_cfg": {
|
||||||
|
"type": "Pretrained",
|
||||||
|
"checkpoint": "torchvision://resnet50"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"neck": {
|
||||||
|
"type": "FPN",
|
||||||
|
"in_channels": [
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048
|
||||||
|
],
|
||||||
|
"out_channels": 256,
|
||||||
|
"num_outs": 5
|
||||||
|
},
|
||||||
|
"rpn_head": {
|
||||||
|
"type": "RPNHead",
|
||||||
|
"in_channels": 256,
|
||||||
|
"feat_channels": 256,
|
||||||
|
"anchor_generator": {
|
||||||
|
"type": "AnchorGenerator",
|
||||||
|
"scales": [
|
||||||
|
8
|
||||||
|
],
|
||||||
|
"ratios": [
|
||||||
|
0.5,
|
||||||
|
1.0,
|
||||||
|
2.0
|
||||||
|
],
|
||||||
|
"strides": [
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
64
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"bbox_coder": {
|
||||||
|
"type": "DeltaXYWHBBoxCoder",
|
||||||
|
"target_means": [
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"target_stds": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"loss_cls": {
|
||||||
|
"type": "CrossEntropyLoss",
|
||||||
|
"use_sigmoid": true,
|
||||||
|
"loss_weight": 1.0
|
||||||
|
},
|
||||||
|
"loss_bbox": {
|
||||||
|
"type": "L1Loss",
|
||||||
|
"loss_weight": 1.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"roi_head": {
|
||||||
|
"type": "StandardRoIHead",
|
||||||
|
"bbox_roi_extractor": {
|
||||||
|
"type": "SingleRoIExtractor",
|
||||||
|
"roi_layer": {
|
||||||
|
"type": "RoIAlign",
|
||||||
|
"output_size": 7,
|
||||||
|
"sampling_ratio": 0
|
||||||
|
},
|
||||||
|
"out_channels": 256,
|
||||||
|
"featmap_strides": [
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"bbox_head": {
|
||||||
|
"type": "Shared2FCBBoxHead",
|
||||||
|
"in_channels": 256,
|
||||||
|
"fc_out_channels": 1024,
|
||||||
|
"roi_feat_size": 7,
|
||||||
|
"num_classes": 80,
|
||||||
|
"bbox_coder": {
|
||||||
|
"type": "DeltaXYWHBBoxCoder",
|
||||||
|
"target_means": [
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"target_stds": [
|
||||||
|
0.1,
|
||||||
|
0.1,
|
||||||
|
0.2,
|
||||||
|
0.2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"reg_class_agnostic": false,
|
||||||
|
"loss_cls": {
|
||||||
|
"type": "CrossEntropyLoss",
|
||||||
|
"use_sigmoid": false,
|
||||||
|
"loss_weight": 1.0
|
||||||
|
},
|
||||||
|
"loss_bbox": {
|
||||||
|
"type": "L1Loss",
|
||||||
|
"loss_weight": 1.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"mask_roi_extractor": {
|
||||||
|
"type": "SingleRoIExtractor",
|
||||||
|
"roi_layer": {
|
||||||
|
"type": "RoIAlign",
|
||||||
|
"output_size": 14,
|
||||||
|
"sampling_ratio": 0
|
||||||
|
},
|
||||||
|
"out_channels": 256,
|
||||||
|
"featmap_strides": [
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mask_head": {
|
||||||
|
"type": "FCNMaskHead",
|
||||||
|
"num_convs": 4,
|
||||||
|
"in_channels": 256,
|
||||||
|
"conv_out_channels": 256,
|
||||||
|
"num_classes": 80,
|
||||||
|
"loss_mask": {
|
||||||
|
"type": "CrossEntropyLoss",
|
||||||
|
"use_mask": true,
|
||||||
|
"loss_weight": 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train_cfg": {
|
||||||
|
"rpn": {
|
||||||
|
"assigner": {
|
||||||
|
"type": "MaxIoUAssigner",
|
||||||
|
"pos_iou_thr": 0.7,
|
||||||
|
"neg_iou_thr": 0.3,
|
||||||
|
"min_pos_iou": 0.3,
|
||||||
|
"match_low_quality": true,
|
||||||
|
"ignore_iof_thr": -1
|
||||||
|
},
|
||||||
|
"sampler": {
|
||||||
|
"type": "RandomSampler",
|
||||||
|
"num": 256,
|
||||||
|
"pos_fraction": 0.5,
|
||||||
|
"neg_pos_ub": -1,
|
||||||
|
"add_gt_as_proposals": false
|
||||||
|
},
|
||||||
|
"allowed_border": -1,
|
||||||
|
"pos_weight": -1,
|
||||||
|
"debug": false
|
||||||
|
},
|
||||||
|
"rpn_proposal": {
|
||||||
|
"nms_pre": 2000,
|
||||||
|
"max_per_img": 1000,
|
||||||
|
"nms": {
|
||||||
|
"type": "nms",
|
||||||
|
"iou_threshold": 0.7
|
||||||
|
},
|
||||||
|
"min_bbox_size": 0
|
||||||
|
},
|
||||||
|
"rcnn": {
|
||||||
|
"assigner": {
|
||||||
|
"type": "MaxIoUAssigner",
|
||||||
|
"pos_iou_thr": 0.5,
|
||||||
|
"neg_iou_thr": 0.5,
|
||||||
|
"min_pos_iou": 0.5,
|
||||||
|
"match_low_quality": true,
|
||||||
|
"ignore_iof_thr": -1
|
||||||
|
},
|
||||||
|
"sampler": {
|
||||||
|
"type": "RandomSampler",
|
||||||
|
"num": 512,
|
||||||
|
"pos_fraction": 0.25,
|
||||||
|
"neg_pos_ub": -1,
|
||||||
|
"add_gt_as_proposals": true
|
||||||
|
},
|
||||||
|
"mask_size": 28,
|
||||||
|
"pos_weight": -1,
|
||||||
|
"debug": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"test_cfg": {
|
||||||
|
"rpn": {
|
||||||
|
"nms_pre": 1000,
|
||||||
|
"max_per_img": 1000,
|
||||||
|
"nms": {
|
||||||
|
"type": "nms",
|
||||||
|
"iou_threshold": 0.7
|
||||||
|
},
|
||||||
|
"min_bbox_size": 0
|
||||||
|
},
|
||||||
|
"rcnn": {
|
||||||
|
"score_thr": 0.05,
|
||||||
|
"nms": {
|
||||||
|
"type": "nms",
|
||||||
|
"iou_threshold": 0.5
|
||||||
|
},
|
||||||
|
"max_per_img": 100,
|
||||||
|
"mask_thr_binary": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
110
tests/test_mmdet/data/single_stage_model.json
Normal file
110
tests/test_mmdet/data/single_stage_model.json
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
{
|
||||||
|
"type": "RetinaNet",
|
||||||
|
"backbone": {
|
||||||
|
"type": "ResNet",
|
||||||
|
"depth": 50,
|
||||||
|
"num_stages": 4,
|
||||||
|
"out_indices": [
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
"frozen_stages": 1,
|
||||||
|
"norm_cfg": {
|
||||||
|
"type": "BN",
|
||||||
|
"requires_grad": true
|
||||||
|
},
|
||||||
|
"norm_eval": true,
|
||||||
|
"style": "pytorch",
|
||||||
|
"init_cfg": {
|
||||||
|
"type": "Pretrained",
|
||||||
|
"checkpoint": "torchvision://resnet50"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"neck": {
|
||||||
|
"type": "FPN",
|
||||||
|
"in_channels": [
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048
|
||||||
|
],
|
||||||
|
"out_channels": 256,
|
||||||
|
"start_level": 1,
|
||||||
|
"add_extra_convs": "on_input",
|
||||||
|
"num_outs": 5
|
||||||
|
},
|
||||||
|
"bbox_head": {
|
||||||
|
"type": "RetinaHead",
|
||||||
|
"num_classes": 80,
|
||||||
|
"in_channels": 256,
|
||||||
|
"stacked_convs": 4,
|
||||||
|
"feat_channels": 256,
|
||||||
|
"anchor_generator": {
|
||||||
|
"type": "AnchorGenerator",
|
||||||
|
"octave_base_scale": 4,
|
||||||
|
"scales_per_octave": 3,
|
||||||
|
"ratios": [
|
||||||
|
0.5,
|
||||||
|
1.0,
|
||||||
|
2.0
|
||||||
|
],
|
||||||
|
"strides": [
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
64,
|
||||||
|
128
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"bbox_coder": {
|
||||||
|
"type": "DeltaXYWHBBoxCoder",
|
||||||
|
"target_means": [
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"target_stds": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"loss_cls": {
|
||||||
|
"type": "FocalLoss",
|
||||||
|
"use_sigmoid": true,
|
||||||
|
"gamma": 2.0,
|
||||||
|
"alpha": 0.25,
|
||||||
|
"loss_weight": 1.0
|
||||||
|
},
|
||||||
|
"loss_bbox": {
|
||||||
|
"type": "L1Loss",
|
||||||
|
"loss_weight": 1.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train_cfg": {
|
||||||
|
"assigner": {
|
||||||
|
"type": "MaxIoUAssigner",
|
||||||
|
"pos_iou_thr": 0.5,
|
||||||
|
"neg_iou_thr": 0.4,
|
||||||
|
"min_pos_iou": 0,
|
||||||
|
"ignore_iof_thr": -1
|
||||||
|
},
|
||||||
|
"allowed_border": -1,
|
||||||
|
"pos_weight": -1,
|
||||||
|
"debug": false
|
||||||
|
},
|
||||||
|
"test_cfg": {
|
||||||
|
"nms_pre": 1000,
|
||||||
|
"min_bbox_size": 0,
|
||||||
|
"score_thr": 0.05,
|
||||||
|
"nms": {
|
||||||
|
"type": "nms",
|
||||||
|
"iou_threshold": 0.5
|
||||||
|
},
|
||||||
|
"max_per_img": 100
|
||||||
|
}
|
||||||
|
}
|
@ -1,68 +0,0 @@
|
|||||||
import importlib
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from mmdeploy.mmdet.core.post_processing.bbox_nms import multiclass_nms
|
|
||||||
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
|
||||||
def test_multiclass_nms_static():
|
|
||||||
|
|
||||||
import tensorrt as trt
|
|
||||||
deploy_cfg = mmcv.Config(
|
|
||||||
dict(
|
|
||||||
onnx_config=dict(
|
|
||||||
output_names=['dets', 'labels'], input_shape=None),
|
|
||||||
backend_config=dict(
|
|
||||||
type='tensorrt',
|
|
||||||
common_config=dict(
|
|
||||||
fp16_mode=False,
|
|
||||||
log_level=trt.Logger.INFO,
|
|
||||||
max_workspace_size=1 << 30),
|
|
||||||
model_inputs=[
|
|
||||||
dict(
|
|
||||||
input_shapes=dict(
|
|
||||||
boxes=dict(
|
|
||||||
min_shape=[1, 500, 4],
|
|
||||||
opt_shape=[1, 500, 4],
|
|
||||||
max_shape=[1, 500, 4]),
|
|
||||||
scores=dict(
|
|
||||||
min_shape=[1, 500, 80],
|
|
||||||
opt_shape=[1, 500, 80],
|
|
||||||
max_shape=[1, 500, 80])))
|
|
||||||
]),
|
|
||||||
codebase_config=dict(
|
|
||||||
type='mmdet',
|
|
||||||
task='ObjectDetection',
|
|
||||||
post_processing=dict(
|
|
||||||
score_threshold=0.05,
|
|
||||||
iou_threshold=0.5,
|
|
||||||
max_output_boxes_per_class=200,
|
|
||||||
pre_top_k=-1,
|
|
||||||
keep_top_k=100,
|
|
||||||
background_label_id=-1,
|
|
||||||
))))
|
|
||||||
|
|
||||||
boxes = torch.rand(1, 500, 4).cuda()
|
|
||||||
scores = torch.rand(1, 500, 80).cuda()
|
|
||||||
max_output_boxes_per_class = 200
|
|
||||||
keep_top_k = 100
|
|
||||||
wrapped_func = WrapFunction(
|
|
||||||
multiclass_nms,
|
|
||||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
|
||||||
keep_top_k=keep_top_k)
|
|
||||||
rewrite_outputs = get_rewrite_outputs(
|
|
||||||
wrapped_func,
|
|
||||||
model_inputs={
|
|
||||||
'boxes': boxes,
|
|
||||||
'scores': scores
|
|
||||||
},
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
|
||||||
'outputs: {}'.format(rewrite_outputs)
|
|
549
tests/test_mmdet/test_mmdet_apis.py
Normal file
549
tests/test_mmdet/test_mmdet_apis.py
Normal file
@ -0,0 +1,549 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import mmdeploy.apis.ncnn as ncnn_apis
|
||||||
|
import mmdeploy.apis.onnxruntime as ort_apis
|
||||||
|
import mmdeploy.apis.ppl as ppl_apis
|
||||||
|
import mmdeploy.apis.tensorrt as trt_apis
|
||||||
|
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure TensorRTDetector can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'dets': torch.rand(1, 100, 5).cuda(),
|
||||||
|
'labels': torch.rand(1, 100).cuda()
|
||||||
|
}
|
||||||
|
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import TensorRTDetector
|
||||||
|
trt_detector = TensorRTDetector('', ['' for i in range(80)], 0)
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64).cuda()]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = trt_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using TensorRTDetector'
|
||||||
|
SwitchBackendWrapper.recover(TRTWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimeDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure ONNXRuntimeDetector can use ORTWrapper inside itself
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
|
||||||
|
SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector
|
||||||
|
ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0)
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ort_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'ONNXRuntimeDetector'
|
||||||
|
SwitchBackendWrapper.recover(ORTWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||||
|
def test_PPLDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure PPLDetector can use PPLWrapper inside itself
|
||||||
|
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||||
|
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
|
||||||
|
SwitchBackendWrapper.set(PPLWrapper, outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import PPLDetector
|
||||||
|
ppl_detector = PPLDetector('', ['' for i in range(80)], 0)
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ppl_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using PPLDetector'
|
||||||
|
SwitchBackendWrapper.recover(PPLWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_cfg_and_post_processing():
|
||||||
|
test_cfg = {
|
||||||
|
'nms_pre': 100,
|
||||||
|
'min_bbox_size': 0,
|
||||||
|
'score_thr': 0.05,
|
||||||
|
'nms': {
|
||||||
|
'type': 'nms',
|
||||||
|
'iou_threshold': 0.5
|
||||||
|
},
|
||||||
|
'max_per_img': 10
|
||||||
|
}
|
||||||
|
post_processing = {
|
||||||
|
'score_threshold': 0.05,
|
||||||
|
'iou_threshold': 0.5,
|
||||||
|
'max_output_boxes_per_class': 20,
|
||||||
|
'pre_top_k': -1,
|
||||||
|
'keep_top_k': 10,
|
||||||
|
'background_label_id': -1
|
||||||
|
}
|
||||||
|
return test_cfg, post_processing
|
||||||
|
|
||||||
|
|
||||||
|
def test_PartitionSingleStageDetector():
|
||||||
|
test_cfg, post_processing = get_test_cfg_and_post_processing()
|
||||||
|
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(codebase_config=dict(post_processing=post_processing)))
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import PartitionSingleStageDetector
|
||||||
|
pss_detector = PartitionSingleStageDetector(['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
scores = torch.rand(1, 120, 80)
|
||||||
|
bboxes = torch.rand(1, 120, 4)
|
||||||
|
|
||||||
|
results = pss_detector.partition0_postprocess(scores=scores, bboxes=bboxes)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'partition0_postprocess of PartitionSingleStageDetector'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||||
|
def test_NCNNPSSDetector():
|
||||||
|
test_cfg, post_processing = get_test_cfg_and_post_processing()
|
||||||
|
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(codebase_config=dict(post_processing=post_processing)))
|
||||||
|
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure NCNNPSSDetector can use NCNNWrapper inside itself
|
||||||
|
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||||
|
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'scores': torch.rand(1, 120, 80),
|
||||||
|
'boxes': torch.rand(1, 120, 4)
|
||||||
|
}
|
||||||
|
SwitchBackendWrapper.set(
|
||||||
|
NCNNWrapper,
|
||||||
|
outputs=outputs,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import NCNNPSSDetector
|
||||||
|
|
||||||
|
ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ncnn_pss_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using NCNNPSSDetector'
|
||||||
|
SwitchBackendWrapper.recover(NCNNWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimePSSDetector():
|
||||||
|
test_cfg, post_processing = get_test_cfg_and_post_processing()
|
||||||
|
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(codebase_config=dict(post_processing=post_processing)))
|
||||||
|
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure ONNXRuntimePSSDetector can use ORTWrapper inside itself
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = [
|
||||||
|
np.random.rand(1, 120, 80).astype(np.float32),
|
||||||
|
np.random.rand(1, 120, 4).astype(np.float32)
|
||||||
|
]
|
||||||
|
SwitchBackendWrapper.set(
|
||||||
|
ORTWrapper,
|
||||||
|
outputs=outputs,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector
|
||||||
|
|
||||||
|
ort_pss_detector = ONNXRuntimePSSDetector(
|
||||||
|
'', ['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ort_pss_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '
|
||||||
|
'ONNXRuntimePSSDetector'
|
||||||
|
SwitchBackendWrapper.recover(ORTWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTPSSDetector():
|
||||||
|
test_cfg, post_processing = get_test_cfg_and_post_processing()
|
||||||
|
model_cfg = mmcv.Config(dict(model=dict(test_cfg=test_cfg)))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(codebase_config=dict(post_processing=post_processing)))
|
||||||
|
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure TensorRTPSSDetector can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'scores': torch.rand(1, 120, 80).cuda(),
|
||||||
|
'boxes': torch.rand(1, 120, 4).cuda()
|
||||||
|
}
|
||||||
|
SwitchBackendWrapper.set(
|
||||||
|
TRTWrapper,
|
||||||
|
outputs=outputs,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector
|
||||||
|
|
||||||
|
trt_pss_detector = TensorRTPSSDetector(
|
||||||
|
'', ['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
imgs = [torch.rand(1, 3, 32, 32).cuda()]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = trt_pss_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '
|
||||||
|
'TensorRTPSSDetector'
|
||||||
|
SwitchBackendWrapper.recover(TRTWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_deploy_cfgs():
|
||||||
|
test_cfg, post_processing = get_test_cfg_and_post_processing()
|
||||||
|
bbox_roi_extractor = {
|
||||||
|
'type': 'SingleRoIExtractor',
|
||||||
|
'roi_layer': {
|
||||||
|
'type': 'RoIAlign',
|
||||||
|
'output_size': 7,
|
||||||
|
'sampling_ratio': 0
|
||||||
|
},
|
||||||
|
'out_channels': 8,
|
||||||
|
'featmap_strides': [4]
|
||||||
|
}
|
||||||
|
bbox_head = {
|
||||||
|
'type': 'Shared2FCBBoxHead',
|
||||||
|
'in_channels': 8,
|
||||||
|
'fc_out_channels': 1024,
|
||||||
|
'roi_feat_size': 7,
|
||||||
|
'num_classes': 80,
|
||||||
|
'bbox_coder': {
|
||||||
|
'type': 'DeltaXYWHBBoxCoder',
|
||||||
|
'target_means': [0.0, 0.0, 0.0, 0.0],
|
||||||
|
'target_stds': [0.1, 0.1, 0.2, 0.2]
|
||||||
|
},
|
||||||
|
'reg_class_agnostic': False,
|
||||||
|
'loss_cls': {
|
||||||
|
'type': 'CrossEntropyLoss',
|
||||||
|
'use_sigmoid': False,
|
||||||
|
'loss_weight': 1.0
|
||||||
|
},
|
||||||
|
'loss_bbox': {
|
||||||
|
'type': 'L1Loss',
|
||||||
|
'loss_weight': 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
roi_head = dict(bbox_roi_extractor=bbox_roi_extractor, bbox_head=bbox_head)
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
model=dict(
|
||||||
|
test_cfg=dict(rpn=test_cfg, rcnn=test_cfg),
|
||||||
|
roi_head=roi_head)))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(codebase_config=dict(post_processing=post_processing)))
|
||||||
|
return model_cfg, deploy_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def test_PartitionTwoStageDetector():
|
||||||
|
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
|
||||||
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
|
pts_detector = PartitionTwoStageDetector(['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
feats = [torch.randn(1, 8, 14, 14) for i in range(5)]
|
||||||
|
scores = torch.rand(1, 50, 1)
|
||||||
|
bboxes = torch.rand(1, 50, 4)
|
||||||
|
bboxes[..., 2:4] = 2 * bboxes[..., :2]
|
||||||
|
results = pts_detector.partition0_postprocess(
|
||||||
|
x=feats, scores=scores, bboxes=bboxes)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'partition0_postprocess of PartitionTwoStageDetector'
|
||||||
|
|
||||||
|
rois = torch.rand(1, 10, 5)
|
||||||
|
cls_score = torch.rand(10, 81)
|
||||||
|
bbox_pred = torch.rand(10, 320)
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
results = pts_detector.partition1_postprocess(
|
||||||
|
rois=rois,
|
||||||
|
cls_score=cls_score,
|
||||||
|
bbox_pred=bbox_pred,
|
||||||
|
img_metas=img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'partition1_postprocess of PartitionTwoStageDetector'
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPTSDetector(torch.nn.Module):
|
||||||
|
"""A dummy wrapper for unit tests."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.output_names = ['dets', 'labels']
|
||||||
|
|
||||||
|
def partition0_postprocess(self, *args, **kwargs):
|
||||||
|
return self.outputs0
|
||||||
|
|
||||||
|
def partition1_postprocess(self, *args, **kwargs):
|
||||||
|
return self.outputs1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTPTSDetector():
|
||||||
|
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
|
||||||
|
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure TensorRTPTSDetector can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'scores': torch.rand(1, 12, 80).cuda(),
|
||||||
|
'boxes': torch.rand(1, 12, 4).cuda(),
|
||||||
|
'cls_score': torch.rand(1, 12, 80).cuda(),
|
||||||
|
'bbox_pred': torch.rand(1, 12, 4).cuda()
|
||||||
|
}
|
||||||
|
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
|
||||||
|
TRTWrapper.model_cfg = model_cfg
|
||||||
|
TRTWrapper.deploy_cfg = deploy_cfg
|
||||||
|
|
||||||
|
# replace original function in PartitionTwoStageDetector
|
||||||
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
|
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
||||||
|
PartitionTwoStageDetector.partition0_postprocess = \
|
||||||
|
DummyPTSDetector.partition0_postprocess
|
||||||
|
PartitionTwoStageDetector.partition1_postprocess = \
|
||||||
|
DummyPTSDetector.partition1_postprocess
|
||||||
|
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2
|
||||||
|
PartitionTwoStageDetector.outputs1 = [
|
||||||
|
torch.rand(1, 9, 5).cuda(),
|
||||||
|
torch.rand(1, 9).cuda()
|
||||||
|
]
|
||||||
|
PartitionTwoStageDetector.device_id = 0
|
||||||
|
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector
|
||||||
|
trt_pts_detector = TensorRTPTSDetector(['', ''], ['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
|
||||||
|
imgs = [torch.rand(1, 3, 32, 32).cuda()]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
results = trt_pts_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '
|
||||||
|
'TensorRTPTSDetector'
|
||||||
|
SwitchBackendWrapper.recover(TRTWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimePTSDetector():
|
||||||
|
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
|
||||||
|
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure ONNXRuntimePTSDetector can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = [
|
||||||
|
np.random.rand(1, 12, 80).astype(np.float32),
|
||||||
|
np.random.rand(1, 12, 4).astype(np.float32),
|
||||||
|
] * 2
|
||||||
|
SwitchBackendWrapper.set(
|
||||||
|
ORTWrapper,
|
||||||
|
outputs=outputs,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
# replace original function in PartitionTwoStageDetector
|
||||||
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
|
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
||||||
|
PartitionTwoStageDetector.partition0_postprocess = \
|
||||||
|
DummyPTSDetector.partition0_postprocess
|
||||||
|
PartitionTwoStageDetector.partition1_postprocess = \
|
||||||
|
DummyPTSDetector.partition1_postprocess
|
||||||
|
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
|
||||||
|
PartitionTwoStageDetector.outputs1 = [
|
||||||
|
torch.rand(1, 9, 5), torch.rand(1, 9)
|
||||||
|
]
|
||||||
|
PartitionTwoStageDetector.device_id = -1
|
||||||
|
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector
|
||||||
|
ort_pts_detector = ONNXRuntimePTSDetector(['', ''],
|
||||||
|
['' for i in range(80)],
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
|
||||||
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
results = ort_pts_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '
|
||||||
|
'ONNXRuntimePTSDetector'
|
||||||
|
SwitchBackendWrapper.recover(ORTWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||||
|
def test_NCNNPTSDetector():
|
||||||
|
model_cfg, deploy_cfg = prepare_model_deploy_cfgs()
|
||||||
|
num_outs = dict(model=dict(neck=dict(num_outs=0)))
|
||||||
|
model_cfg.update(num_outs)
|
||||||
|
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure NCNNPTSDetector can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||||
|
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'scores': torch.rand(1, 12, 80),
|
||||||
|
'boxes': torch.rand(1, 12, 4),
|
||||||
|
'cls_score': torch.rand(1, 12, 80),
|
||||||
|
'bbox_pred': torch.rand(1, 12, 4)
|
||||||
|
}
|
||||||
|
SwitchBackendWrapper.set(
|
||||||
|
NCNNWrapper,
|
||||||
|
outputs=outputs,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
# replace original function in PartitionTwoStageDetector
|
||||||
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
|
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
||||||
|
PartitionTwoStageDetector.partition0_postprocess = \
|
||||||
|
DummyPTSDetector.partition0_postprocess
|
||||||
|
PartitionTwoStageDetector.partition1_postprocess = \
|
||||||
|
DummyPTSDetector.partition1_postprocess
|
||||||
|
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
|
||||||
|
PartitionTwoStageDetector.outputs1 = [
|
||||||
|
torch.rand(1, 9, 5), torch.rand(1, 9)
|
||||||
|
]
|
||||||
|
PartitionTwoStageDetector.device_id = -1
|
||||||
|
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
||||||
|
|
||||||
|
from mmdeploy.mmdet.apis.inference import NCNNPTSDetector
|
||||||
|
ncnn_pts_detector = NCNNPTSDetector(
|
||||||
|
[''] * 4, [''] * 80,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
|
||||||
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [32, 32, 3],
|
||||||
|
'img_shape': [32, 32, 3],
|
||||||
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
|
}]]
|
||||||
|
results = ncnn_pts_detector.forward(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '
|
||||||
|
'NCNNPTSDetector'
|
||||||
|
SwitchBackendWrapper.recover(NCNNWrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_build_detector():
|
||||||
|
_, post_processing = get_test_cfg_and_post_processing()
|
||||||
|
model_cfg = mmcv.Config(dict(data=dict(test={'type': 'CocoDataset'})))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmdet', post_processing=post_processing)))
|
||||||
|
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
SwitchBackendWrapper.set(
|
||||||
|
ORTWrapper, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
|
from mmdeploy.apis.utils import init_backend_model
|
||||||
|
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
||||||
|
assert detector is not None
|
||||||
|
SwitchBackendWrapper.recover(ORTWrapper)
|
147
tests/test_mmdet/test_mmdet_core.py
Normal file
147
tests/test_mmdet/test_mmdet_core.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_multiclass_nms_static():
|
||||||
|
|
||||||
|
import tensorrt as trt
|
||||||
|
from mmdeploy.mmdet.core import multiclass_nms
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
onnx_config=dict(
|
||||||
|
output_names=['dets', 'labels'], input_shape=None),
|
||||||
|
backend_config=dict(
|
||||||
|
type='tensorrt',
|
||||||
|
common_config=dict(
|
||||||
|
fp16_mode=False,
|
||||||
|
log_level=trt.Logger.INFO,
|
||||||
|
max_workspace_size=1 << 20),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
boxes=dict(
|
||||||
|
min_shape=[1, 5, 4],
|
||||||
|
opt_shape=[1, 5, 4],
|
||||||
|
max_shape=[1, 5, 4]),
|
||||||
|
scores=dict(
|
||||||
|
min_shape=[1, 5, 8],
|
||||||
|
opt_shape=[1, 5, 8],
|
||||||
|
max_shape=[1, 5, 8])))
|
||||||
|
]),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmdet',
|
||||||
|
task='ObjectDetection',
|
||||||
|
post_processing=dict(
|
||||||
|
score_threshold=0.05,
|
||||||
|
iou_threshold=0.5,
|
||||||
|
max_output_boxes_per_class=20,
|
||||||
|
pre_top_k=-1,
|
||||||
|
keep_top_k=10,
|
||||||
|
background_label_id=-1,
|
||||||
|
))))
|
||||||
|
|
||||||
|
boxes = torch.rand(1, 5, 4).cuda()
|
||||||
|
scores = torch.rand(1, 5, 8).cuda()
|
||||||
|
max_output_boxes_per_class = 20
|
||||||
|
keep_top_k = 10
|
||||||
|
wrapped_func = WrapFunction(
|
||||||
|
multiclass_nms,
|
||||||
|
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||||
|
keep_top_k=keep_top_k)
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_func,
|
||||||
|
model_inputs={
|
||||||
|
'boxes': boxes,
|
||||||
|
'scores': scores
|
||||||
|
},
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
||||||
|
'outputs: {}'.format(rewrite_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||||
|
def test_delta2bbox(backend_type):
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
onnx_config=dict(
|
||||||
|
output_names=['dets', 'labels'], input_shape=None),
|
||||||
|
backend_config=dict(type=backend_type, model_inputs=None),
|
||||||
|
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||||
|
|
||||||
|
# wrap function to enable rewrite
|
||||||
|
def delta2bbox(*args, **kwargs):
|
||||||
|
import mmdet
|
||||||
|
return mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox(
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
|
rois = torch.rand(1, 5, 4)
|
||||||
|
deltas = torch.rand(1, 5, 4)
|
||||||
|
original_outputs = delta2bbox(rois, deltas)
|
||||||
|
|
||||||
|
# wrap function to nn.Module, enable torch.onn.export
|
||||||
|
wrapped_func = WrapFunction(delta2bbox)
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_func,
|
||||||
|
model_inputs={
|
||||||
|
'rois': rois,
|
||||||
|
'deltas': deltas
|
||||||
|
},
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
model_output = original_outputs.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_outputs[0].squeeze()
|
||||||
|
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||||
|
def test_tblr2bbox(backend_type):
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
onnx_config=dict(
|
||||||
|
output_names=['dets', 'labels'], input_shape=None),
|
||||||
|
backend_config=dict(type=backend_type, model_inputs=None),
|
||||||
|
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||||
|
|
||||||
|
# wrap function to enable rewrite
|
||||||
|
def tblr2bboxes(*args, **kwargs):
|
||||||
|
import mmdet
|
||||||
|
return mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes(
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
|
priors = torch.rand(1, 5, 4)
|
||||||
|
tblr = torch.rand(1, 5, 4)
|
||||||
|
original_outputs = tblr2bboxes(priors, tblr)
|
||||||
|
|
||||||
|
# wrap function to nn.Module, enable torch.onn.export
|
||||||
|
wrapped_func = WrapFunction(tblr2bboxes)
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_func,
|
||||||
|
model_inputs={
|
||||||
|
'priors': priors,
|
||||||
|
'tblr': tblr
|
||||||
|
},
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
model_output = original_outputs.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_outputs[0].squeeze()
|
||||||
|
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_distance2bbox():
|
||||||
|
from mmdeploy.mmdet.core import distance2bbox
|
||||||
|
points = torch.rand(3, 2)
|
||||||
|
distance = torch.rand(3, 4)
|
||||||
|
bbox = distance2bbox(points, distance)
|
||||||
|
assert bbox.shape == torch.Size([3, 4])
|
104
tests/test_mmdet/test_mmdet_export.py
Normal file
104
tests/test_mmdet/test_mmdet_export.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.apis.utils import (build_dataloader, build_dataset, create_input,
|
||||||
|
get_tensor_from_input)
|
||||||
|
from mmdeploy.utils.constants import Codebase, Task
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_input():
|
||||||
|
task = Task.OBJECT_DETECTION
|
||||||
|
test = dict(pipeline=[{
|
||||||
|
'type': 'LoadImageFromWebcam'
|
||||||
|
}, {
|
||||||
|
'type':
|
||||||
|
'MultiScaleFlipAug',
|
||||||
|
'img_scale': [32, 32],
|
||||||
|
'flip':
|
||||||
|
False,
|
||||||
|
'transforms': [{
|
||||||
|
'type': 'Resize',
|
||||||
|
'keep_ratio': True
|
||||||
|
}, {
|
||||||
|
'type': 'RandomFlip'
|
||||||
|
}, {
|
||||||
|
'type': 'Normalize',
|
||||||
|
'mean': [123.675, 116.28, 103.53],
|
||||||
|
'std': [58.395, 57.12, 57.375],
|
||||||
|
'to_rgb': True
|
||||||
|
}, {
|
||||||
|
'type': 'Pad',
|
||||||
|
'size_divisor': 32
|
||||||
|
}, {
|
||||||
|
'type': 'DefaultFormatBundle'
|
||||||
|
}, {
|
||||||
|
'type': 'Collect',
|
||||||
|
'keys': ['img']
|
||||||
|
}]
|
||||||
|
}])
|
||||||
|
data = dict(test=test)
|
||||||
|
model_cfg = mmcv.Config(dict(data=data))
|
||||||
|
imgs = [np.random.rand(32, 32, 3)]
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMDET,
|
||||||
|
task,
|
||||||
|
model_cfg,
|
||||||
|
imgs,
|
||||||
|
input_shape=(32, 32),
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('input_data', [{'img': [torch.ones(3, 4, 5)]}])
|
||||||
|
def test_get_tensor_from_input(input_data):
|
||||||
|
inputs = get_tensor_from_input(Codebase.MMDET, input_data)
|
||||||
|
assert inputs is not None, 'Failed to get tensor from input'
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_dataset():
|
||||||
|
data = dict(
|
||||||
|
test={
|
||||||
|
'type': 'CocoDataset',
|
||||||
|
'ann_file': 'tests/test_mmdet/data/coco_sample.json',
|
||||||
|
'img_prefix': 'tests/test_mmdet/data/imgs/',
|
||||||
|
'pipeline': [
|
||||||
|
{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
dataset_cfg = mmcv.Config(dict(data=data))
|
||||||
|
dataset = build_dataset(
|
||||||
|
Codebase.MMDET, dataset_cfg=dataset_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = build_dataloader(Codebase.MMDET, dataset, 1, 1)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_bboxes():
|
||||||
|
from mmdeploy.mmdet.export import clip_bboxes
|
||||||
|
x1 = torch.rand(3, 2) * 224
|
||||||
|
y1 = torch.rand(3, 2) * 224
|
||||||
|
x2 = x1 * 2
|
||||||
|
y2 = y1 * 2
|
||||||
|
outs = clip_bboxes(x1, y1, x2, y2, [224, 224])
|
||||||
|
for out in outs:
|
||||||
|
assert int(out.max()) <= 224
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad_with_value():
|
||||||
|
from mmdeploy.mmdet.export import pad_with_value
|
||||||
|
x = torch.rand(3, 2)
|
||||||
|
padded_x = pad_with_value(x, pad_dim=1, pad_size=4, pad_value=0)
|
||||||
|
assert np.allclose(
|
||||||
|
padded_x.shape, torch.Size([3, 6]), rtol=1e-03, atol=1e-05)
|
||||||
|
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('partition_type', ['single_stage', 'two_stage'])
|
||||||
|
def test_get_partition_cfg(partition_type):
|
||||||
|
from mmdeploy.mmdet.export import get_partition_cfg
|
||||||
|
partition_cfg = get_partition_cfg(partition_type=partition_type)
|
||||||
|
assert partition_cfg is not None
|
293
tests/test_mmdet/test_mmdet_models.py
Normal file
293
tests/test_mmdet/test_mmdet_models.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
import copy
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.utils.constants import Backend, Codebase
|
||||||
|
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
|
||||||
|
get_rewrite_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed=1029):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_anchor_head_model():
|
||||||
|
"""AnchorHead Config."""
|
||||||
|
test_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
deploy_nms_pre=0,
|
||||||
|
min_bbox_size=0,
|
||||||
|
score_thr=0.05,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.5),
|
||||||
|
max_per_img=100))
|
||||||
|
|
||||||
|
from mmdet.models import AnchorHead
|
||||||
|
model = AnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
||||||
|
model.requires_grad_(False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_fcos_head_model():
|
||||||
|
"""FCOS Head Config."""
|
||||||
|
test_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
deploy_nms_pre=0,
|
||||||
|
min_bbox_size=0,
|
||||||
|
score_thr=0.05,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.5),
|
||||||
|
max_per_img=100))
|
||||||
|
|
||||||
|
from mmdet.models import FCOSHead
|
||||||
|
model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_rpn_head_model():
|
||||||
|
"""RPN Head Config."""
|
||||||
|
test_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
deploy_nms_pre=0,
|
||||||
|
min_bbox_size=0,
|
||||||
|
score_thr=0.05,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.5),
|
||||||
|
max_per_img=100))
|
||||||
|
from mmdet.models import RPNHead
|
||||||
|
model = RPNHead(in_channels=1, test_cfg=test_cfg)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||||
|
def test_anchor_head_get_bboxes(backend_type):
|
||||||
|
"""Test get_bboxes rewrite of anchor head."""
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
anchor_head = get_anchor_head_model()
|
||||||
|
anchor_head.cpu().eval()
|
||||||
|
s = 128
|
||||||
|
img_metas = [{
|
||||||
|
'scale_factor': np.ones(4),
|
||||||
|
'pad_shape': (s, s, 3),
|
||||||
|
'img_shape': (s, s, 3)
|
||||||
|
}]
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type=backend_type),
|
||||||
|
onnx_config=dict(
|
||||||
|
output_names=['dets', 'labels'], input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmdet',
|
||||||
|
task='ObjectDetection',
|
||||||
|
post_processing=dict(
|
||||||
|
score_threshold=0.05,
|
||||||
|
iou_threshold=0.5,
|
||||||
|
max_output_boxes_per_class=200,
|
||||||
|
pre_top_k=-1,
|
||||||
|
keep_top_k=100,
|
||||||
|
background_label_id=-1,
|
||||||
|
))))
|
||||||
|
|
||||||
|
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||||
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||||
|
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||||
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||||
|
seed_everything(1234)
|
||||||
|
cls_score = [
|
||||||
|
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||||
|
]
|
||||||
|
seed_everything(5678)
|
||||||
|
bboxes = [torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||||
|
|
||||||
|
# to get outputs of pytorch model
|
||||||
|
model_inputs = {
|
||||||
|
'cls_scores': cls_score,
|
||||||
|
'bbox_preds': bboxes,
|
||||||
|
'img_metas': img_metas
|
||||||
|
}
|
||||||
|
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
|
||||||
|
|
||||||
|
# to get outputs of onnx model after rewrite
|
||||||
|
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||||
|
wrapped_model = WrapModel(
|
||||||
|
anchor_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
|
||||||
|
rewrite_inputs = {
|
||||||
|
'cls_scores': cls_score,
|
||||||
|
'bbox_preds': bboxes,
|
||||||
|
}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
# hard code to make two tensors with the same shape
|
||||||
|
# rewrite and original codes applied different nms strategy
|
||||||
|
assert np.allclose(
|
||||||
|
model_output[:rewrite_output.shape[0]],
|
||||||
|
rewrite_output,
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||||
|
def test_get_bboxes_of_fcos_head(backend_type):
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
fcos_head = get_fcos_head_model()
|
||||||
|
fcos_head.cpu().eval()
|
||||||
|
s = 128
|
||||||
|
img_metas = [{
|
||||||
|
'scale_factor': np.ones(4),
|
||||||
|
'pad_shape': (s, s, 3),
|
||||||
|
'img_shape': (s, s, 3)
|
||||||
|
}]
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type=backend_type),
|
||||||
|
onnx_config=dict(
|
||||||
|
output_names=['dets', 'labels'], input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmdet',
|
||||||
|
task='ObjectDetection',
|
||||||
|
post_processing=dict(
|
||||||
|
score_threshold=0.05,
|
||||||
|
iou_threshold=0.5,
|
||||||
|
max_output_boxes_per_class=200,
|
||||||
|
pre_top_k=-1,
|
||||||
|
keep_top_k=100,
|
||||||
|
background_label_id=-1,
|
||||||
|
))))
|
||||||
|
|
||||||
|
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||||
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||||
|
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||||
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||||
|
seed_everything(1234)
|
||||||
|
cls_score = [
|
||||||
|
torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i))
|
||||||
|
for i in range(5, 0, -1)
|
||||||
|
]
|
||||||
|
seed_everything(5678)
|
||||||
|
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||||
|
|
||||||
|
seed_everything(9101)
|
||||||
|
centernesses = [
|
||||||
|
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||||
|
]
|
||||||
|
|
||||||
|
# to get outputs of pytorch model
|
||||||
|
model_inputs = {
|
||||||
|
'cls_scores': cls_score,
|
||||||
|
'bbox_preds': bboxes,
|
||||||
|
'centernesses': centernesses,
|
||||||
|
'img_metas': img_metas
|
||||||
|
}
|
||||||
|
model_outputs = get_model_outputs(fcos_head, 'get_bboxes', model_inputs)
|
||||||
|
|
||||||
|
# to get outputs of onnx model after rewrite
|
||||||
|
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||||
|
wrapped_model = WrapModel(
|
||||||
|
fcos_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
|
||||||
|
rewrite_inputs = {
|
||||||
|
'cls_scores': cls_score,
|
||||||
|
'bbox_preds': bboxes,
|
||||||
|
'centernesses': centernesses
|
||||||
|
}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
# hard code to make two tensors with the same shape
|
||||||
|
# rewrite and original codes applied different nms strategy
|
||||||
|
assert np.allclose(
|
||||||
|
model_output[:rewrite_output.shape[0]],
|
||||||
|
rewrite_output,
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_r50_with_r18(model):
|
||||||
|
"""Replace ResNet50 with ResNet18 in config."""
|
||||||
|
model = copy.deepcopy(model)
|
||||||
|
if model.backbone.type == 'ResNet':
|
||||||
|
model.backbone.depth = 18
|
||||||
|
model.backbone.base_channels = 2
|
||||||
|
model.neck.in_channels = [2, 4, 8, 16]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('model_cfg_path', [
|
||||||
|
'tests/test_mmdet/data/single_stage_model.json',
|
||||||
|
'tests/test_mmdet/data/mask_model.json'
|
||||||
|
])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_forward_of_base_detector_and_visualize(model_cfg_path):
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
onnx_config=dict(
|
||||||
|
output_names=['dets', 'labels'], input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmdet',
|
||||||
|
task='ObjectDetection',
|
||||||
|
post_processing=dict(
|
||||||
|
score_threshold=0.05,
|
||||||
|
iou_threshold=0.5,
|
||||||
|
max_output_boxes_per_class=200,
|
||||||
|
pre_top_k=-1,
|
||||||
|
keep_top_k=100,
|
||||||
|
background_label_id=-1,
|
||||||
|
))))
|
||||||
|
|
||||||
|
model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
||||||
|
model_cfg.model = _replace_r50_with_r18(model_cfg.model)
|
||||||
|
from mmdet.apis import init_detector
|
||||||
|
model = init_detector(model_cfg, None, 'cpu')
|
||||||
|
|
||||||
|
img = torch.randn(1, 3, 64, 64)
|
||||||
|
rewrite_inputs = {'img': img}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
from mmdeploy.apis.utils import visualize
|
||||||
|
output_file = tempfile.NamedTemporaryFile(suffix='.jpg').name
|
||||||
|
model.CLASSES = [''] * 80
|
||||||
|
visualize(
|
||||||
|
Codebase.MMDET,
|
||||||
|
img.squeeze().permute(1, 2, 0).numpy(),
|
||||||
|
result=[torch.rand(0, 5).numpy()] * 80,
|
||||||
|
model=model,
|
||||||
|
output_file=output_file,
|
||||||
|
backend=Backend.ONNXRUNTIME,
|
||||||
|
show_result=False)
|
||||||
|
|
||||||
|
assert rewrite_outputs is not None
|
@ -1,100 +0,0 @@
|
|||||||
import os.path as osp
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
|
|
||||||
get_rewrite_outputs)
|
|
||||||
|
|
||||||
data_path = osp.join(osp.dirname(__file__), 'data')
|
|
||||||
|
|
||||||
|
|
||||||
def get_anchor_head_model():
|
|
||||||
"""AnchorHead Config."""
|
|
||||||
test_cfg = mmcv.Config(
|
|
||||||
dict(
|
|
||||||
deploy_nms_pre=0,
|
|
||||||
min_bbox_size=0,
|
|
||||||
score_thr=0.05,
|
|
||||||
nms=dict(type='nms', iou_threshold=0.5),
|
|
||||||
max_per_img=100))
|
|
||||||
|
|
||||||
from mmdet.models import AnchorHead
|
|
||||||
model = AnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
|
||||||
model.requires_grad_(False)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def test_anchor_head_get_bboxes():
|
|
||||||
"""Test get_bboxes rewrite of anchor head."""
|
|
||||||
anchor_head = get_anchor_head_model()
|
|
||||||
anchor_head.cpu().eval()
|
|
||||||
s = 128
|
|
||||||
img_metas = [{
|
|
||||||
'scale_factor': np.ones(4),
|
|
||||||
'pad_shape': (s, s, 3),
|
|
||||||
'img_shape': (s, s, 2)
|
|
||||||
}]
|
|
||||||
|
|
||||||
deploy_cfg = mmcv.Config(
|
|
||||||
dict(
|
|
||||||
backend_config=dict(type='onnxruntime'),
|
|
||||||
onnx_config=dict(
|
|
||||||
output_names=['dets', 'labels'], input_shape=None),
|
|
||||||
codebase_config=dict(
|
|
||||||
type='mmdet',
|
|
||||||
task='ObjectDetection',
|
|
||||||
post_processing=dict(
|
|
||||||
score_threshold=0.05,
|
|
||||||
iou_threshold=0.5,
|
|
||||||
max_output_boxes_per_class=200,
|
|
||||||
pre_top_k=-1,
|
|
||||||
keep_top_k=100,
|
|
||||||
background_label_id=-1,
|
|
||||||
))))
|
|
||||||
|
|
||||||
# The data of anchor_head_get_bboxes.pkl contains two parts:
|
|
||||||
# cls_score(list(Tensor)) and bboxes(list(Tensor)),
|
|
||||||
# where each torch.Tensor is generated by torch.rand().
|
|
||||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
|
||||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
|
||||||
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
|
||||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
|
||||||
retina_head_data = 'anchor_head_get_bboxes.pkl'
|
|
||||||
feats = mmcv.load(osp.join(data_path, retina_head_data))
|
|
||||||
cls_score = feats[:5]
|
|
||||||
bboxes = feats[5:]
|
|
||||||
|
|
||||||
# to get outputs of pytorch model
|
|
||||||
model_inputs = {
|
|
||||||
'cls_scores': cls_score,
|
|
||||||
'bbox_preds': bboxes,
|
|
||||||
'img_metas': img_metas
|
|
||||||
}
|
|
||||||
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
|
|
||||||
|
|
||||||
# to get outputs of onnx model after rewrite
|
|
||||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
|
||||||
wrapped_model = WrapModel(
|
|
||||||
anchor_head, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
|
|
||||||
rewrite_inputs = {
|
|
||||||
'cls_scores': cls_score,
|
|
||||||
'bbox_preds': bboxes,
|
|
||||||
}
|
|
||||||
rewrite_outputs = get_rewrite_outputs(
|
|
||||||
wrapped_model=wrapped_model,
|
|
||||||
model_inputs=rewrite_inputs,
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
|
|
||||||
model_output = model_output.squeeze().cpu().numpy()
|
|
||||||
rewrite_output = rewrite_output.squeeze()
|
|
||||||
# hard code to make two tensors with the same shape
|
|
||||||
# rewrite and original codes applied different nms strategy
|
|
||||||
assert np.allclose(
|
|
||||||
model_output[:rewrite_output.shape[0]],
|
|
||||||
rewrite_output,
|
|
||||||
rtol=1e-03,
|
|
||||||
atol=1e-05)
|
|
Loading…
x
Reference in New Issue
Block a user