【Fix】Fix ppl bug about grid sample (#325)

* fix ppl problems

* fix roialign

* fix grid_sampler bug

* fix grid sampler

* fix config

* fix test
pull/1/head
VVsssssk 2021-12-23 12:11:07 +08:00 committed by GitHub
parent 199253ce94
commit 9d4d52078b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 2 deletions

View File

@ -1,3 +1,4 @@
_base_ = ['./base_static.py']
onnx_config = dict(output_names=['dets', 'labels', 'masks'])
codebase_config = dict(post_processing=dict(export_postprocess_mask=False))

View File

@ -2,6 +2,7 @@
from torch.onnx.symbolic_helper import parse_args
from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend, get_backend
@parse_args('v', 'v', 'i', 'i', 'i')
@ -26,10 +27,36 @@ def grid_sampler(g,
align_corners_i=align_corners)
@parse_args('v', 'v', 'i', 'i', 'i')
def grid_sampler_ppl(g,
input,
grid,
interpolation_mode,
padding_mode,
align_corners=False):
"""Symbolic function for `grid_sampler`.
PyTorch does not support export grid_sampler to ONNX by default. We add the
support here. `grid_sampler` will be exported as ONNX node
'mmdeploy::grid_sampler'
"""
return g.op(
'mmcv::grid_sampler',
input,
grid,
interpolation_mode_i=interpolation_mode,
padding_mode_i=padding_mode,
align_corners_i=align_corners)
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
def grid_sampler__default(ctx, *args):
"""Register default symbolic function for `grid_sampler`.
Add support to grid_sample to ONNX.
"""
return grid_sampler(*args)
backend = get_backend(ctx.cfg)
if backend == Backend.PPLNN:
return grid_sampler_ppl(*args)
else:
return grid_sampler(*args)

View File

@ -4,6 +4,7 @@ import tempfile
import onnx
import pytest
import torch
from mmcv import Config
from mmdeploy.core import RewriterContext
@ -12,7 +13,10 @@ onnx_file = tempfile.NamedTemporaryFile(suffix='onnx').name
@pytest.fixture(autouse=True, scope='module')
def prepare_symbolics():
context = RewriterContext({}, 'tensorrt', opset=11)
context = RewriterContext(
Config({'backend_config': {
'type': 'tensorrt'
}}), 'tensorrt', opset=11)
context.enter()
yield