Update ncnn test (#298)

* update ncnn test

* type hint

* update test ocr

* update mmseg ut

* ignore ncnn rpn head test

* add logging

* fix ssd base dense head test

* recover bacth in ncnn wrapper

* fix ncnn_ops_ut

* fix yapf

* recover test ops

* fix run_with_backend False

* Revert "fix run_with_backend False"

This reverts commit 83f8f915a25e800f5c2db339584d164ba40b2d9b.

* disable ncnn test test_pytorch_functions.py

Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: hanrui1sensetime <hanrui1@sensetime.com>
This commit is contained in:
RunningLeon 2021-12-20 14:05:13 +08:00 committed by GitHub
parent 270d98a8a2
commit fabdb473bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 295 additions and 441 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import logging
from typing import Dict, Optional, Sequence
import ncnn
@ -77,36 +78,46 @@ class NCNNWrapper(BaseWrapper):
"""
input_list = list(inputs.values())
batch_size = input_list[0].size(0)
# assert batch_size == 1, 'Only batch_size=1 is supported!'
if batch_size > 1:
logging.warning(
f'ncnn only support batch_size = 1, but given {batch_size}')
for input_tensor in input_list[1:]:
assert input_tensor.size(
0) == batch_size, 'All tensors should have same batch size'
assert input_tensor.device.type == 'cpu', \
'NCNN only supports cpu device'
# set output names
output_names = self._output_names
# create output dict
outputs = dict([name, [None] * batch_size] for name in output_names)
# run inference
for batch_id in range(batch_size):
# create extractor
ex = self._net.create_extractor()
# create extractor
ex = self._net.create_extractor()
# set inputs
for name, input_tensor in inputs.items():
data = input_tensor[0].contiguous().cpu().numpy()
input_mat = ncnn.Mat(data)
ex.input(name, input_mat)
# set inputs
for name, input_tensor in inputs.items():
data = input_tensor[batch_id].contiguous()
data = data.detach().cpu().numpy()
input_mat = ncnn.Mat(data)
ex.input(name, input_mat)
# get outputs
result = self.__ncnn_execute(
extractor=ex, output_names=output_names)
for name in output_names:
mat = result[name]
# deal with special case
if mat.empty():
mat = None
logging.warning(
f'The "{name}" output of ncnn model is empty.')
outputs[name][batch_id] = torch.from_numpy(np.array(mat))
# stack outputs together
for name, output_tensor in outputs.items():
outputs[name] = torch.stack(output_tensor)
# get outputs
result = self.__ncnn_execute(extractor=ex, output_names=output_names)
for name in output_names:
mat = result[name]
# deal with special case
if mat.empty():
outputs[name] = None
continue
outputs[name] = torch.from_numpy(np.array(mat)).unsqueeze(0)
return outputs
@TimeCounter.count_time()

View File

@ -147,7 +147,7 @@ class ObjectDetection(BaseTask):
output_file: str,
window_name: str,
show_result: bool = False,
score_thr=0.3):
score_thr: float = 0.3):
"""Visualize predictions of a model.
Args:

View File

@ -1,226 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.codebase.mmdet import get_post_processing_params, multiclass_nms
from mmdeploy.codebase.mmdet.core.ops.detection_output import \
ncnn_detection_output_forward
from mmdeploy.codebase.mmdet.core.ops.prior_box import ncnn_prior_box_forward
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.AnchorHead.get_bboxes', backend='ncnn')
def anchor_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of AnchorHead for NCNN backend.
Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
Default: None.
Returns:
If isinstance(self, SSDHead) == True:
Tensor: outputs, shape is [N, num_det, 5].
If isinstance(self, SSDHead) == False:
If with_nms == True:
tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
from mmdet.models.dense_heads import SSDHead
# now the ncnn PriorBox and DetectionOutput adaption is only used for
# SSDHead.
# TODO: Adapt all of the AnchorHead instances for ncnn PriorBox and
# DetectionOutput. Then, the determine statement will be removed, and
# the code will be unified.
if not isinstance(self, SSDHead):
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) \
== len(mlvl_anchors)
batch_size = 1
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_valid_bboxes = []
mlvl_valid_anchors = []
mlvl_scores = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).\
reshape(batch_size, -1, 4)
# use static anchor if input shape is static
anchors = anchors.expand_as(bbox_pred).data
if pre_topk > 0:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(-1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = scores[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
topk_inds = topk_inds.view(-1)
anchors = anchors[:, topk_inds, :]
bbox_pred = bbox_pred[:, topk_inds, :]
scores = scores[:, topk_inds, :]
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,
batch_mlvl_valid_bboxes,
max_shape=img_metas[0]['img_shape'])
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
num_levels = len(cls_scores)
aspect_ratio = [
ratio[ratio > 1].detach().cpu().numpy()
for ratio in self.anchor_generator.ratios
]
min_sizes = self.anchor_generator.base_sizes
max_sizes = min_sizes[1:] + \
img_metas[0]['img_shape'][0:1].tolist()
img_height = img_metas[0]['img_shape'][0].item()
img_width = img_metas[0]['img_shape'][1].item()
# if no reshape, concat will be error in ncnn.
mlvl_anchors = [
ncnn_prior_box_forward(cls_scores[i], aspect_ratio[i], img_height,
img_width, max_sizes[i:i + 1],
min_sizes[i:i + 1]).reshape(1, 2, -1)
for i in range(num_levels)
]
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) \
== len(mlvl_anchors)
batch_size = 1
mlvl_valid_bboxes = []
mlvl_scores = []
for level_id, cls_score, bbox_pred in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3, 1). \
reshape(batch_size, -1, 4)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(cls_score)
# NCNN DetectionOutput layer uses background class at 0 position, but
# in mmdetection, background class is at self.num_classes position.
# We should adapt for ncnn.
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
if self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores.sigmoid()
else:
batch_mlvl_scores = batch_mlvl_scores.softmax(-1)
batch_mlvl_anchors = torch.cat(mlvl_anchors, dim=2)
batch_mlvl_scores = torch.cat([
batch_mlvl_scores[:, :, self.num_classes:],
batch_mlvl_scores[:, :, 0:self.num_classes]
],
dim=2)
batch_mlvl_valid_bboxes = batch_mlvl_valid_bboxes. \
reshape(batch_size, 1, -1)
batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1)
batch_mlvl_anchors = batch_mlvl_anchors.reshape(batch_size, 2, -1)
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
output__ncnn = ncnn_detection_output_forward(
batch_mlvl_valid_bboxes, batch_mlvl_scores, batch_mlvl_anchors,
score_threshold, iou_threshold, pre_top_k, keep_top_k,
self.num_classes + 1)
return output__ncnn

View File

@ -9,7 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER
def tensor__size__ncnn(ctx, self, *args):
"""Rewrite `size` for NCNN backend.
ONNX Shape node is not supported in ncnn. This function return integal
ONNX Shape node is not supported in ncnn. This function return integer
instead of Torch.Size to avoid ONNX Shape node.
"""

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import warnings
from typing import Any, Callable, Dict, List, Tuple, Union
@ -9,7 +10,6 @@ import pytest
import torch
from torch import nn
# Register the rewrite functions
import mmdeploy.codebase # noqa: F401,F403
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend, get_backend, get_onnx_config
@ -390,7 +390,7 @@ def get_backend_outputs(onnx_file_path: str,
flatten_model_inputs = get_flatten_inputs(model_inputs)
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
output_names = get_onnx_config(deploy_cfg).get('output_names', None)
backend_files = [onnx_file_path]
# prepare backend model and input features
if backend == Backend.TENSORRT:
# convert to engine
@ -436,9 +436,20 @@ def get_backend_outputs(onnx_file_path: str,
backend_feats[input_names[i]] = feature_list[i]
else:
backend_feats[str(i)] = feature_list[i]
backend_files = [onnx_file_path]
device = 'cpu'
elif backend == Backend.NCNN:
return None
import mmdeploy.apis.ncnn as ncnn_apis
if not (ncnn_apis.is_available() and ncnn_apis.is_plugin_available()):
return None
work_dir = tempfile.TemporaryDirectory().name
param_path, bin_path = ncnn_apis.get_output_model_file(
onnx_file_path, work_dir)
ncnn_apis.onnx2ncnn(onnx_file_path, param_path, bin_path)
backend_files = [param_path, bin_path]
backend_feats = flatten_model_inputs
device = 'cpu'
elif backend == Backend.OPENVINO:
import mmdeploy.apis.openvino as openvino_apis
if not openvino_apis.is_available():
@ -473,13 +484,16 @@ def get_backend_outputs(onnx_file_path: str,
def get_rewrite_outputs(wrapped_model: nn.Module,
model_inputs: Dict[str, Union[Tuple, List,
torch.Tensor]],
deploy_cfg: mmcv.Config) -> Tuple[Any, bool]:
deploy_cfg: mmcv.Config,
run_with_backend: bool = True) -> Tuple[Any, bool]:
"""To get outputs of generated onnx model after rewrite.
Args:
wrapped_model (nn.Module): The input model.
model_inputs (dict): Inputs for model.
deploy_cfg (mmcv.Config): Deployment config.
run_with_backend (bool): Whether to run inference with backend.
Default is True.
Returns:
List[torch.Tensor]: The outputs of model.
@ -494,8 +508,10 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
onnx_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg)
backend_outputs = get_backend_outputs(onnx_file_path, model_inputs,
deploy_cfg)
backend_outputs = None
if run_with_backend:
backend_outputs = get_backend_outputs(onnx_file_path, model_inputs,
deploy_cfg)
if backend_outputs is None:
return ctx_outputs, False

View File

@ -108,7 +108,7 @@ def test_shufflenetv2_backbone__forward(backend_type: Backend):
model_outputs = model.forward(imgs)
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'x': imgs}
rewrite_outputs, _ = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

View File

@ -66,9 +66,12 @@ def test_multiclass_nms_static():
'outputs: {}'.format(rewrite_outputs)
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize('add_ctr_clamp', [True, False])
def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool):
@pytest.mark.parametrize('clip_border,max_shape',
[(False, None), (True, torch.tensor([100, 200]))])
def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
clip_border: bool, max_shape: tuple):
check_backend(backend_type)
deploy_cfg = mmcv.Config(
dict(
@ -91,21 +94,21 @@ def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool):
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={
'rois': rois,
'deltas': deltas
'rois': rois.unsqueeze(0),
'deltas': deltas.unsqueeze(0)
},
deploy_cfg=deploy_cfg)
if is_backend_output:
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_tblr2bbox(backend_type: Backend):
check_backend(backend_type)
deploy_cfg = mmcv.Config(

View File

@ -185,14 +185,6 @@ def test_get_bboxes_of_fcos_head(backend_type: Backend):
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'centernesses': centernesses,
'img_metas': img_metas
}
model_outputs = get_model_outputs(fcos_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
@ -205,26 +197,9 @@ def test_get_bboxes_of_fcos_head(backend_type: Backend):
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None
deploy_cfg=deploy_cfg,
run_with_backend=False)
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
@ -275,10 +250,13 @@ def test_get_bboxes_of_rpn_head(backend_type: Backend):
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
# do not run with ncnn backend
run_with_backend = False if backend_type in [Backend.NCNN] else True
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
deploy_cfg=deploy_cfg,
run_with_backend=run_with_backend)
assert rewrite_outputs is not None
@ -785,7 +763,7 @@ def test_yolov3_head_get_bboxes_ncnn():
'img_shape': (s, s, 3)
}]
output_names = ['outout']
output_names = ['detection_output']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
@ -1019,7 +997,7 @@ def test_get_bboxes_of_vfnet_head(backend_type: Backend):
@pytest.mark.parametrize('backend_type',
[Backend.ONNXRUNTIME, Backend.NCNN, Backend.OPENVINO])
[Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_base_dense_head_get_bboxes(backend_type: Backend):
"""Test get_bboxes rewrite of base dense head."""
check_backend(backend_type)
@ -1102,10 +1080,72 @@ def test_base_dense_head_get_bboxes(backend_type: Backend):
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_ssd_head_get_bboxes(backend_type: Backend):
"""Test get_bboxes rewrite of anchor head."""
def test_base_dense_head_get_bboxes__ncnn():
"""Test get_bboxes rewrite of base dense head."""
backend_type = Backend.NCNN
check_backend(backend_type)
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['output']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
model_type='ncnn_end2end',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))))
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
# output should be of shape [1, N, 6]
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
assert rewrite_outputs.shape[-1] == 6
@pytest.mark.parametrize('is_dynamic', [True, False])
def test_ssd_head_get_bboxes__ncnn(is_dynamic: bool):
"""Test get_bboxes rewrite of ssd head for ncnn."""
check_backend(Backend.NCNN)
ssd_head = get_ssd_head_model()
ssd_head.cpu().eval()
s = 128
@ -1115,13 +1155,30 @@ def test_ssd_head_get_bboxes(backend_type: Backend):
'img_shape': (s, s, 3)
}]
output_names = ['output']
input_names = ['input']
dynamic_axes = None
if is_dynamic:
dynamic_axes = {
input_names[0]: {
2: 'height',
3: 'width'
},
output_names[0]: {
1: 'num_dets',
}
}
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
backend_config=dict(type=Backend.NCNN.value),
onnx_config=dict(
input_names=input_names,
output_names=output_names,
input_shape=None,
dynamic_axes=dynamic_axes),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
model_type='ncnn_end2end',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
@ -1149,16 +1206,8 @@ def test_ssd_head_get_bboxes(backend_type: Backend):
for i in range(num_prior)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(ssd_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32)
img_metas[0]['img_shape'] = torch.tensor([s, s]) if is_dynamic else [s, s]
wrapped_model = WrapModel(
ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
@ -1170,19 +1219,8 @@ def test_ssd_head_get_bboxes(backend_type: Backend):
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
# output should be of shape [1, N, 6]
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze().cpu().numpy()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None
rewrite_outputs = rewrite_outputs[0]
assert rewrite_outputs.shape[-1] == 6

View File

@ -134,17 +134,17 @@ def get_base_recognizer_model():
return model
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_bidirectionallstm(backend_type: Backend):
@pytest.mark.parametrize('backend', [Backend.NCNN])
def test_bidirectionallstm(backend: Backend):
"""Test forward rewrite of bidirectionallstm."""
check_backend(backend_type)
check_backend(backend)
bilstm = get_bidirectionallstm_model()
bilstm.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(input_shape=None),
backend_config=dict(type=backend.value),
onnx_config=dict(output_names=['output'], input_shape=None),
codebase_config=dict(
type='mmocr',
task='TextRecognition',
@ -161,25 +161,30 @@ def test_bidirectionallstm(backend_type: Backend):
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(bilstm, 'forward')
rewrite_inputs = {'input': input}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
deploy_cfg=deploy_cfg,
run_with_backend=False)
if is_backend_output:
model_output = model_outputs.cpu().numpy()
rewrite_output = rewrite_outputs[0].cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None
def test_simple_test_of_single_stage_text_detector():
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_simple_test_of_single_stage_text_detector(backend: Backend):
"""Test simple_test single_stage_text_detector."""
check_backend(backend)
single_stage_text_detector = get_single_stage_text_detector_model()
single_stage_text_detector.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='default'),
backend_config=dict(type=backend.value),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
@ -187,40 +192,36 @@ def test_simple_test_of_single_stage_text_detector():
)))
input = torch.rand(1, 3, 64, 64)
img_metas = [{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'pad_shape': [64, 64, 3],
'scale_factor': [1., 1., 1., 1],
}]
x = single_stage_text_detector.extract_feat(input)
model_outputs = single_stage_text_detector.bbox_head(x)
wrapped_model = WrapModel(single_stage_text_detector, 'simple_test')
rewrite_inputs = {'img': input, 'img_metas': img_metas[0]}
rewrite_outputs = get_rewrite_outputs(
rewrite_inputs = {'img': input}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
deploy_cfg=deploy_cfg,
run_with_backend=False)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
model_outputs = model_outputs.cpu().numpy()
rewrite_outputs = rewrite_outputs.cpu().numpy()
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
@pytest.mark.parametrize('backend', [Backend.NCNN])
@pytest.mark.parametrize('rnn_flag', [True, False])
def test_crnndecoder(backend_type: Backend, rnn_flag: bool):
def test_crnndecoder(backend: Backend, rnn_flag: bool):
"""Test forward rewrite of crnndecoder."""
check_backend(backend_type)
check_backend(backend)
crnn_decoder = get_crnn_decoder_model(rnn_flag)
crnn_decoder.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
backend_config=dict(type=backend.value),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
@ -250,32 +251,39 @@ def test_crnndecoder(backend_type: Backend, rnn_flag: bool):
targets_dict=targets_dict,
img_metas=img_metas)
rewrite_inputs = {'feat': input}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
deploy_cfg=deploy_cfg,
run_with_backend=False)
if is_backend_output:
for model_output, rewrite_output in zip(model_outputs,
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize(
'img_metas', [[[{}]], [[{
'resize_shape': [32, 32],
'valid_ratio': 1.0
}]]])
@pytest.mark.parametrize('is_dynamic', [True, False])
def test_forward_of_base_recognizer(img_metas, is_dynamic):
def test_forward_of_base_recognizer(img_metas, is_dynamic, backend):
"""Test forward base_recognizer."""
check_backend(backend)
base_recognizer = get_base_recognizer_model()
base_recognizer.eval()
if not is_dynamic:
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='ncnn'),
backend_config=dict(type=backend.value),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
@ -284,7 +292,7 @@ def test_forward_of_base_recognizer(img_metas, is_dynamic):
else:
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='ncnn'),
backend_config=dict(type=backend.value),
onnx_config=dict(
input_shape=None,
dynamic_axes={
@ -317,26 +325,29 @@ def test_forward_of_base_recognizer(img_metas, is_dynamic):
rewrite_inputs = {
'img': input,
}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
model_outputs = model_outputs.cpu().numpy()
rewrite_outputs = rewrite_outputs.cpu().numpy()
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
def test_simple_test_of_encode_decode_recognizer():
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_simple_test_of_encode_decode_recognizer(backend):
"""Test simple_test encode_decode_recognizer."""
check_backend(backend)
encode_decode_recognizer = get_encode_decode_recognizer_model()
encode_decode_recognizer.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='default'),
backend_config=dict(type=backend.value),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
@ -356,28 +367,28 @@ def test_simple_test_of_encode_decode_recognizer():
wrapped_model = WrapModel(
encode_decode_recognizer, 'simple_test', img_metas=img_metas)
rewrite_inputs = {'img': input}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
model_outputs = model_outputs.cpu().numpy()
rewrite_outputs = rewrite_outputs.cpu().numpy()
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', [Backend.TENSORRT])
def test_forward_of_fpnc(backend_type: Backend):
@pytest.mark.parametrize('backend', [Backend.TENSORRT])
def test_forward_of_fpnc(backend: Backend):
"""Test forward rewrite of fpnc."""
check_backend(backend_type)
check_backend(backend)
fpnc = get_fpnc_neck_model()
fpnc.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(
type=backend_type.value,
type=backend.value,
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
@ -399,22 +410,17 @@ def test_forward_of_fpnc(backend_type: Backend):
rewrite_inputs = {
'inputs': input,
}
rewrite_outputs, is_need_name = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_need_name:
model_output = model_outputs[0].squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
for model_output, rewrite_output in zip(model_outputs,
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze().cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
model_outputs = model_outputs.cpu().numpy()
rewrite_outputs = rewrite_outputs.cpu().numpy()
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
def get_sar_model_cfg(decoder_type: str):
@ -468,11 +474,11 @@ def get_sar_model_cfg(decoder_type: str):
dict(model=model, data=dict(test=dict(pipeline=test_pipeline))))
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize('decoder_type',
['SequentialSARDecoder', 'ParallelSARDecoder'])
def test_sar_model(backend_type: Backend, decoder_type):
check_backend(backend_type)
def test_sar_model(backend: Backend, decoder_type):
check_backend(backend)
import os.path as osp
import onnx
from mmocr.models.textrecog import SARNet
@ -483,7 +489,7 @@ def test_sar_model(backend_type: Backend, decoder_type):
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
backend_config=dict(type=backend.value),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
@ -492,11 +498,11 @@ def test_sar_model(backend_type: Backend, decoder_type):
# patch model
pytorch_model.cfg = sar_cfg
patched_model = patch_model(
pytorch_model, cfg=deploy_cfg, backend=backend_type.value)
pytorch_model, cfg=deploy_cfg, backend=backend.value)
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
with RewriterContext(
cfg=deploy_cfg, backend=backend_type.value), torch.no_grad():
cfg=deploy_cfg, backend=backend.value), torch.no_grad():
torch.onnx.export(
patched_model,
tuple([v for k, v in model_inputs.items()]),

View File

@ -9,8 +9,8 @@ from mmseg.models import BACKBONES, HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Codebase
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)
import_codebase(Codebase.MMSEG)
@ -93,14 +93,15 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
return mm_inputs
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_encoderdecoder_simple_test(backend_type):
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_encoderdecoder_simple_test(backend):
check_backend(backend)
segmentor = get_model()
segmentor.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
backend_config=dict(type=backend.value),
onnx_config=dict(output_names=['result'], input_shape=None),
codebase_config=dict(type='mmseg', task='Segmentation')))
@ -127,23 +128,21 @@ def test_encoderdecoder_simple_test(backend_type):
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
model_outputs = torch.tensor(model_outputs[0])
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
assert torch.allclose(rewrite_outputs, model_outputs)
else:
assert rewrite_outputs is not None
model_outputs = torch.tensor(model_outputs[0])
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(rewrite_outputs, model_outputs)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_basesegmentor_forward(backend_type):
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_basesegmentor_forward(backend):
check_backend(backend)
segmentor = get_model()
segmentor.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
backend_config=dict(type=backend.value),
onnx_config=dict(output_names=['result'], input_shape=None),
codebase_config=dict(type='mmseg', task='Segmentation')))
@ -167,17 +166,16 @@ def test_basesegmentor_forward(backend_type):
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = torch.tensor(rewrite_outputs[0])
model_outputs = torch.tensor(model_outputs[0])
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
assert torch.allclose(rewrite_outputs, model_outputs)
else:
assert rewrite_outputs is not None
model_outputs = torch.tensor(model_outputs[0])
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(rewrite_outputs, model_outputs)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_aspphead_forward(backend_type):
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_aspphead_forward(backend):
check_backend(backend)
from mmseg.models.decode_heads import ASPPHead
head = ASPPHead(
in_channels=32, channels=16, num_classes=19,
@ -185,7 +183,7 @@ def test_aspphead_forward(backend_type):
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
backend_config=dict(type=backend.value),
onnx_config=dict(
output_names=['result'], input_shape=(1, 32, 45, 45)),
codebase_config=dict(type='mmseg', task='Segmentation')))
@ -200,21 +198,23 @@ def test_aspphead_forward(backend_type):
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = torch.tensor(rewrite_outputs[0])
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None
rewrite_outputs = rewrite_outputs[0]
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
def test_psphead_forward(backend_type):
@pytest.mark.parametrize('backend',
[Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN])
def test_psphead_forward(backend):
check_backend(backend)
from mmseg.models.decode_heads import PSPHead
head = PSPHead(in_channels=32, channels=16, num_classes=19).eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
backend_config=dict(type=backend.value),
onnx_config=dict(output_names=['result'], input_shape=None),
codebase_config=dict(type='mmseg', task='Segmentation')))
inputs = [torch.randn(1, 32, 45, 45)]
@ -228,7 +228,7 @@ def test_psphead_forward(backend_type):
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = torch.tensor(rewrite_outputs[0])
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)
else:
assert rewrite_outputs is not None
rewrite_outputs = rewrite_outputs[0]
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)

View File

@ -49,7 +49,8 @@ def test_get_attribute():
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'tensor': input},
deploy_cfg=deploy_cfg_ncnn)
deploy_cfg=deploy_cfg_ncnn,
run_with_backend=False)
assert rewrite_outputs is not None, 'Got unexpected rewrite '
'outputs: {}'.format(rewrite_outputs)
@ -69,7 +70,8 @@ def test_group_norm_ncnn():
rewrite_output, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=deploy_cfg_ncnn)
deploy_cfg=deploy_cfg_ncnn,
run_with_backend=False)
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@ -86,7 +88,8 @@ def test_interpolate_static():
rewrite_output, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=deploy_cfg_ncnn)
deploy_cfg=deploy_cfg_ncnn,
run_with_backend=False)
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@ -105,7 +108,8 @@ def test_linear_ncnn():
rewrite_output, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=deploy_cfg_ncnn)
deploy_cfg=deploy_cfg_ncnn,
run_with_backend=False)
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@ -148,7 +152,8 @@ def test_size_of_tensor_static():
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=deploy_cfg_ncnn)
deploy_cfg=deploy_cfg_ncnn,
run_with_backend=False)
assert rewrite_outputs is not None, 'Got unexpected rewrite '
'outputs: {}'.format(rewrite_outputs)
@ -175,7 +180,8 @@ class TestTopk:
output, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': TestTopk.input},
deploy_cfg=deploy_cfg_ncnn)
deploy_cfg=deploy_cfg_ncnn,
run_with_backend=False)
assert np.allclose(model_output, output[1], rtol=1e-03, atol=1e-05)