mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Unittest]: Test for CascadeRoIHead (#141)
* Fix include and lib paths for onnxruntime. * Fixes for SSD export test * Add onnx2openvino and OpenVINODetector. Test models: ssd, retinanet, fcos, fsaf. * Add support for two-stage models: faster_rcnn, cascade_rcnn * Add doc * Add strip_doc_string for openvino. * Fix openvino preprocess. * Add OpenVINO to test_wrapper.py. * Fix * Add openvino_execute. * Removed preprocessing. * Fix onnxruntime cmake. * Rewrote postprocessing and forward, added docstrings and fixes. * Added device type change to OpenVINOWrapper. * Update forward_of_single_roi_extractor_dynamic_openvino and fix doc. * Update docs. * Add OpenVINODetector and onn2openvino tests. * Add input_info to onnx2openvino. * Add TestOpenVINOExporter and test_single_roi_extractor. * Moved get_input_shape_from_cfg to openvino_utils.py and added test. * Added test_cascade_roi_head. * Add backend.check_env() to tests. * Add OpenVINO to get_rewrite_outputs and to some tests in test_mmdet_models. * Moved test_single_roi_extractor to test_mmdet_models. * Removed TestOpenVINOExporter. * Added test_cascade_roi_head. * Fix onnxruntime outputs type.
This commit is contained in:
parent
9e227b228b
commit
ba7ae81d3e
@ -366,3 +366,134 @@ def test_single_roi_extractor(backend_type):
|
||||
backend_output = backend_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, backend_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def get_cascade_roi_head():
|
||||
"""CascadeRoIHead Config."""
|
||||
num_stages = 3
|
||||
stage_loss_weights = [1, 0.5, 0.25]
|
||||
bbox_roi_extractor = {
|
||||
'type': 'SingleRoIExtractor',
|
||||
'roi_layer': {
|
||||
'type': 'RoIAlign',
|
||||
'output_size': 7,
|
||||
'sampling_ratio': 0
|
||||
},
|
||||
'out_channels': 64,
|
||||
'featmap_strides': [4, 8, 16, 32]
|
||||
}
|
||||
all_target_stds = [[0.1, 0.1, 0.2, 0.2], [0.05, 0.05, 0.1, 0.1],
|
||||
[0.033, 0.033, 0.067, 0.067]]
|
||||
bbox_head = [{
|
||||
'type': 'Shared2FCBBoxHead',
|
||||
'in_channels': 64,
|
||||
'fc_out_channels': 1024,
|
||||
'roi_feat_size': 7,
|
||||
'num_classes': 80,
|
||||
'bbox_coder': {
|
||||
'type': 'DeltaXYWHBBoxCoder',
|
||||
'target_means': [0.0, 0.0, 0.0, 0.0],
|
||||
'target_stds': target_stds
|
||||
},
|
||||
'reg_class_agnostic': True,
|
||||
'loss_cls': {
|
||||
'type': 'CrossEntropyLoss',
|
||||
'use_sigmoid': False,
|
||||
'loss_weight': 1.0
|
||||
},
|
||||
'loss_bbox': {
|
||||
'type': 'SmoothL1Loss',
|
||||
'beta': 1.0,
|
||||
'loss_weight': 1.0
|
||||
}
|
||||
} for target_stds in all_target_stds]
|
||||
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
score_thr=0.05,
|
||||
nms=mmcv.Config(dict(type='nms', iou_threshold=0.5)),
|
||||
max_per_img=100))
|
||||
|
||||
from mmdet.models import CascadeRoIHead
|
||||
model = CascadeRoIHead(
|
||||
num_stages,
|
||||
stage_loss_weights,
|
||||
bbox_roi_extractor,
|
||||
bbox_head,
|
||||
test_cfg=test_cfg).eval()
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'openvino'])
|
||||
def test_cascade_roi_head(backend_type):
|
||||
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||
|
||||
cascade_roi_head = get_cascade_roi_head()
|
||||
seed_everything(1234)
|
||||
x = [
|
||||
torch.rand((1, 64, 200, 304)),
|
||||
torch.rand((1, 64, 100, 152)),
|
||||
torch.rand((1, 64, 50, 76)),
|
||||
torch.rand((1, 64, 25, 38)),
|
||||
]
|
||||
proposals = torch.tensor([[587.8285, 52.1405, 886.2484, 341.5644, 0.5]])
|
||||
img_metas = mmcv.Config({
|
||||
'img_shape': torch.tensor([800, 1216]),
|
||||
'ori_shape': torch.tensor([800, 1216]),
|
||||
'scale_factor': torch.tensor([1, 1, 1, 1])
|
||||
})
|
||||
|
||||
model_inputs = {
|
||||
'x': x,
|
||||
'proposal_list': [proposals],
|
||||
'img_metas': [img_metas]
|
||||
}
|
||||
model_outputs = get_model_outputs(cascade_roi_head, 'simple_test',
|
||||
model_inputs)
|
||||
processed_model_outputs = []
|
||||
for output in model_outputs[0]:
|
||||
if output.shape == (0, 5):
|
||||
processed_model_outputs.append(np.zeros((1, 5)))
|
||||
else:
|
||||
processed_model_outputs.append(output)
|
||||
processed_model_outputs = np.array(processed_model_outputs).squeeze()
|
||||
processed_model_outputs = processed_model_outputs[None, :, :]
|
||||
|
||||
output_names = ['results']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1))))
|
||||
model_inputs = {'x': x, 'proposals': proposals.unsqueeze(0)}
|
||||
wrapped_model = WrapModel(
|
||||
cascade_roi_head, 'simple_test', img_metas=img_metas)
|
||||
backend_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=model_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
processed_backend_outputs = []
|
||||
if isinstance(backend_outputs, dict):
|
||||
processed_backend_outputs = [
|
||||
backend_outputs[name] for name in output_names
|
||||
if name in backend_outputs
|
||||
]
|
||||
elif isinstance(backend_outputs, (list, tuple)) and \
|
||||
backend_outputs[0].shape == (1, 0, 5):
|
||||
processed_backend_outputs = np.zeros((1, 80, 5))
|
||||
else:
|
||||
processed_backend_outputs = backend_outputs
|
||||
assert np.allclose(
|
||||
processed_model_outputs,
|
||||
processed_backend_outputs,
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
|
Loading…
x
Reference in New Issue
Block a user