fix satrn for ORT (#753)

* fix satrn for ORT

* move rewrite into pytorch
pull/704/head
AllentDan 2022-07-19 09:47:42 +08:00 committed by GitHub
parent bc7de32fb1
commit 394fb55809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 1 deletions

View File

@ -6,6 +6,7 @@ from .getattribute import tensor__getattribute__ncnn
from .group_norm import group_norm__ncnn
from .interpolate import interpolate__ncnn, interpolate__tensorrt
from .linear import linear__ncnn
from .masked_fill import masked_fill__onnxruntime
from .normalize import normalize__ncnn
from .repeat import tensor__repeat__tensorrt
from .size import tensor__size__ncnn
@ -17,5 +18,5 @@ __all__ = [
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
'chunk__torchscript'
'chunk__torchscript', 'masked_fill__onnxruntime'
]

View File

@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from torch.types import Number
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.constants import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.masked_fill', backend=Backend.ONNXRUNTIME.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value)
def masked_fill__onnxruntime(
ctx, input, mask: torch.Tensor, value: Union[torch.Tensor,
Number]) -> torch.Tensor:
"""Rewrite `masked_fill` for onnxruntime backend.
SATRN model as example, when value is set to `float('-inf')`, the results
of ORT inferencing turns out to be NAN.
"""
if value == float('-inf'):
value = -1e34 # hard coding number
return ctx.origin_func(input, mask, value)

View File

@ -283,3 +283,29 @@ def test_normalize_ncnn(input, dim):
param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path)
assert osp.exists(param_path)
assert osp.exists(bin_path)
@backend_checker(Backend.ONNXRUNTIME)
@pytest.mark.parametrize(
'input',
[torch.rand(1, 16, 16), torch.rand(1, 3, 16, 16)])
def test_masked_fill_onnxruntime(input):
mask = input > 0
value = float('-inf')
def masked_fill_caller(*arg, **kwargs):
return torch.masked_fill(*arg, **kwargs)
deploy_cfg_ort = mmcv.Config(
dict(
onnx_config=dict(input_shape=None),
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
wrapped_func = WrapFunction(masked_fill_caller, mask=mask, value=value)
rewrite_output, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=deploy_cfg_ort,
run_with_backend=True)
assert rewrite_output is not None