【Fix】Fix ppl bug about grid sample (#325)
* fix ppl problems * fix roialign * fix grid_sampler bug * fix grid sampler * fix config * fix testpull/1/head
parent
199253ce94
commit
9d4d52078b
|
@ -1,3 +1,4 @@
|
||||||
_base_ = ['./base_static.py']
|
_base_ = ['./base_static.py']
|
||||||
|
|
||||||
onnx_config = dict(output_names=['dets', 'labels', 'masks'])
|
onnx_config = dict(output_names=['dets', 'labels', 'masks'])
|
||||||
|
codebase_config = dict(post_processing=dict(export_postprocess_mask=False))
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from torch.onnx.symbolic_helper import parse_args
|
from torch.onnx.symbolic_helper import parse_args
|
||||||
|
|
||||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||||
|
from mmdeploy.utils import Backend, get_backend
|
||||||
|
|
||||||
|
|
||||||
@parse_args('v', 'v', 'i', 'i', 'i')
|
@parse_args('v', 'v', 'i', 'i', 'i')
|
||||||
|
@ -26,10 +27,36 @@ def grid_sampler(g,
|
||||||
align_corners_i=align_corners)
|
align_corners_i=align_corners)
|
||||||
|
|
||||||
|
|
||||||
|
@parse_args('v', 'v', 'i', 'i', 'i')
|
||||||
|
def grid_sampler_ppl(g,
|
||||||
|
input,
|
||||||
|
grid,
|
||||||
|
interpolation_mode,
|
||||||
|
padding_mode,
|
||||||
|
align_corners=False):
|
||||||
|
"""Symbolic function for `grid_sampler`.
|
||||||
|
|
||||||
|
PyTorch does not support export grid_sampler to ONNX by default. We add the
|
||||||
|
support here. `grid_sampler` will be exported as ONNX node
|
||||||
|
'mmdeploy::grid_sampler'
|
||||||
|
"""
|
||||||
|
return g.op(
|
||||||
|
'mmcv::grid_sampler',
|
||||||
|
input,
|
||||||
|
grid,
|
||||||
|
interpolation_mode_i=interpolation_mode,
|
||||||
|
padding_mode_i=padding_mode,
|
||||||
|
align_corners_i=align_corners)
|
||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
|
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||||
def grid_sampler__default(ctx, *args):
|
def grid_sampler__default(ctx, *args):
|
||||||
"""Register default symbolic function for `grid_sampler`.
|
"""Register default symbolic function for `grid_sampler`.
|
||||||
|
|
||||||
Add support to grid_sample to ONNX.
|
Add support to grid_sample to ONNX.
|
||||||
"""
|
"""
|
||||||
|
backend = get_backend(ctx.cfg)
|
||||||
|
if backend == Backend.PPLNN:
|
||||||
|
return grid_sampler_ppl(*args)
|
||||||
|
else:
|
||||||
return grid_sampler(*args)
|
return grid_sampler(*args)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import tempfile
|
||||||
import onnx
|
import onnx
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from mmcv import Config
|
||||||
|
|
||||||
from mmdeploy.core import RewriterContext
|
from mmdeploy.core import RewriterContext
|
||||||
|
|
||||||
|
@ -12,7 +13,10 @@ onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope='module')
|
@pytest.fixture(autouse=True, scope='module')
|
||||||
def prepare_symbolics():
|
def prepare_symbolics():
|
||||||
context = RewriterContext({}, 'tensorrt', opset=11)
|
context = RewriterContext(
|
||||||
|
Config({'backend_config': {
|
||||||
|
'type': 'tensorrt'
|
||||||
|
}}), 'tensorrt', opset=11)
|
||||||
context.enter()
|
context.enter()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
Loading…
Reference in New Issue