[Fix] fix test ops (#1352)

* fix test ops

* fix name
This commit is contained in:
q.yao 2022-11-14 11:19:57 +08:00 committed by GitHub
parent ed2d7680f7
commit 8b5a6dd7eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 4 deletions

View File

@ -11,6 +11,9 @@ void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
if (gather->op_type() != "Gather") {
continue;
}
if (weights.find(std::string(gather->input(1))) == weights.end()) {
continue;
}
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
if (indices.size() != 1) {
continue;

View File

@ -42,7 +42,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
MMCVRoiAlign op for onnx.
"""
backend = get_backend(ctx.cfg)
if backend == Backend.PPLNN:
if backend == Backend.PPLNN or backend == Backend.TENSORRT:
domain = 'mmcv'
return g.op(
f'{domain}::MMCVRoiAlign',

View File

@ -293,9 +293,10 @@ def test_instance_norm(backend,
else:
dynamic_axes = None
norm = nn.InstanceNorm2d(c, affine=True)
wrapped_model = WrapFunction(norm).eval()
wrapped_model = nn.InstanceNorm2d(c, affine=True).eval().cuda()
cudnn_enable = torch.backends.cudnn.enabled
torch.backends.cudnn.enabled = False
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [input],
@ -304,6 +305,7 @@ def test_instance_norm(backend,
dynamic_axes=dynamic_axes,
output_names=['output'],
save_dir=save_dir)
torch.backends.cudnn.enabled = cudnn_enable
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@ -765,11 +767,11 @@ def test_gather(backend,
# so the ncnn_outputs has 2 dimensions, not 1.
import importlib
import onnxruntime
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
not installed.'
import numpy as np
import onnxruntime
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
model_outputs = session.run(
output_names,