mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
Update ncnn test (#298)
* update ncnn test * type hint * update test ocr * update mmseg ut * ignore ncnn rpn head test * add logging * fix ssd base dense head test * recover bacth in ncnn wrapper * fix ncnn_ops_ut * fix yapf * recover test ops * fix run_with_backend False * Revert "fix run_with_backend False" This reverts commit 83f8f915a25e800f5c2db339584d164ba40b2d9b. * disable ncnn test test_pytorch_functions.py Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: hanrui1sensetime <hanrui1@sensetime.com>
This commit is contained in:
parent
270d98a8a2
commit
fabdb473bb
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import ncnn
|
||||
@ -77,36 +78,46 @@ class NCNNWrapper(BaseWrapper):
|
||||
"""
|
||||
input_list = list(inputs.values())
|
||||
batch_size = input_list[0].size(0)
|
||||
# assert batch_size == 1, 'Only batch_size=1 is supported!'
|
||||
if batch_size > 1:
|
||||
logging.warning(
|
||||
f'ncnn only support batch_size = 1, but given {batch_size}')
|
||||
for input_tensor in input_list[1:]:
|
||||
assert input_tensor.size(
|
||||
0) == batch_size, 'All tensors should have same batch size'
|
||||
assert input_tensor.device.type == 'cpu', \
|
||||
'NCNN only supports cpu device'
|
||||
|
||||
# set output names
|
||||
output_names = self._output_names
|
||||
|
||||
# create output dict
|
||||
outputs = dict([name, [None] * batch_size] for name in output_names)
|
||||
# run inference
|
||||
for batch_id in range(batch_size):
|
||||
# create extractor
|
||||
ex = self._net.create_extractor()
|
||||
|
||||
# create extractor
|
||||
ex = self._net.create_extractor()
|
||||
# set inputs
|
||||
for name, input_tensor in inputs.items():
|
||||
data = input_tensor[0].contiguous().cpu().numpy()
|
||||
input_mat = ncnn.Mat(data)
|
||||
ex.input(name, input_mat)
|
||||
# set inputs
|
||||
for name, input_tensor in inputs.items():
|
||||
data = input_tensor[batch_id].contiguous()
|
||||
data = data.detach().cpu().numpy()
|
||||
input_mat = ncnn.Mat(data)
|
||||
ex.input(name, input_mat)
|
||||
|
||||
# get outputs
|
||||
result = self.__ncnn_execute(
|
||||
extractor=ex, output_names=output_names)
|
||||
for name in output_names:
|
||||
mat = result[name]
|
||||
# deal with special case
|
||||
if mat.empty():
|
||||
mat = None
|
||||
logging.warning(
|
||||
f'The "{name}" output of ncnn model is empty.')
|
||||
outputs[name][batch_id] = torch.from_numpy(np.array(mat))
|
||||
|
||||
# stack outputs together
|
||||
for name, output_tensor in outputs.items():
|
||||
outputs[name] = torch.stack(output_tensor)
|
||||
|
||||
# get outputs
|
||||
result = self.__ncnn_execute(extractor=ex, output_names=output_names)
|
||||
for name in output_names:
|
||||
mat = result[name]
|
||||
# deal with special case
|
||||
if mat.empty():
|
||||
outputs[name] = None
|
||||
continue
|
||||
outputs[name] = torch.from_numpy(np.array(mat)).unsqueeze(0)
|
||||
return outputs
|
||||
|
||||
@TimeCounter.count_time()
|
||||
|
@ -147,7 +147,7 @@ class ObjectDetection(BaseTask):
|
||||
output_file: str,
|
||||
window_name: str,
|
||||
show_result: bool = False,
|
||||
score_thr=0.3):
|
||||
score_thr: float = 0.3):
|
||||
"""Visualize predictions of a model.
|
||||
|
||||
Args:
|
||||
|
@ -1,226 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params, multiclass_nms
|
||||
from mmdeploy.codebase.mmdet.core.ops.detection_output import \
|
||||
ncnn_detection_output_forward
|
||||
from mmdeploy.codebase.mmdet.core.ops.prior_box import ncnn_prior_box_forward
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.AnchorHead.get_bboxes', backend='ncnn')
|
||||
def anchor_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of AnchorHead for NCNN backend.
|
||||
|
||||
Shape node and batch inference is not supported by ncnn. This function
|
||||
transform dynamic shape to constant shape and remove batch inference.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
cls_scores (list[Tensor]): Box scores for each level in the
|
||||
feature pyramid, has shape
|
||||
(N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each
|
||||
level in the feature pyramid, has shape
|
||||
(N, num_anchors * 4, H, W).
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
Default: None.
|
||||
|
||||
|
||||
Returns:
|
||||
If isinstance(self, SSDHead) == True:
|
||||
Tensor: outputs, shape is [N, num_det, 5].
|
||||
If isinstance(self, SSDHead) == False:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
from mmdet.models.dense_heads import SSDHead
|
||||
|
||||
# now the ncnn PriorBox and DetectionOutput adaption is only used for
|
||||
# SSDHead.
|
||||
# TODO: Adapt all of the AnchorHead instances for ncnn PriorBox and
|
||||
# DetectionOutput. Then, the determine statement will be removed, and
|
||||
# the code will be unified.
|
||||
if not isinstance(self, SSDHead):
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
device = cls_scores[0].device
|
||||
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
|
||||
mlvl_anchors = self.anchor_generator.grid_anchors(
|
||||
featmap_sizes, device=device)
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) \
|
||||
== len(mlvl_anchors)
|
||||
batch_size = 1
|
||||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
# loop over features, decode boxes
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_valid_anchors = []
|
||||
mlvl_scores = []
|
||||
for level_id, cls_score, bbox_pred, anchors in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds,
|
||||
mlvl_anchors):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
cls_score = cls_score.permute(0, 2, 3,
|
||||
1).reshape(batch_size, -1,
|
||||
self.cls_out_channels)
|
||||
if self.use_sigmoid_cls:
|
||||
scores = cls_score.sigmoid()
|
||||
else:
|
||||
scores = cls_score.softmax(-1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).\
|
||||
reshape(batch_size, -1, 4)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
anchors = anchors.expand_as(bbox_pred).data
|
||||
|
||||
if pre_topk > 0:
|
||||
# Get maximum scores for foreground classes.
|
||||
if self.use_sigmoid_cls:
|
||||
max_scores, _ = scores.max(-1)
|
||||
else:
|
||||
# remind that we set FG labels to [0, num_class-1]
|
||||
# since mmdet v2.0
|
||||
# BG cat_id: num_class
|
||||
max_scores, _ = scores[..., :-1].max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
|
||||
topk_inds = topk_inds.view(-1)
|
||||
anchors = anchors[:, topk_inds, :]
|
||||
bbox_pred = bbox_pred[:, topk_inds, :]
|
||||
scores = scores[:, topk_inds, :]
|
||||
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
||||
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
|
||||
batch_mlvl_bboxes = self.bbox_coder.decode(
|
||||
batch_mlvl_anchors,
|
||||
batch_mlvl_valid_bboxes,
|
||||
max_shape=img_metas[0]['img_shape'])
|
||||
|
||||
# ignore background class
|
||||
if not self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
return multiclass_nms(
|
||||
batch_mlvl_bboxes,
|
||||
batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
else:
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
num_levels = len(cls_scores)
|
||||
aspect_ratio = [
|
||||
ratio[ratio > 1].detach().cpu().numpy()
|
||||
for ratio in self.anchor_generator.ratios
|
||||
]
|
||||
min_sizes = self.anchor_generator.base_sizes
|
||||
max_sizes = min_sizes[1:] + \
|
||||
img_metas[0]['img_shape'][0:1].tolist()
|
||||
img_height = img_metas[0]['img_shape'][0].item()
|
||||
img_width = img_metas[0]['img_shape'][1].item()
|
||||
|
||||
# if no reshape, concat will be error in ncnn.
|
||||
mlvl_anchors = [
|
||||
ncnn_prior_box_forward(cls_scores[i], aspect_ratio[i], img_height,
|
||||
img_width, max_sizes[i:i + 1],
|
||||
min_sizes[i:i + 1]).reshape(1, 2, -1)
|
||||
for i in range(num_levels)
|
||||
]
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) \
|
||||
== len(mlvl_anchors)
|
||||
batch_size = 1
|
||||
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_scores = []
|
||||
for level_id, cls_score, bbox_pred in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
cls_score = cls_score.permute(0, 2, 3,
|
||||
1).reshape(batch_size, -1,
|
||||
self.cls_out_channels)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1). \
|
||||
reshape(batch_size, -1, 4)
|
||||
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(cls_score)
|
||||
|
||||
# NCNN DetectionOutput layer uses background class at 0 position, but
|
||||
# in mmdetection, background class is at self.num_classes position.
|
||||
# We should adapt for ncnn.
|
||||
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
if self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores.sigmoid()
|
||||
else:
|
||||
batch_mlvl_scores = batch_mlvl_scores.softmax(-1)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_anchors, dim=2)
|
||||
batch_mlvl_scores = torch.cat([
|
||||
batch_mlvl_scores[:, :, self.num_classes:],
|
||||
batch_mlvl_scores[:, :, 0:self.num_classes]
|
||||
],
|
||||
dim=2)
|
||||
batch_mlvl_valid_bboxes = batch_mlvl_valid_bboxes. \
|
||||
reshape(batch_size, 1, -1)
|
||||
batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_anchors = batch_mlvl_anchors.reshape(batch_size, 2, -1)
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
output__ncnn = ncnn_detection_output_forward(
|
||||
batch_mlvl_valid_bboxes, batch_mlvl_scores, batch_mlvl_anchors,
|
||||
score_threshold, iou_threshold, pre_top_k, keep_top_k,
|
||||
self.num_classes + 1)
|
||||
|
||||
return output__ncnn
|
@ -9,7 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
def tensor__size__ncnn(ctx, self, *args):
|
||||
"""Rewrite `size` for NCNN backend.
|
||||
|
||||
ONNX Shape node is not supported in ncnn. This function return integal
|
||||
ONNX Shape node is not supported in ncnn. This function return integer
|
||||
instead of Torch.Size to avoid ONNX Shape node.
|
||||
"""
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
@ -9,7 +10,6 @@ import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# Register the rewrite functions
|
||||
import mmdeploy.codebase # noqa: F401,F403
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import Backend, get_backend, get_onnx_config
|
||||
@ -390,7 +390,7 @@ def get_backend_outputs(onnx_file_path: str,
|
||||
flatten_model_inputs = get_flatten_inputs(model_inputs)
|
||||
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
|
||||
output_names = get_onnx_config(deploy_cfg).get('output_names', None)
|
||||
backend_files = [onnx_file_path]
|
||||
|
||||
# prepare backend model and input features
|
||||
if backend == Backend.TENSORRT:
|
||||
# convert to engine
|
||||
@ -436,9 +436,20 @@ def get_backend_outputs(onnx_file_path: str,
|
||||
backend_feats[input_names[i]] = feature_list[i]
|
||||
else:
|
||||
backend_feats[str(i)] = feature_list[i]
|
||||
backend_files = [onnx_file_path]
|
||||
device = 'cpu'
|
||||
elif backend == Backend.NCNN:
|
||||
return None
|
||||
import mmdeploy.apis.ncnn as ncnn_apis
|
||||
if not (ncnn_apis.is_available() and ncnn_apis.is_plugin_available()):
|
||||
return None
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
param_path, bin_path = ncnn_apis.get_output_model_file(
|
||||
onnx_file_path, work_dir)
|
||||
ncnn_apis.onnx2ncnn(onnx_file_path, param_path, bin_path)
|
||||
backend_files = [param_path, bin_path]
|
||||
backend_feats = flatten_model_inputs
|
||||
device = 'cpu'
|
||||
|
||||
elif backend == Backend.OPENVINO:
|
||||
import mmdeploy.apis.openvino as openvino_apis
|
||||
if not openvino_apis.is_available():
|
||||
@ -473,13 +484,16 @@ def get_backend_outputs(onnx_file_path: str,
|
||||
def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
model_inputs: Dict[str, Union[Tuple, List,
|
||||
torch.Tensor]],
|
||||
deploy_cfg: mmcv.Config) -> Tuple[Any, bool]:
|
||||
deploy_cfg: mmcv.Config,
|
||||
run_with_backend: bool = True) -> Tuple[Any, bool]:
|
||||
"""To get outputs of generated onnx model after rewrite.
|
||||
|
||||
Args:
|
||||
wrapped_model (nn.Module): The input model.
|
||||
model_inputs (dict): Inputs for model.
|
||||
deploy_cfg (mmcv.Config): Deployment config.
|
||||
run_with_backend (bool): Whether to run inference with backend.
|
||||
Default is True.
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]: The outputs of model.
|
||||
@ -494,8 +508,10 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
|
||||
onnx_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg)
|
||||
|
||||
backend_outputs = get_backend_outputs(onnx_file_path, model_inputs,
|
||||
deploy_cfg)
|
||||
backend_outputs = None
|
||||
if run_with_backend:
|
||||
backend_outputs = get_backend_outputs(onnx_file_path, model_inputs,
|
||||
deploy_cfg)
|
||||
|
||||
if backend_outputs is None:
|
||||
return ctx_outputs, False
|
||||
|
@ -108,7 +108,7 @@ def test_shufflenetv2_backbone__forward(backend_type: Backend):
|
||||
model_outputs = model.forward(imgs)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_inputs = {'x': imgs}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
@ -66,9 +66,12 @@ def test_multiclass_nms_static():
|
||||
'outputs: {}'.format(rewrite_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize('add_ctr_clamp', [True, False])
|
||||
def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool):
|
||||
@pytest.mark.parametrize('clip_border,max_shape',
|
||||
[(False, None), (True, torch.tensor([100, 200]))])
|
||||
def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
|
||||
clip_border: bool, max_shape: tuple):
|
||||
check_backend(backend_type)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
@ -91,21 +94,21 @@ def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool):
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={
|
||||
'rois': rois,
|
||||
'deltas': deltas
|
||||
'rois': rois.unsqueeze(0),
|
||||
'deltas': deltas.unsqueeze(0)
|
||||
},
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
model_output = original_outputs.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_outputs[0].squeeze()
|
||||
rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_tblr2bbox(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
deploy_cfg = mmcv.Config(
|
||||
|
@ -185,14 +185,6 @@ def test_get_bboxes_of_fcos_head(backend_type: Backend):
|
||||
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'centernesses': centernesses,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(fcos_head, 'get_bboxes', model_inputs)
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
@ -205,26 +197,9 @@ def test_get_bboxes_of_fcos_head(backend_type: Backend):
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = [
|
||||
value for name, value in rewrite_outputs.items()
|
||||
if name in output_names
|
||||
]
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
assert np.allclose(
|
||||
model_output[:rewrite_output.shape[0]],
|
||||
rewrite_output,
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
|
||||
@ -275,10 +250,13 @@ def test_get_bboxes_of_rpn_head(backend_type: Backend):
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
# do not run with ncnn backend
|
||||
run_with_backend = False if backend_type in [Backend.NCNN] else True
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=run_with_backend)
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@ -785,7 +763,7 @@ def test_yolov3_head_get_bboxes_ncnn():
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['outout']
|
||||
output_names = ['detection_output']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
@ -1019,7 +997,7 @@ def test_get_bboxes_of_vfnet_head(backend_type: Backend):
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type',
|
||||
[Backend.ONNXRUNTIME, Backend.NCNN, Backend.OPENVINO])
|
||||
[Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_base_dense_head_get_bboxes(backend_type: Backend):
|
||||
"""Test get_bboxes rewrite of base dense head."""
|
||||
check_backend(backend_type)
|
||||
@ -1102,10 +1080,72 @@ def test_base_dense_head_get_bboxes(backend_type: Backend):
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_ssd_head_get_bboxes(backend_type: Backend):
|
||||
"""Test get_bboxes rewrite of anchor head."""
|
||||
def test_base_dense_head_get_bboxes__ncnn():
|
||||
"""Test get_bboxes rewrite of base dense head."""
|
||||
backend_type = Backend.NCNN
|
||||
check_backend(backend_type)
|
||||
anchor_head = get_anchor_head_model()
|
||||
anchor_head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['output']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
model_type='ncnn_end2end',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
|
||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
# output should be of shape [1, N, 6]
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
assert rewrite_outputs.shape[-1] == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize('is_dynamic', [True, False])
|
||||
def test_ssd_head_get_bboxes__ncnn(is_dynamic: bool):
|
||||
"""Test get_bboxes rewrite of ssd head for ncnn."""
|
||||
check_backend(Backend.NCNN)
|
||||
ssd_head = get_ssd_head_model()
|
||||
ssd_head.cpu().eval()
|
||||
s = 128
|
||||
@ -1115,13 +1155,30 @@ def test_ssd_head_get_bboxes(backend_type: Backend):
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
output_names = ['output']
|
||||
input_names = ['input']
|
||||
dynamic_axes = None
|
||||
if is_dynamic:
|
||||
dynamic_axes = {
|
||||
input_names[0]: {
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
output_names[0]: {
|
||||
1: 'num_dets',
|
||||
}
|
||||
}
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
backend_config=dict(type=Backend.NCNN.value),
|
||||
onnx_config=dict(
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
input_shape=None,
|
||||
dynamic_axes=dynamic_axes),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
model_type='ncnn_end2end',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
@ -1149,16 +1206,8 @@ def test_ssd_head_get_bboxes(backend_type: Backend):
|
||||
for i in range(num_prior)
|
||||
]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(ssd_head, 'get_bboxes', model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32)
|
||||
img_metas[0]['img_shape'] = torch.tensor([s, s]) if is_dynamic else [s, s]
|
||||
wrapped_model = WrapModel(
|
||||
ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
@ -1170,19 +1219,8 @@ def test_ssd_head_get_bboxes(backend_type: Backend):
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
# output should be of shape [1, N, 6]
|
||||
if is_backend_output:
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze().cpu().numpy()
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
assert np.allclose(
|
||||
model_output[:rewrite_output.shape[0]],
|
||||
rewrite_output,
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
assert rewrite_outputs.shape[-1] == 6
|
||||
|
@ -134,17 +134,17 @@ def get_base_recognizer_model():
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_bidirectionallstm(backend_type: Backend):
|
||||
@pytest.mark.parametrize('backend', [Backend.NCNN])
|
||||
def test_bidirectionallstm(backend: Backend):
|
||||
"""Test forward rewrite of bidirectionallstm."""
|
||||
check_backend(backend_type)
|
||||
check_backend(backend)
|
||||
bilstm = get_bidirectionallstm_model()
|
||||
bilstm.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(input_shape=None),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['output'], input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextRecognition',
|
||||
@ -161,25 +161,30 @@ def test_bidirectionallstm(backend_type: Backend):
|
||||
# to get outputs of onnx model after rewrite
|
||||
wrapped_model = WrapModel(bilstm, 'forward')
|
||||
rewrite_inputs = {'input': input}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
if is_backend_output:
|
||||
model_output = model_outputs.cpu().numpy()
|
||||
rewrite_output = rewrite_outputs[0].cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def test_simple_test_of_single_stage_text_detector():
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
def test_simple_test_of_single_stage_text_detector(backend: Backend):
|
||||
"""Test simple_test single_stage_text_detector."""
|
||||
check_backend(backend)
|
||||
single_stage_text_detector = get_single_stage_text_detector_model()
|
||||
single_stage_text_detector.eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='default'),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
@ -187,40 +192,36 @@ def test_simple_test_of_single_stage_text_detector():
|
||||
)))
|
||||
|
||||
input = torch.rand(1, 3, 64, 64)
|
||||
img_metas = [{
|
||||
'ori_shape': [64, 64, 3],
|
||||
'img_shape': [64, 64, 3],
|
||||
'pad_shape': [64, 64, 3],
|
||||
'scale_factor': [1., 1., 1., 1],
|
||||
}]
|
||||
|
||||
x = single_stage_text_detector.extract_feat(input)
|
||||
model_outputs = single_stage_text_detector.bbox_head(x)
|
||||
|
||||
wrapped_model = WrapModel(single_stage_text_detector, 'simple_test')
|
||||
rewrite_inputs = {'img': input, 'img_metas': img_metas[0]}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
rewrite_inputs = {'img': input}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
model_outputs = model_outputs.cpu().numpy()
|
||||
rewrite_outputs = rewrite_outputs.cpu().numpy()
|
||||
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
@pytest.mark.parametrize('backend', [Backend.NCNN])
|
||||
@pytest.mark.parametrize('rnn_flag', [True, False])
|
||||
def test_crnndecoder(backend_type: Backend, rnn_flag: bool):
|
||||
def test_crnndecoder(backend: Backend, rnn_flag: bool):
|
||||
"""Test forward rewrite of crnndecoder."""
|
||||
check_backend(backend_type)
|
||||
check_backend(backend)
|
||||
crnn_decoder = get_crnn_decoder_model(rnn_flag)
|
||||
crnn_decoder.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
@ -250,32 +251,39 @@ def test_crnndecoder(backend_type: Backend, rnn_flag: bool):
|
||||
targets_dict=targets_dict,
|
||||
img_metas=img_metas)
|
||||
rewrite_inputs = {'feat': input}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
if is_backend_output:
|
||||
for model_output, rewrite_output in zip(model_outputs,
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize(
|
||||
'img_metas', [[[{}]], [[{
|
||||
'resize_shape': [32, 32],
|
||||
'valid_ratio': 1.0
|
||||
}]]])
|
||||
@pytest.mark.parametrize('is_dynamic', [True, False])
|
||||
def test_forward_of_base_recognizer(img_metas, is_dynamic):
|
||||
def test_forward_of_base_recognizer(img_metas, is_dynamic, backend):
|
||||
"""Test forward base_recognizer."""
|
||||
check_backend(backend)
|
||||
base_recognizer = get_base_recognizer_model()
|
||||
base_recognizer.eval()
|
||||
|
||||
if not is_dynamic:
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='ncnn'),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
@ -284,7 +292,7 @@ def test_forward_of_base_recognizer(img_metas, is_dynamic):
|
||||
else:
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='ncnn'),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
input_shape=None,
|
||||
dynamic_axes={
|
||||
@ -317,26 +325,29 @@ def test_forward_of_base_recognizer(img_metas, is_dynamic):
|
||||
rewrite_inputs = {
|
||||
'img': input,
|
||||
}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
model_outputs = model_outputs.cpu().numpy()
|
||||
rewrite_outputs = rewrite_outputs.cpu().numpy()
|
||||
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_simple_test_of_encode_decode_recognizer():
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
def test_simple_test_of_encode_decode_recognizer(backend):
|
||||
"""Test simple_test encode_decode_recognizer."""
|
||||
check_backend(backend)
|
||||
encode_decode_recognizer = get_encode_decode_recognizer_model()
|
||||
encode_decode_recognizer.eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='default'),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
@ -356,28 +367,28 @@ def test_simple_test_of_encode_decode_recognizer():
|
||||
wrapped_model = WrapModel(
|
||||
encode_decode_recognizer, 'simple_test', img_metas=img_metas)
|
||||
rewrite_inputs = {'img': input}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
model_outputs = model_outputs.cpu().numpy()
|
||||
rewrite_outputs = rewrite_outputs.cpu().numpy()
|
||||
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.TENSORRT])
|
||||
def test_forward_of_fpnc(backend_type: Backend):
|
||||
@pytest.mark.parametrize('backend', [Backend.TENSORRT])
|
||||
def test_forward_of_fpnc(backend: Backend):
|
||||
"""Test forward rewrite of fpnc."""
|
||||
check_backend(backend_type)
|
||||
check_backend(backend)
|
||||
fpnc = get_fpnc_neck_model()
|
||||
fpnc.eval()
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(
|
||||
type=backend_type.value,
|
||||
type=backend.value,
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
@ -399,22 +410,17 @@ def test_forward_of_fpnc(backend_type: Backend):
|
||||
rewrite_inputs = {
|
||||
'inputs': input,
|
||||
}
|
||||
rewrite_outputs, is_need_name = get_rewrite_outputs(
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_need_name:
|
||||
model_output = model_outputs[0].squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
for model_output, rewrite_output in zip(model_outputs,
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
model_outputs = model_outputs.cpu().numpy()
|
||||
rewrite_outputs = rewrite_outputs.cpu().numpy()
|
||||
assert np.allclose(model_outputs, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def get_sar_model_cfg(decoder_type: str):
|
||||
@ -468,11 +474,11 @@ def get_sar_model_cfg(decoder_type: str):
|
||||
dict(model=model, data=dict(test=dict(pipeline=test_pipeline))))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize('decoder_type',
|
||||
['SequentialSARDecoder', 'ParallelSARDecoder'])
|
||||
def test_sar_model(backend_type: Backend, decoder_type):
|
||||
check_backend(backend_type)
|
||||
def test_sar_model(backend: Backend, decoder_type):
|
||||
check_backend(backend)
|
||||
import os.path as osp
|
||||
import onnx
|
||||
from mmocr.models.textrecog import SARNet
|
||||
@ -483,7 +489,7 @@ def test_sar_model(backend_type: Backend, decoder_type):
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
@ -492,11 +498,11 @@ def test_sar_model(backend_type: Backend, decoder_type):
|
||||
# patch model
|
||||
pytorch_model.cfg = sar_cfg
|
||||
patched_model = patch_model(
|
||||
pytorch_model, cfg=deploy_cfg, backend=backend_type.value)
|
||||
pytorch_model, cfg=deploy_cfg, backend=backend.value)
|
||||
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
|
||||
with RewriterContext(
|
||||
cfg=deploy_cfg, backend=backend_type.value), torch.no_grad():
|
||||
cfg=deploy_cfg, backend=backend.value), torch.no_grad():
|
||||
torch.onnx.export(
|
||||
patched_model,
|
||||
tuple([v for k, v in model_inputs.items()]),
|
||||
|
@ -9,8 +9,8 @@ from mmseg.models import BACKBONES, HEADS
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Codebase
|
||||
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
||||
get_rewrite_outputs)
|
||||
|
||||
import_codebase(Codebase.MMSEG)
|
||||
@ -93,14 +93,15 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||
def test_encoderdecoder_simple_test(backend_type):
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_encoderdecoder_simple_test(backend):
|
||||
check_backend(backend)
|
||||
segmentor = get_model()
|
||||
segmentor.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['result'], input_shape=None),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
|
||||
@ -127,23 +128,21 @@ def test_encoderdecoder_simple_test(backend_type):
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
model_outputs = torch.tensor(model_outputs[0])
|
||||
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
model_outputs = torch.tensor(model_outputs[0])
|
||||
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||
def test_basesegmentor_forward(backend_type):
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_basesegmentor_forward(backend):
|
||||
check_backend(backend)
|
||||
segmentor = get_model()
|
||||
segmentor.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['result'], input_shape=None),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
|
||||
@ -167,17 +166,16 @@ def test_basesegmentor_forward(backend_type):
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = torch.tensor(rewrite_outputs[0])
|
||||
model_outputs = torch.tensor(model_outputs[0])
|
||||
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
model_outputs = torch.tensor(model_outputs[0])
|
||||
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||
def test_aspphead_forward(backend_type):
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_aspphead_forward(backend):
|
||||
check_backend(backend)
|
||||
from mmseg.models.decode_heads import ASPPHead
|
||||
head = ASPPHead(
|
||||
in_channels=32, channels=16, num_classes=19,
|
||||
@ -185,7 +183,7 @@ def test_aspphead_forward(backend_type):
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
output_names=['result'], input_shape=(1, 32, 45, 45)),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
@ -200,21 +198,23 @@ def test_aspphead_forward(backend_type):
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = torch.tensor(rewrite_outputs[0])
|
||||
assert torch.allclose(
|
||||
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(
|
||||
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
||||
def test_psphead_forward(backend_type):
|
||||
@pytest.mark.parametrize('backend',
|
||||
[Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN])
|
||||
def test_psphead_forward(backend):
|
||||
check_backend(backend)
|
||||
from mmseg.models.decode_heads import PSPHead
|
||||
head = PSPHead(in_channels=32, channels=16, num_classes=19).eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['result'], input_shape=None),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
@ -228,7 +228,7 @@ def test_psphead_forward(backend_type):
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = torch.tensor(rewrite_outputs[0])
|
||||
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)
|
||||
|
@ -49,7 +49,8 @@ def test_get_attribute():
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'tensor': input},
|
||||
deploy_cfg=deploy_cfg_ncnn)
|
||||
deploy_cfg=deploy_cfg_ncnn,
|
||||
run_with_backend=False)
|
||||
|
||||
assert rewrite_outputs is not None, 'Got unexpected rewrite '
|
||||
'outputs: {}'.format(rewrite_outputs)
|
||||
@ -69,7 +70,8 @@ def test_group_norm_ncnn():
|
||||
rewrite_output, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'input': input},
|
||||
deploy_cfg=deploy_cfg_ncnn)
|
||||
deploy_cfg=deploy_cfg_ncnn,
|
||||
run_with_backend=False)
|
||||
|
||||
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
@ -86,7 +88,8 @@ def test_interpolate_static():
|
||||
rewrite_output, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'input': input},
|
||||
deploy_cfg=deploy_cfg_ncnn)
|
||||
deploy_cfg=deploy_cfg_ncnn,
|
||||
run_with_backend=False)
|
||||
|
||||
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
@ -105,7 +108,8 @@ def test_linear_ncnn():
|
||||
rewrite_output, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'input': input},
|
||||
deploy_cfg=deploy_cfg_ncnn)
|
||||
deploy_cfg=deploy_cfg_ncnn,
|
||||
run_with_backend=False)
|
||||
|
||||
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
@ -148,7 +152,8 @@ def test_size_of_tensor_static():
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'input': input},
|
||||
deploy_cfg=deploy_cfg_ncnn)
|
||||
deploy_cfg=deploy_cfg_ncnn,
|
||||
run_with_backend=False)
|
||||
|
||||
assert rewrite_outputs is not None, 'Got unexpected rewrite '
|
||||
'outputs: {}'.format(rewrite_outputs)
|
||||
@ -175,7 +180,8 @@ class TestTopk:
|
||||
output, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'input': TestTopk.input},
|
||||
deploy_cfg=deploy_cfg_ncnn)
|
||||
deploy_cfg=deploy_cfg_ncnn,
|
||||
run_with_backend=False)
|
||||
|
||||
assert np.allclose(model_output, output[1], rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user