parent
bc7de32fb1
commit
394fb55809
|
@ -6,6 +6,7 @@ from .getattribute import tensor__getattribute__ncnn
|
|||
from .group_norm import group_norm__ncnn
|
||||
from .interpolate import interpolate__ncnn, interpolate__tensorrt
|
||||
from .linear import linear__ncnn
|
||||
from .masked_fill import masked_fill__onnxruntime
|
||||
from .normalize import normalize__ncnn
|
||||
from .repeat import tensor__repeat__tensorrt
|
||||
from .size import tensor__size__ncnn
|
||||
|
@ -17,5 +18,5 @@ __all__ = [
|
|||
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
|
||||
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
|
||||
'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
|
||||
'chunk__torchscript'
|
||||
'chunk__torchscript', 'masked_fill__onnxruntime'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.types import Number
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.masked_fill', backend=Backend.ONNXRUNTIME.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value)
|
||||
def masked_fill__onnxruntime(
|
||||
ctx, input, mask: torch.Tensor, value: Union[torch.Tensor,
|
||||
Number]) -> torch.Tensor:
|
||||
"""Rewrite `masked_fill` for onnxruntime backend.
|
||||
|
||||
SATRN model as example, when value is set to `float('-inf')`, the results
|
||||
of ORT inferencing turns out to be NAN.
|
||||
"""
|
||||
if value == float('-inf'):
|
||||
value = -1e34 # hard coding number
|
||||
return ctx.origin_func(input, mask, value)
|
|
@ -283,3 +283,29 @@ def test_normalize_ncnn(input, dim):
|
|||
param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path)
|
||||
assert osp.exists(param_path)
|
||||
assert osp.exists(bin_path)
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
@pytest.mark.parametrize(
|
||||
'input',
|
||||
[torch.rand(1, 16, 16), torch.rand(1, 3, 16, 16)])
|
||||
def test_masked_fill_onnxruntime(input):
|
||||
mask = input > 0
|
||||
value = float('-inf')
|
||||
|
||||
def masked_fill_caller(*arg, **kwargs):
|
||||
return torch.masked_fill(*arg, **kwargs)
|
||||
|
||||
deploy_cfg_ort = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(input_shape=None),
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||
|
||||
wrapped_func = WrapFunction(masked_fill_caller, mask=mask, value=value)
|
||||
rewrite_output, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={'input': input},
|
||||
deploy_cfg=deploy_cfg_ort,
|
||||
run_with_backend=True)
|
||||
assert rewrite_output is not None
|
||||
|
|
Loading…
Reference in New Issue