From 394fb5580958bf4143d86b3de523887543a214e6 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Tue, 19 Jul 2022 09:47:42 +0800 Subject: [PATCH] fix satrn for ORT (#753) * fix satrn for ORT * move rewrite into pytorch --- mmdeploy/pytorch/functions/__init__.py | 3 ++- mmdeploy/pytorch/functions/masked_fill.py | 25 +++++++++++++++++++ tests/test_pytorch/test_pytorch_functions.py | 26 ++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 mmdeploy/pytorch/functions/masked_fill.py diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 337392f53..3e84ee6ed 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -6,6 +6,7 @@ from .getattribute import tensor__getattribute__ncnn from .group_norm import group_norm__ncnn from .interpolate import interpolate__ncnn, interpolate__tensorrt from .linear import linear__ncnn +from .masked_fill import masked_fill__onnxruntime from .normalize import normalize__ncnn from .repeat import tensor__repeat__tensorrt from .size import tensor__size__ncnn @@ -17,5 +18,5 @@ __all__ = [ 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', 'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', - 'chunk__torchscript' + 'chunk__torchscript', 'masked_fill__onnxruntime' ] diff --git a/mmdeploy/pytorch/functions/masked_fill.py b/mmdeploy/pytorch/functions/masked_fill.py new file mode 100644 index 000000000..5e4f67b45 --- /dev/null +++ b/mmdeploy/pytorch/functions/masked_fill.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from torch.types import Number + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils.constants import Backend + + +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.masked_fill', backend=Backend.ONNXRUNTIME.value) +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value) +def masked_fill__onnxruntime( + ctx, input, mask: torch.Tensor, value: Union[torch.Tensor, + Number]) -> torch.Tensor: + """Rewrite `masked_fill` for onnxruntime backend. + + SATRN model as example, when value is set to `float('-inf')`, the results + of ORT inferencing turns out to be NAN. + """ + if value == float('-inf'): + value = -1e34 # hard coding number + return ctx.origin_func(input, mask, value) diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index ba9604097..fca863af3 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -283,3 +283,29 @@ def test_normalize_ncnn(input, dim): param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path) assert osp.exists(param_path) assert osp.exists(bin_path) + + +@backend_checker(Backend.ONNXRUNTIME) +@pytest.mark.parametrize( + 'input', + [torch.rand(1, 16, 16), torch.rand(1, 3, 16, 16)]) +def test_masked_fill_onnxruntime(input): + mask = input > 0 + value = float('-inf') + + def masked_fill_caller(*arg, **kwargs): + return torch.masked_fill(*arg, **kwargs) + + deploy_cfg_ort = mmcv.Config( + dict( + onnx_config=dict(input_shape=None), + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmdet', task='ObjectDetection'))) + + wrapped_func = WrapFunction(masked_fill_caller, mask=mask, value=value) + rewrite_output, _ = get_rewrite_outputs( + wrapped_func, + model_inputs={'input': input}, + deploy_cfg=deploy_cfg_ort, + run_with_backend=True) + assert rewrite_output is not None