【Fix】Fix ppl bug about grid sample (#325)
* fix ppl problems * fix roialign * fix grid_sampler bug * fix grid sampler * fix config * fix testpull/1/head
parent
199253ce94
commit
9d4d52078b
|
@ -1,3 +1,4 @@
|
|||
_base_ = ['./base_static.py']
|
||||
|
||||
onnx_config = dict(output_names=['dets', 'labels', 'masks'])
|
||||
codebase_config = dict(post_processing=dict(export_postprocess_mask=False))
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
from mmdeploy.utils import Backend, get_backend
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'i', 'i', 'i')
|
||||
|
@ -26,10 +27,36 @@ def grid_sampler(g,
|
|||
align_corners_i=align_corners)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'i', 'i', 'i')
|
||||
def grid_sampler_ppl(g,
|
||||
input,
|
||||
grid,
|
||||
interpolation_mode,
|
||||
padding_mode,
|
||||
align_corners=False):
|
||||
"""Symbolic function for `grid_sampler`.
|
||||
|
||||
PyTorch does not support export grid_sampler to ONNX by default. We add the
|
||||
support here. `grid_sampler` will be exported as ONNX node
|
||||
'mmdeploy::grid_sampler'
|
||||
"""
|
||||
return g.op(
|
||||
'mmcv::grid_sampler',
|
||||
input,
|
||||
grid,
|
||||
interpolation_mode_i=interpolation_mode,
|
||||
padding_mode_i=padding_mode,
|
||||
align_corners_i=align_corners)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||
def grid_sampler__default(ctx, *args):
|
||||
"""Register default symbolic function for `grid_sampler`.
|
||||
|
||||
Add support to grid_sample to ONNX.
|
||||
"""
|
||||
return grid_sampler(*args)
|
||||
backend = get_backend(ctx.cfg)
|
||||
if backend == Backend.PPLNN:
|
||||
return grid_sampler_ppl(*args)
|
||||
else:
|
||||
return grid_sampler(*args)
|
||||
|
|
|
@ -4,6 +4,7 @@ import tempfile
|
|||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv import Config
|
||||
|
||||
from mmdeploy.core import RewriterContext
|
||||
|
||||
|
@ -12,7 +13,10 @@ onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
|
|||
|
||||
@pytest.fixture(autouse=True, scope='module')
|
||||
def prepare_symbolics():
|
||||
context = RewriterContext({}, 'tensorrt', opset=11)
|
||||
context = RewriterContext(
|
||||
Config({'backend_config': {
|
||||
'type': 'tensorrt'
|
||||
}}), 'tensorrt', opset=11)
|
||||
context.enter()
|
||||
|
||||
yield
|
||||
|
|
Loading…
Reference in New Issue