mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
parent
ed2d7680f7
commit
8b5a6dd7eb
@ -11,6 +11,9 @@ void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
||||
if (gather->op_type() != "Gather") {
|
||||
continue;
|
||||
}
|
||||
if (weights.find(std::string(gather->input(1))) == weights.end()) {
|
||||
continue;
|
||||
}
|
||||
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
|
||||
if (indices.size() != 1) {
|
||||
continue;
|
||||
|
@ -42,7 +42,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
||||
MMCVRoiAlign op for onnx.
|
||||
"""
|
||||
backend = get_backend(ctx.cfg)
|
||||
if backend == Backend.PPLNN:
|
||||
if backend == Backend.PPLNN or backend == Backend.TENSORRT:
|
||||
domain = 'mmcv'
|
||||
return g.op(
|
||||
f'{domain}::MMCVRoiAlign',
|
||||
|
@ -293,9 +293,10 @@ def test_instance_norm(backend,
|
||||
else:
|
||||
dynamic_axes = None
|
||||
|
||||
norm = nn.InstanceNorm2d(c, affine=True)
|
||||
wrapped_model = WrapFunction(norm).eval()
|
||||
wrapped_model = nn.InstanceNorm2d(c, affine=True).eval().cuda()
|
||||
|
||||
cudnn_enable = torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.enabled = False
|
||||
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input],
|
||||
@ -304,6 +305,7 @@ def test_instance_norm(backend,
|
||||
dynamic_axes=dynamic_axes,
|
||||
output_names=['output'],
|
||||
save_dir=save_dir)
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||
@ -765,11 +767,11 @@ def test_gather(backend,
|
||||
# so the ncnn_outputs has 2 dimensions, not 1.
|
||||
import importlib
|
||||
|
||||
import onnxruntime
|
||||
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
|
||||
not installed.'
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
|
||||
model_outputs = session.run(
|
||||
output_names,
|
||||
|
Loading…
x
Reference in New Issue
Block a user