* cherry-pick PR1352 to master * fix test_ops with teardown and skip * remove useless line * fix lint Co-authored-by: q.yao <yaoqian@sensetime.com>pull/1589/head
parent
fc98472e9c
commit
f21dc4e7d3
csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn
mmdeploy/mmcv/ops
tests/test_ops
|
@ -39,6 +39,9 @@ void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
|||
if (gather->op_type() != "Gather") {
|
||||
continue;
|
||||
}
|
||||
if (weights.find(std::string(gather->input(1))) == weights.end()) {
|
||||
continue;
|
||||
}
|
||||
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
|
||||
if (indices.size() != 1) {
|
||||
continue;
|
||||
|
|
|
@ -42,7 +42,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
|||
MMCVRoiAlign op for onnx.
|
||||
"""
|
||||
backend = get_backend(ctx.cfg)
|
||||
if backend == Backend.PPLNN:
|
||||
if backend == Backend.PPLNN or backend == Backend.TENSORRT:
|
||||
domain = 'mmcv'
|
||||
return g.op(
|
||||
f'{domain}::MMCVRoiAlign',
|
||||
|
|
|
@ -16,6 +16,15 @@ TEST_TENSORRT = TestTensorRTExporter()
|
|||
TEST_NCNN = TestNCNNExporter()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def disable_cudnn():
|
||||
cudnn_enable = torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.enabled = False
|
||||
|
||||
yield
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||
@pytest.mark.parametrize('pool_h,pool_w,spatial_scale,sampling_ratio',
|
||||
[(2, 2, 1.0, 2), (4, 4, 2.0, 4)])
|
||||
|
@ -264,7 +273,8 @@ def test_deform_conv(backend,
|
|||
@pytest.mark.parametrize('dynamic_export', [True, False])
|
||||
@pytest.mark.parametrize('fp16_mode', [True, False])
|
||||
@pytest.mark.parametrize('n, c, h, w', [(2, 3, 10, 10)])
|
||||
def test_instance_norm(backend,
|
||||
def test_instance_norm(disable_cudnn,
|
||||
backend,
|
||||
dynamic_export,
|
||||
fp16_mode,
|
||||
n,
|
||||
|
@ -293,8 +303,7 @@ def test_instance_norm(backend,
|
|||
else:
|
||||
dynamic_axes = None
|
||||
|
||||
norm = nn.InstanceNorm2d(c, affine=True)
|
||||
wrapped_model = WrapFunction(norm).eval()
|
||||
wrapped_model = nn.InstanceNorm2d(c, affine=True).eval().cuda()
|
||||
|
||||
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
|
||||
backend.run_and_validate(
|
||||
|
@ -765,11 +774,11 @@ def test_gather(backend,
|
|||
# so the ncnn_outputs has 2 dimensions, not 1.
|
||||
import importlib
|
||||
|
||||
import onnxruntime
|
||||
assert importlib.util.find_spec('onnxruntime') is not None, \
|
||||
'onnxruntime not installed.'
|
||||
if importlib.util.find_spec('onnxruntime') is None:
|
||||
pytest.skip('onnxruntime not installed.')
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
|
||||
model_outputs = session.run(
|
||||
output_names,
|
||||
|
@ -779,7 +788,6 @@ def test_gather(backend,
|
|||
np.array(indice[0], dtype=np.int64)
|
||||
])))
|
||||
model_outputs = [model_output for model_output in model_outputs]
|
||||
|
||||
ncnn_outputs = ncnn_model(
|
||||
dict(zip(input_names, [data.float(), indice.float()])))
|
||||
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
|
||||
|
|
Loading…
Reference in New Issue