mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
parent
ed2d7680f7
commit
8b5a6dd7eb
@ -11,6 +11,9 @@ void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
|||||||
if (gather->op_type() != "Gather") {
|
if (gather->op_type() != "Gather") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (weights.find(std::string(gather->input(1))) == weights.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
|
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
|
||||||
if (indices.size() != 1) {
|
if (indices.size() != 1) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -42,7 +42,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
|||||||
MMCVRoiAlign op for onnx.
|
MMCVRoiAlign op for onnx.
|
||||||
"""
|
"""
|
||||||
backend = get_backend(ctx.cfg)
|
backend = get_backend(ctx.cfg)
|
||||||
if backend == Backend.PPLNN:
|
if backend == Backend.PPLNN or backend == Backend.TENSORRT:
|
||||||
domain = 'mmcv'
|
domain = 'mmcv'
|
||||||
return g.op(
|
return g.op(
|
||||||
f'{domain}::MMCVRoiAlign',
|
f'{domain}::MMCVRoiAlign',
|
||||||
|
@ -293,9 +293,10 @@ def test_instance_norm(backend,
|
|||||||
else:
|
else:
|
||||||
dynamic_axes = None
|
dynamic_axes = None
|
||||||
|
|
||||||
norm = nn.InstanceNorm2d(c, affine=True)
|
wrapped_model = nn.InstanceNorm2d(c, affine=True).eval().cuda()
|
||||||
wrapped_model = WrapFunction(norm).eval()
|
|
||||||
|
|
||||||
|
cudnn_enable = torch.backends.cudnn.enabled
|
||||||
|
torch.backends.cudnn.enabled = False
|
||||||
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
|
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
|
||||||
backend.run_and_validate(
|
backend.run_and_validate(
|
||||||
wrapped_model, [input],
|
wrapped_model, [input],
|
||||||
@ -304,6 +305,7 @@ def test_instance_norm(backend,
|
|||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
output_names=['output'],
|
output_names=['output'],
|
||||||
save_dir=save_dir)
|
save_dir=save_dir)
|
||||||
|
torch.backends.cudnn.enabled = cudnn_enable
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||||
@ -765,11 +767,11 @@ def test_gather(backend,
|
|||||||
# so the ncnn_outputs has 2 dimensions, not 1.
|
# so the ncnn_outputs has 2 dimensions, not 1.
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import onnxruntime
|
|
||||||
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
|
assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \
|
||||||
not installed.'
|
not installed.'
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import onnxruntime
|
||||||
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
|
session = onnxruntime.InferenceSession(gather_model.SerializeToString())
|
||||||
model_outputs = session.run(
|
model_outputs = session.run(
|
||||||
output_names,
|
output_names,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user