diff --git a/mmdeploy/mmcv/ops/modulated_deform_conv.py b/mmdeploy/mmcv/ops/modulated_deform_conv.py index 64fd9fdd7..3c7adcc21 100644 --- a/mmdeploy/mmcv/ops/modulated_deform_conv.py +++ b/mmdeploy/mmcv/ops/modulated_deform_conv.py @@ -1,5 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdeploy.core import SYMBOLIC_REWRITER +import torch + +from mmdeploy.core import FUNCTION_REWRITER, SYMBOLIC_REWRITER +from mmdeploy.utils import IR + + +@FUNCTION_REWRITER.register_rewriter( + 'mmcv.ops.modulated_deform_conv.modulated_deform_conv2d', + ir=IR.TORCHSCRIPT) +def modulated_deform_conv__torchscript(input, offset, mask, weight, bias, + stride, padding, dilation, groups, + deform_groups): + """rewriter for the custom torchscript mdcn op.""" + from mmdeploy.backend.torchscript import get_ops_path, ops_available + assert ops_available(), 'torchscript custom ops is required.' + torch.ops.load_library(get_ops_path()) + from torch.nn.modules.utils import _pair + kernel_h, kernel_w = weight.shape[-2:] + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + with_bias = bias is not None + if not with_bias: + bias = input.new_empty(0) + return torch.ops.mmdeploy.modulated_deform_conv( + input, weight, bias, offset, mask, kernel_h, kernel_w, stride[1], + stride[0], padding[1], padding[0], dilation[1], dilation[0], groups, + deform_groups, with_bias) @SYMBOLIC_REWRITER.register_symbolic( diff --git a/tests/test_mmcv/test_mmcv_ops.py b/tests/test_mmcv/test_mmcv_ops.py index 8c3cb185e..8267a7de5 100644 --- a/tests/test_mmcv/test_mmcv_ops.py +++ b/tests/test_mmcv/test_mmcv_ops.py @@ -221,3 +221,25 @@ def test_multiclass_nms__ascend(): assert rewrite_outputs is not None, 'Got unexpected rewrite '\ 'outputs: {}'.format(rewrite_outputs) + + +def test_modulated_deform_conv(): + check_backend(Backend.TORCHSCRIPT) + from mmdeploy.backend.torchscript import ops_available + + if not ops_available(): + pytest.skip('torchscript custom ops is required.') + + from mmcv.ops import ModulatedDeformConv2dPack + + from mmdeploy.apis.torch_jit import trace + + model = ModulatedDeformConv2dPack(3, 1, 1).eval() + x = torch.rand(1, 3, 16, 16) + + jit_model = trace(model, x, None, backend='torchscript') + + out = model(x) + jit_out = jit_model(x) + + torch.testing.assert_allclose(out, jit_out)