[Feature] Support RTMDet and RTMPose ncnn deployment (#1857)
* support rtmpose ncnn * fix docformatter * fix docformatter * fix classname from tauj to dev-1.x branch * rename file * fix comments * remove unused rewriter * fix norm * fix lint * fix rtmcc_block * fix norm * add ut * fix origin_func * fix norm * fix rtmdet_head * add ut * false run_with_backend for ncnn * fix lintpull/1901/head
parent
423e27a4fe
commit
d48187cf68
|
@ -0,0 +1,5 @@
|
|||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
|
||||
|
||||
backend_config = dict(precision='FP16')
|
||||
codebase_config = dict(model_type='ncnn_end2end')
|
||||
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
|
||||
|
||||
codebase_config = dict(model_type='ncnn_end2end')
|
||||
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/ncnn.py']
|
||||
|
||||
backend_config = dict(precision='FP16')
|
||||
onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])
|
|
@ -2200,8 +2200,6 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
fprintf(pp, " 4=%d", keepdims);
|
||||
fprintf(pp, " 5=1");
|
||||
// Force set Reduction for FP32, FP16 may exceed for some models.
|
||||
fprintf(pp, " 31=15");
|
||||
} else if (op == "Reorg") {
|
||||
int stride = get_node_attr_i(node, "stride", 1);
|
||||
fprintf(pp, " 0=%d", stride);
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch import Tensor
|
|||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.mmcv.ops import multiclass_nms
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
@ -105,3 +106,120 @@ def rtmdet_head__predict_by_feat(self,
|
|||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.rtmdet_head.'
|
||||
'RTMDetHead.predict_by_feat',
|
||||
backend=Backend.NCNN.value)
|
||||
def rtmdet_head__predict_by_feat__ncnn(
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
cfg: Optional[ConfigDict] = None,
|
||||
rescale: bool = False,
|
||||
with_nms: bool = True):
|
||||
"""Rewrite `predict_by_feat` of RTMDetHead for ncnn backend.
|
||||
1. Decode the prior to a box format for ncnn DetectionOutput layer to do
|
||||
the post-processing.
|
||||
2. Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
The negative value of axis in torch.cat is rewritten as corresponding
|
||||
positive value to avoid axis shift.
|
||||
3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
|
||||
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
|
||||
correct `BinaryOps` calculation by ncnn.
|
||||
Args:
|
||||
cls_scores (list[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 4, H, W).
|
||||
objectnesses (list[Tensor], Optional): Score factor for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, 1, H, W).
|
||||
batch_img_metas (list[dict], Optional): Batch image meta info.
|
||||
Defaults to None.
|
||||
cfg (ConfigDict, optional): Test / postprocessing
|
||||
configuration, if None, test_cfg would be used.
|
||||
Defaults to None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Defaults to False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Defaults to True.
|
||||
Returns:
|
||||
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward
|
||||
from mmdeploy.utils import get_root_logger
|
||||
from mmdeploy.utils.config_utils import is_dynamic_shape
|
||||
dynamic_flag = is_dynamic_shape(ctx.cfg)
|
||||
if dynamic_flag:
|
||||
logger = get_root_logger()
|
||||
logger.warning('RTMDet does not support dynamic shape with ncnn.')
|
||||
img_height = int(batch_img_metas[0]['img_shape'][0])
|
||||
img_width = int(batch_img_metas[0]['img_shape'][1])
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
device = cls_scores[0].device
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
batch_size = bbox_preds[0].shape[0]
|
||||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=device, with_stride=True)
|
||||
mlvl_priors = [mlvl_prior.unsqueeze(0) for mlvl_prior in mlvl_priors]
|
||||
flatten_priors = torch.cat(mlvl_priors, dim=1)
|
||||
|
||||
flatten_cls_scores = [
|
||||
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
|
||||
self.cls_out_channels)
|
||||
for cls_score in cls_scores
|
||||
]
|
||||
flatten_bbox_preds = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
|
||||
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
dummy_cls_scores = torch.zeros(
|
||||
batch_size, cls_scores.shape[-2], 1, device=cls_scores.device)
|
||||
|
||||
batch_mlvl_scores = torch.cat([dummy_cls_scores, cls_scores], dim=2)
|
||||
|
||||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
||||
assert flatten_priors.shape[-1] == 4, f'rtmdet needs (B, N, 4) priors, got\
|
||||
(B, N, {flatten_priors.shape[-1]})'
|
||||
|
||||
tl_x = (flatten_priors[:, :, 0:1] -
|
||||
flatten_bbox_preds[:, :, 0:1]) / img_width
|
||||
tl_y = (flatten_priors[:, :, 1:2] -
|
||||
flatten_bbox_preds[:, :, 1:2]) / img_height
|
||||
br_x = (flatten_priors[:, :, 0:1] +
|
||||
flatten_bbox_preds[:, :, 2:3]) / img_width
|
||||
br_y = (flatten_priors[:, :, 1:2] +
|
||||
flatten_bbox_preds[:, :, 3:4]) / img_height
|
||||
prior_box_ncnn = torch.stack([tl_x, tl_y, br_x, br_y], -1)
|
||||
|
||||
scores = batch_mlvl_scores
|
||||
|
||||
batch_mlvl_bboxes = flatten_bbox_preds.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_scores = scores.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_priors = prior_box_ncnn.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_vars = torch.ones_like(batch_mlvl_priors)
|
||||
batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)
|
||||
deploy_cfg = ctx.cfg
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
vars = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
|
||||
output__ncnn = ncnn_detection_output_forward(
|
||||
batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors,
|
||||
score_threshold, iou_threshold, pre_top_k, keep_top_k,
|
||||
self.num_classes + 1,
|
||||
vars.cpu().detach().numpy())
|
||||
return output__ncnn
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
|
||||
from . import heads # noqa: F401,F403
|
||||
from . import pose_estimators # noqa: F401,F403
|
||||
from . import utils # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from . import rtmcc_block
|
||||
|
||||
__all__ = ['rtmcc_block']
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmpose.models.utils import rope
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.utils.rtmcc_block.ScaleNorm.forward', backend='ncnn')
|
||||
def scalenorm__forward__ncnn(self, x):
|
||||
"""Rewrite `scalenorm` for ncnn backend.
|
||||
|
||||
Rewrite scalenorm to avoid FP16 exceed in ncnn Android platform.
|
||||
"""
|
||||
# The one-dim of Fubinious norm is equal to L2Norm.
|
||||
# Set p=2 explicitly to map torch.norm to ReduceL2 onnx op,
|
||||
# which will avoid FP16 exceed.
|
||||
norm = torch.norm(x, dim=2, keepdim=True)
|
||||
norm = norm * self.scale
|
||||
# Rewrite for ncnn binaryop broadcast.
|
||||
norm = norm.clamp(min=self.eps)
|
||||
return (x.unsqueeze(2) / norm.unsqueeze(2)).squeeze(2) * self.g
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.utils.rtmcc_block.RTMCCBlock._forward', backend='ncnn')
|
||||
def rtmccblock___forward_ncnn(self, inputs):
|
||||
"""Rewrite `_forward` of RTMBlock for ncnn backend.
|
||||
|
||||
Rewrite the matmul and avoid unbind for ncnn backend.
|
||||
"""
|
||||
if self.attn_type == 'self-attn':
|
||||
x = inputs
|
||||
else:
|
||||
x, k, v = inputs
|
||||
|
||||
x = self.ln(x)
|
||||
uv = self.uv(x)
|
||||
if self.attn_type == 'self-attn':
|
||||
uv = self.act_fn(uv)
|
||||
u = uv[..., :self.e]
|
||||
v = uv[..., self.e:2 * self.e]
|
||||
base = uv[..., 2 * self.e:2 * self.e + self.s]
|
||||
|
||||
q = (base.unsqueeze(1) * self.gamma[None, None, 0:1, :] +
|
||||
self.beta[None, None, 0:1, :]).squeeze(1)
|
||||
k = (base.unsqueeze(1) * self.gamma[None, None, 1:2, :] +
|
||||
self.beta[None, None, 1:2, :]).squeeze(1)
|
||||
|
||||
if self.pos_enc:
|
||||
q = rope(q, dim=1)
|
||||
k = rope(k, dim=1)
|
||||
else:
|
||||
u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1)
|
||||
|
||||
k = self.k_fc(k)
|
||||
v = self.v_fc(v)
|
||||
|
||||
if self.pos_enc:
|
||||
q = rope(q, 1)
|
||||
k = rope(k, 1)
|
||||
qk = torch.bmm(q, k.permute(0, 2, 1))
|
||||
if self.use_rel_bias:
|
||||
if self.attn_type == 'self-attn':
|
||||
bias = self.rel_pos_bias(q.size(1))
|
||||
else:
|
||||
bias = self.rel_pos_bias(q.size(1), k.size(1))
|
||||
qk += bias[:, :q.size(1), :k.size(1)]
|
||||
|
||||
kernel = torch.square(F.relu(qk / self.sqrt_s))
|
||||
if self.dropout_rate > 0.:
|
||||
kernel = self.dropout(kernel)
|
||||
|
||||
x = u * torch.bmm(kernel, v)
|
||||
x = self.o(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.utils.rtmcc_block.Scale.forward', backend='ncnn')
|
||||
def scale__forward_ncnn(self, x):
|
||||
"""Rewrite `forward` of Scale for ncnn backend.
|
||||
|
||||
Adapt the shape to avoid ncnn BinaryOp seg fault.
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
scale = self.scale[None, None, None, :]
|
||||
return (x * scale).squeeze(1)
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
@ -39,3 +41,26 @@ def normalize__ncnn(input: torch.Tensor,
|
|||
input.transpose(1, dim), p=p, dim=1,
|
||||
eps=eps).transpose(1, dim)
|
||||
return output
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='torch.norm', backend='ncnn')
|
||||
def norm__ncnn(input: torch.Tensor,
|
||||
p: Optional[Union[int, str]] = 'fro',
|
||||
dim: Optional[Union[int, Sequence]] = None,
|
||||
keepdim: Optional[bool] = False,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None):
|
||||
"""Rewrite `torch.norm` for ncnn backend.
|
||||
|
||||
Rewrite torch.norm when p is Frobenius norm to avoid FP16 exceed in ncnn
|
||||
Android platform.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
origin_func = ctx.origin_func
|
||||
if p == 'fro' and (isinstance(dim, int) or len(dim) == 1):
|
||||
# Substitute Frobenius norm with L2 norm.
|
||||
return origin_func(
|
||||
input, p=2, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
|
||||
else:
|
||||
return origin_func(
|
||||
input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
|
||||
|
|
|
@ -2121,3 +2121,88 @@ def test_solo_head_predict_by_feat(backend_type: Backend):
|
|||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_rtmdet_head_model():
|
||||
|
||||
from mmdet.models.dense_heads import RTMDetHead
|
||||
from mmdet.models.task_modules.prior_generators.point_generator import \
|
||||
MlvlPointGenerator
|
||||
|
||||
test_cfg = Config(
|
||||
dict(
|
||||
deploy_nms_pre=0,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.6),
|
||||
max_per_img=100))
|
||||
model = RTMDetHead(1, 64)
|
||||
model.prior_generator = MlvlPointGenerator([8, 4, 2])
|
||||
model.test_cfg = test_cfg
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def test_rtmdet_head_predict_by_feat_ncnn():
|
||||
"""Test predict_by_feat rewrite of yolov3 head."""
|
||||
backend_type = Backend.NCNN
|
||||
check_backend(backend_type)
|
||||
rtmdet_head = get_rtmdet_head_model()
|
||||
rtmdet_head.cpu().eval()
|
||||
s = 320
|
||||
batch_img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['detection_output']
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
model_type='ncnn_end2end',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.45,
|
||||
confidence_threshold=0.005,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=10,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
|
||||
seed_everything(1234)
|
||||
cls_scores = [
|
||||
torch.rand(1, 1, 40, 40),
|
||||
torch.rand(1, 1, 20, 20),
|
||||
torch.rand(1, 1, 10, 10)
|
||||
]
|
||||
|
||||
bbox_preds = [
|
||||
torch.rand(1, 4, 40, 40),
|
||||
torch.rand(1, 4, 20, 20),
|
||||
torch.rand(1, 4, 10, 10)
|
||||
]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
wrapped_model = WrapModel(
|
||||
rtmdet_head,
|
||||
'predict_by_feat',
|
||||
batch_img_metas=batch_img_metas,
|
||||
with_nms=True)
|
||||
rewrite_inputs = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
# output should be of shape [1, N, 6]
|
||||
if is_backend_output:
|
||||
assert rewrite_outputs[0].shape[-1] == 6
|
||||
else:
|
||||
assert rewrite_outputs.shape[-1] == 6
|
||||
|
|
|
@ -6,6 +6,11 @@ from mmdeploy.codebase import import_codebase
|
|||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
except Exception:
|
||||
from torch.testing import assert_allclose as torch_assert_close
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMPOSE)
|
||||
except ImportError:
|
||||
|
@ -108,3 +113,78 @@ def test_estimator_forward(backend_type: Backend):
|
|||
run_with_backend=False,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert isinstance(rewrite_outputs, torch.Tensor)
|
||||
|
||||
|
||||
def get_scale_norm_model():
|
||||
from mmpose.models.utils.rtmcc_block import ScaleNorm
|
||||
|
||||
model = ScaleNorm(48)
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_scale_norm_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
|
||||
model = get_scale_norm_model()
|
||||
x = torch.rand(1, 17, 48)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
model_outputs = model.forward(x)
|
||||
rewrite_inputs = {'x': x}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
torch_assert_close(rewrite_outputs, model_outputs)
|
||||
|
||||
|
||||
def get_rtmcc_block_model():
|
||||
from mmpose.models.utils.rtmcc_block import RTMCCBlock
|
||||
|
||||
model = RTMCCBlock(48, 48, 48)
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_rtmcc_block_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
|
||||
model = get_rtmcc_block_model()
|
||||
inputs = torch.rand(1, 17, 48)
|
||||
wrapped_model = WrapModel(model, '_forward')
|
||||
model_outputs = model._forward(inputs)
|
||||
rewrite_inputs = {'inputs': inputs}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
torch_assert_close(rewrite_outputs, model_outputs)
|
||||
|
||||
|
||||
def get_scale_model():
|
||||
from mmpose.models.utils.rtmcc_block import Scale
|
||||
|
||||
model = Scale(48)
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_scale_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
|
||||
model = get_scale_model()
|
||||
x = torch.rand(1, 17, 48)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
model_outputs = model.forward(x)
|
||||
rewrite_inputs = {'x': x}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
torch_assert_close(rewrite_outputs, model_outputs)
|
||||
|
|
|
@ -166,6 +166,28 @@ def test_linear_ncnn():
|
|||
assert np.allclose(model_output, rewrite_output[0], rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@backend_checker(Backend.NCNN)
|
||||
def test_norm_ncnn():
|
||||
import onnx
|
||||
|
||||
import mmdeploy.apis.ncnn as ncnn_apis
|
||||
from mmdeploy.utils.test import get_onnx_model
|
||||
|
||||
input = torch.rand(1, 17, 24)
|
||||
wrapped_func = WrapFunction(torch.norm, p='fro', dim=2, keepdim=True)
|
||||
model_inputs = {'input': input}
|
||||
ir_file_path = get_onnx_model(wrapped_func, model_inputs, deploy_cfg_ncnn)
|
||||
assert osp.exists(ir_file_path)
|
||||
onnx_model = onnx.load(ir_file_path)
|
||||
nodes = onnx_model.graph.node
|
||||
assert nodes[-1].name.startswith('ReduceL2')
|
||||
ncnn_files_prefix = osp.splitext(ir_file_path)[0]
|
||||
ncnn_apis.from_onnx(ir_file_path, ncnn_files_prefix)
|
||||
param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path)
|
||||
assert osp.exists(param_path)
|
||||
assert osp.exists(bin_path)
|
||||
|
||||
|
||||
@backend_checker(Backend.TENSORRT)
|
||||
def test_repeat_static():
|
||||
input = torch.rand([1])
|
||||
|
|
Loading…
Reference in New Issue