diff --git a/mmdeploy/mmcv/ops/modulated_deform_conv.py b/mmdeploy/mmcv/ops/modulated_deform_conv.py index df3c338a8..3c9bd4008 100644 --- a/mmdeploy/mmcv/ops/modulated_deform_conv.py +++ b/mmdeploy/mmcv/ops/modulated_deform_conv.py @@ -1,5 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdeploy.core import SYMBOLIC_REWRITER +import torch + +from mmdeploy.core import FUNCTION_REWRITER, SYMBOLIC_REWRITER +from mmdeploy.utils import IR + + +@FUNCTION_REWRITER.register_rewriter( + 'mmcv.ops.modulated_deform_conv.modulated_deform_conv2d', + ir=IR.TORCHSCRIPT) +def modulated_deform_conv__torchscript(ctx, input, offset, mask, weight, bias, + stride, padding, dilation, groups, + deform_groups): + """rewriter for the custom torchscript mdcn op.""" + from mmdeploy.backend.torchscript import get_ops_path, ops_available + assert ops_available(), 'torchscript custom ops is required.' + torch.ops.load_library(get_ops_path()) + from torch.nn.modules.utils import _pair + kernel_h, kernel_w = weight.shape[-2:] + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + with_bias = bias is not None + if not with_bias: + bias = input.new_empty(0) + return torch.ops.mmdeploy.modulated_deform_conv( + input, weight, bias, offset, mask, kernel_h, kernel_w, stride[1], + stride[0], padding[1], padding[0], dilation[1], dilation[0], groups, + deform_groups, with_bias) @SYMBOLIC_REWRITER.register_symbolic( diff --git a/tests/test_mmcv/test_mmcv_ops.py b/tests/test_mmcv/test_mmcv_ops.py index 4d41dc2fd..bd3df9d26 100644 --- a/tests/test_mmcv/test_mmcv_ops.py +++ b/tests/test_mmcv/test_mmcv_ops.py @@ -114,3 +114,25 @@ def test_patch_embed_ncnn(): with RewriterContext({}, backend='ncnn'), torch.no_grad(): _, shape = wrapped_model(input) assert shape[0] == patch_cfg['input_size'] / patch_cfg['stride'] + + +def test_modulated_deform_conv(): + check_backend(Backend.TORCHSCRIPT) + from mmdeploy.backend.torchscript import ops_available + + if not ops_available(): + pytest.skip('torchscript custom ops is required.') + + from mmcv.ops import ModulatedDeformConv2dPack + + from mmdeploy.apis.torch_jit import trace + + model = ModulatedDeformConv2dPack(3, 1, 1).eval() + x = torch.rand(1, 3, 16, 16) + + jit_model = trace(model, x, None, backend='torchscript') + + out = model(x) + jit_out = jit_model(x) + + torch.testing.assert_allclose(out, jit_out)