[Fix] fix unittest and suppress warning (#1552)
* fix unittest and some warning * fix read string * snakepull/1583/head
parent
0e656067a6
commit
e0ed95ebc4
|
@ -201,14 +201,15 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
|
|||
|
||||
if (field_name.compare("mode") == 0) {
|
||||
int data_size = fc->fields[i].length;
|
||||
ASSERT(data_size > 0);
|
||||
const char *data_start = static_cast<const char *>(fc->fields[i].data);
|
||||
std::string poolModeStr(data_start, data_size);
|
||||
if (poolModeStr == "avg") {
|
||||
std::string pool_mode(data_start);
|
||||
if (pool_mode == "avg") {
|
||||
poolMode = 1;
|
||||
} else if (poolModeStr == "max") {
|
||||
} else if (pool_mode == "max") {
|
||||
poolMode = 0;
|
||||
} else {
|
||||
std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl;
|
||||
std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl;
|
||||
}
|
||||
ASSERT(poolMode >= 0);
|
||||
}
|
||||
|
|
|
@ -13,15 +13,17 @@ from mmdeploy.utils import get_root_logger
|
|||
from .init_plugins import load_tensorrt_plugin
|
||||
|
||||
|
||||
def save(engine: trt.ICudaEngine, path: str) -> None:
|
||||
def save(engine: Any, path: str) -> None:
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Args:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
|
||||
engine (Any): TensorRT engine to be serialized.
|
||||
path (str): The absolute disk path to write the engine.
|
||||
"""
|
||||
with open(path, mode='wb') as f:
|
||||
f.write(bytearray(engine.serialize()))
|
||||
if isinstance(engine, trt.ICudaEngine):
|
||||
engine = engine.serialize()
|
||||
f.write(bytearray(engine))
|
||||
|
||||
|
||||
def load(path: str, allocator: Optional[Any] = None) -> trt.ICudaEngine:
|
||||
|
@ -226,7 +228,10 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
|||
builder.int8_calibrator = config.int8_calibrator
|
||||
|
||||
# create engine
|
||||
engine = builder.build_engine(network, config)
|
||||
if hasattr(builder, 'build_serialized_network'):
|
||||
engine = builder.build_serialized_network(network, config)
|
||||
else:
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
assert engine is not None, 'Failed to create TensorRT engine'
|
||||
|
||||
|
|
|
@ -601,7 +601,7 @@ class NCNNEnd2EndModel(End2EndModel):
|
|||
scores = out[:, :, 1:2]
|
||||
boxes = out[:, :, 2:6] * scales
|
||||
dets = torch.cat([boxes, scores], dim=2)
|
||||
return dets, torch.tensor(labels, dtype=torch.int32)
|
||||
return dets, labels.to(torch.int32)
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('sdk')
|
||||
|
|
|
@ -46,7 +46,7 @@ def focus__forward__ncnn(self, x):
|
|||
|
||||
x = x.reshape(batch_size, c * h, 1, w)
|
||||
_b, _c, _h, _w = x.shape
|
||||
g = _c // 2
|
||||
g = torch.div(_c, 2, rounding_mode='floor')
|
||||
# fuse to ncnn's shufflechannel
|
||||
x = x.view(_b, g, 2, _h, _w)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
|
@ -55,13 +55,14 @@ def focus__forward__ncnn(self, x):
|
|||
x = x.reshape(_b, c * h * w, 1, 1)
|
||||
|
||||
_b, _c, _h, _w = x.shape
|
||||
g = _c // 2
|
||||
g = torch.div(_c, 2, rounding_mode='floor')
|
||||
# fuse to ncnn's shufflechannel
|
||||
x = x.view(_b, g, 2, _h, _w)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
x = x.view(_b, -1, _h, _w)
|
||||
|
||||
x = x.reshape(_b, c * 4, h // 2, w // 2)
|
||||
x = x.reshape(_b, c * 4, torch.div(h, 2, rounding_mode='floor'),
|
||||
torch.div(w, 2, rounding_mode='floor'))
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
@ -198,8 +199,12 @@ def shift_window_msa__forward__default(self, query, hw_shape):
|
|||
[query,
|
||||
query.new_zeros(B, C, self.window_size, query.shape[-1])],
|
||||
dim=-2)
|
||||
slice_h = (H + self.window_size - 1) // self.window_size * self.window_size
|
||||
slice_w = (W + self.window_size - 1) // self.window_size * self.window_size
|
||||
slice_h = torch.div(
|
||||
(H + self.window_size - 1), self.window_size,
|
||||
rounding_mode='floor') * self.window_size
|
||||
slice_w = torch.div(
|
||||
(W + self.window_size - 1), self.window_size,
|
||||
rounding_mode='floor') * self.window_size
|
||||
query = query[:, :, :slice_h, :slice_w]
|
||||
query = query.permute(0, 2, 3, 1).contiguous()
|
||||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||||
|
|
|
@ -328,6 +328,29 @@ class RewriterRegistry:
|
|||
|
||||
return decorator
|
||||
|
||||
def remove_record(self, object: Any, filter_cb: Optional[Callable] = None):
|
||||
"""Remove record.
|
||||
|
||||
Args:
|
||||
object (Any): The object to remove.
|
||||
filter_cb (Callable): Check if the object need to be remove.
|
||||
Defaults to None.
|
||||
"""
|
||||
key_to_pop = []
|
||||
for key, records in self._rewrite_records.items():
|
||||
for rec in records:
|
||||
if rec['_object'] == object:
|
||||
if filter_cb is not None:
|
||||
if filter_cb(rec):
|
||||
continue
|
||||
key_to_pop.append((key, rec))
|
||||
|
||||
for key, rec in key_to_pop:
|
||||
records = self._rewrite_records[key]
|
||||
records.remove(rec)
|
||||
if len(records) == 0:
|
||||
self._rewrite_records.pop(key)
|
||||
|
||||
|
||||
class ContextCaller:
|
||||
"""A callable object used in RewriteContext.
|
||||
|
|
|
@ -511,7 +511,7 @@ def multiclass_nms(boxes: Tensor,
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
|
||||
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms',
|
||||
backend=Backend.COREML.value)
|
||||
def multiclass_nms__coreml(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
|
@ -574,8 +574,7 @@ def multiclass_nms__coreml(boxes: Tensor,
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
|
||||
ir=IR.TORCHSCRIPT)
|
||||
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', ir=IR.TORCHSCRIPT)
|
||||
def multiclass_nms__torchscript(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
|
@ -676,8 +675,7 @@ class AscendBatchNMSOp(torch.autograd.Function):
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
|
||||
backend='ascend')
|
||||
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='ascend')
|
||||
def multiclass_nms__ascend(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
|
|
|
@ -14,6 +14,11 @@ from mmengine import Config
|
|||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
except Exception:
|
||||
from torch.testing import assert_allclose as torch_assert_close
|
||||
|
||||
import mmdeploy.codebase # noqa: F401,F403
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes,
|
||||
|
@ -293,8 +298,7 @@ def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
|
|||
if isinstance(actual[i], (list, np.ndarray)):
|
||||
actual[i] = torch.tensor(actual[i])
|
||||
try:
|
||||
torch.testing.assert_allclose(
|
||||
actual[i], expected[i], rtol=1e-03, atol=1e-05)
|
||||
torch_assert_close(actual[i], expected[i], rtol=1e-03, atol=1e-05)
|
||||
except AssertionError as error:
|
||||
if tolerate_small_mismatch:
|
||||
assert '(0.00%)' in str(error), str(error)
|
||||
|
|
|
@ -9,6 +9,12 @@ import mmengine
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
except Exception:
|
||||
from torch.testing import assert_allclose as torch_assert_close
|
||||
|
||||
from mmengine import Config
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
|
@ -237,7 +243,7 @@ def test__anchorgenerator__single_level_grid_priors():
|
|||
# test forward
|
||||
with RewriterContext({}, backend_type):
|
||||
wrap_output = wrapped_func(x)
|
||||
torch.testing.assert_allclose(output, wrap_output)
|
||||
torch_assert_close(output, wrap_output)
|
||||
|
||||
onnx_prefix = tempfile.NamedTemporaryFile().name
|
||||
|
||||
|
@ -341,23 +347,6 @@ def get_ssd_head_model():
|
|||
return model
|
||||
|
||||
|
||||
def get_fcos_head_model():
|
||||
"""FCOS Head Config."""
|
||||
test_cfg = Config(
|
||||
dict(
|
||||
deploy_nms_pre=0,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.5),
|
||||
max_per_img=100))
|
||||
|
||||
from mmdet.models.dense_heads import FCOSHead
|
||||
model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_focus_backbone_model():
|
||||
"""Backbone Focus Config."""
|
||||
from mmdet.models.backbones.csp_darknet import Focus
|
||||
|
@ -412,10 +401,8 @@ def get_reppoints_head_model():
|
|||
|
||||
def get_detrhead_model():
|
||||
"""DETR head Config."""
|
||||
from mmdet.models import build_head
|
||||
from mmdet.utils import register_all_modules
|
||||
register_all_modules()
|
||||
model = build_head(
|
||||
from mmdet.registry import MODELS
|
||||
model = MODELS.build(
|
||||
dict(
|
||||
type='DETRHead',
|
||||
num_classes=4,
|
||||
|
@ -431,8 +418,7 @@ def get_detrhead_model():
|
|||
dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=4,
|
||||
num_heads=1,
|
||||
dropout=0.1)
|
||||
num_heads=1)
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
|
@ -442,8 +428,6 @@ def get_detrhead_model():
|
|||
ffn_drop=0.,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
feedforward_channels=32,
|
||||
ffn_dropout=0.1,
|
||||
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
|
||||
decoder=dict(
|
||||
type='DetrTransformerDecoder',
|
||||
|
@ -454,8 +438,7 @@ def get_detrhead_model():
|
|||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=4,
|
||||
num_heads=1,
|
||||
dropout=0.1),
|
||||
num_heads=1),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=4,
|
||||
|
@ -465,7 +448,6 @@ def get_detrhead_model():
|
|||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
feedforward_channels=32,
|
||||
ffn_dropout=0.1,
|
||||
operation_order=('self_attn', 'norm', 'cross_attn',
|
||||
'norm', 'ffn', 'norm')),
|
||||
)),
|
||||
|
@ -536,7 +518,7 @@ def test_focus_forward(backend_type):
|
|||
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
|
||||
model_output = model_output.squeeze()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
torch.testing.assert_allclose(
|
||||
torch_assert_close(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
|
@ -578,77 +560,6 @@ def test_l2norm_forward(backend_type):
|
|||
model_output[0], rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_predict_by_feat_of_fcos_head_ncnn():
|
||||
backend_type = Backend.NCNN
|
||||
check_backend(backend_type)
|
||||
fcos_head = get_fcos_head_model()
|
||||
fcos_head.cpu().eval()
|
||||
s = 128
|
||||
batch_img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['detection_output']
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
model_type='ncnn_end2end',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
|
||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i))
|
||||
for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||
|
||||
seed_everything(9101)
|
||||
centernesses = [
|
||||
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
batch_img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
fcos_head,
|
||||
'predict_by_feat',
|
||||
batch_img_metas=batch_img_metas,
|
||||
with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'centernesses': centernesses
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
# output should be of shape [1, N, 6]
|
||||
if is_backend_output:
|
||||
assert rewrite_outputs[0].shape[-1] == 6
|
||||
else:
|
||||
assert rewrite_outputs.shape[-1] == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
|
||||
def test_predict_by_feat_of_rpn_head(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
|
|
|
@ -57,6 +57,10 @@ class TestVoxelDetectionModel:
|
|||
deploy_cfg=deploy_cfg,
|
||||
model_cfg=model_cfg)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls.wrapper.recover()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
reason='Only support GPU test',
|
||||
condition=not torch.cuda.is_available())
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -93,30 +92,6 @@ def get_cross_resolution_weighting_model():
|
|||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_cross_resolution_weighting_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
model = get_cross_resolution_weighting_model()
|
||||
model.cpu().eval()
|
||||
imgs = torch.rand(1, 16, 16, 16)
|
||||
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
|
||||
rewrite_inputs = {'x': imgs}
|
||||
model_outputs = model.forward(imgs)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = rewrite_outputs['output']
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.cpu().numpy()
|
||||
if isinstance(rewrite_output, torch.Tensor):
|
||||
rewrite_output = rewrite_output.detach().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_estimator_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
except Exception:
|
||||
from torch.testing import assert_allclose as torch_assert_close
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext
|
||||
from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter
|
||||
from mmdeploy.core.rewriters.rewriter_utils import collect_env
|
||||
|
@ -26,19 +31,19 @@ def test_function_rewriter():
|
|||
with RewriterContext(cfg, backend='tensorrt'):
|
||||
result = torch.add(x, y)
|
||||
# replace add with sub
|
||||
torch.testing.assert_allclose(result, x - y)
|
||||
torch_assert_close(result, x - y)
|
||||
result = torch.mul(x, y)
|
||||
# replace add with sub
|
||||
torch.testing.assert_allclose(result, x - y)
|
||||
torch_assert_close(result, x - y)
|
||||
|
||||
result = torch.add(x, y)
|
||||
# recovery origin function
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
torch_assert_close(result, x + y)
|
||||
|
||||
with RewriterContext(cfg):
|
||||
result = torch.add(x, y)
|
||||
# replace should not happen with wrong backend
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
torch_assert_close(result, x + y)
|
||||
|
||||
# test different config
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
@ -49,16 +54,16 @@ def test_function_rewriter():
|
|||
with RewriterContext(cfg, backend='tensorrt'):
|
||||
result = x.add(y)
|
||||
# replace add with multi
|
||||
torch.testing.assert_allclose(result, x * y)
|
||||
torch_assert_close(result, x * y)
|
||||
|
||||
result = x.add(y)
|
||||
# recovery origin function
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
torch_assert_close(result, x + y)
|
||||
|
||||
with RewriterContext(cfg):
|
||||
result = x.add(y)
|
||||
# replace add with multi
|
||||
torch.testing.assert_allclose(result, x * y)
|
||||
torch_assert_close(result, x * y)
|
||||
|
||||
# test origin_func
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
@ -70,11 +75,15 @@ def test_function_rewriter():
|
|||
with RewriterContext(cfg):
|
||||
result = torch.add(x, y)
|
||||
# replace with origin + 1
|
||||
torch.testing.assert_allclose(result, x + y + 1)
|
||||
torch_assert_close(result, x + y + 1)
|
||||
|
||||
# remove torch.add
|
||||
del FUNCTION_REWRITER._origin_functions[-1]
|
||||
torch.testing.assert_allclose(torch.add(x, y), x + y)
|
||||
torch_assert_close(torch.add(x, y), x + y)
|
||||
|
||||
FUNCTION_REWRITER._registry.remove_record(sub_func)
|
||||
FUNCTION_REWRITER._registry.remove_record(mul_func_class)
|
||||
FUNCTION_REWRITER._registry.remove_record(origin_add_func)
|
||||
|
||||
|
||||
def test_rewrite_empty_function():
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
except Exception:
|
||||
from torch.testing import assert_allclose as torch_assert_close
|
||||
from mmdeploy.core import MODULE_REWRITER, patch_model
|
||||
|
||||
|
||||
|
@ -29,7 +33,7 @@ def test_module_rewriter():
|
|||
rewritten_model = patch_model(model, cfg=cfg, backend='tensorrt')
|
||||
rewritten_bottle_nect = rewritten_model.layer1[0]
|
||||
rewritten_result = rewritten_bottle_nect(x)
|
||||
torch.testing.assert_allclose(rewritten_result, result * 2)
|
||||
torch_assert_close(rewritten_result, result * 2)
|
||||
|
||||
# wrong backend should not be rewritten
|
||||
model = resnet50().eval()
|
||||
|
@ -38,7 +42,7 @@ def test_module_rewriter():
|
|||
rewritten_model = patch_model(model, cfg=cfg)
|
||||
rewritten_bottle_nect = rewritten_model.layer1[0]
|
||||
rewritten_result = rewritten_bottle_nect(x)
|
||||
torch.testing.assert_allclose(rewritten_result, result)
|
||||
torch_assert_close(rewritten_result, result)
|
||||
|
||||
|
||||
def test_pass_redundant_args_to_model():
|
||||
|
|
|
@ -769,17 +769,13 @@ def test_gather(backend,
|
|||
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
|
||||
not installed.'
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
|
||||
model_outputs = session.run(
|
||||
output_names,
|
||||
dict(
|
||||
zip(input_names, [
|
||||
np.array(data, dtype=np.float32),
|
||||
np.array(indice[0], dtype=np.int64)
|
||||
])))
|
||||
model_outputs = [model_output for model_output in model_outputs]
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_model = ORTWrapper(
|
||||
gather_model.SerializeToString(),
|
||||
device='cpu',
|
||||
output_names=output_names)
|
||||
model_outputs = ort_model(dict(zip(input_names, [data, indice[0]])))
|
||||
model_outputs = ort_model.output_to_list(model_outputs)
|
||||
|
||||
ncnn_outputs = ncnn_model(
|
||||
dict(zip(input_names, [data.float(), indice.float()])))
|
||||
|
|
Loading…
Reference in New Issue