support torchjit mdcn (#1508)

pull/1530/head
q.yao 2022-12-12 10:28:49 +08:00 committed by GitHub
parent c1ca5a3dbf
commit af4d304004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 1 deletions

View File

@ -1,5 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import SYMBOLIC_REWRITER
import torch
from mmdeploy.core import FUNCTION_REWRITER, SYMBOLIC_REWRITER
from mmdeploy.utils import IR
@FUNCTION_REWRITER.register_rewriter(
'mmcv.ops.modulated_deform_conv.modulated_deform_conv2d',
ir=IR.TORCHSCRIPT)
def modulated_deform_conv__torchscript(ctx, input, offset, mask, weight, bias,
stride, padding, dilation, groups,
deform_groups):
"""rewriter for the custom torchscript mdcn op."""
from mmdeploy.backend.torchscript import get_ops_path, ops_available
assert ops_available(), 'torchscript custom ops is required.'
torch.ops.load_library(get_ops_path())
from torch.nn.modules.utils import _pair
kernel_h, kernel_w = weight.shape[-2:]
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
with_bias = bias is not None
if not with_bias:
bias = input.new_empty(0)
return torch.ops.mmdeploy.modulated_deform_conv(
input, weight, bias, offset, mask, kernel_h, kernel_w, stride[1],
stride[0], padding[1], padding[0], dilation[1], dilation[0], groups,
deform_groups, with_bias)
@SYMBOLIC_REWRITER.register_symbolic(

View File

@ -114,3 +114,25 @@ def test_patch_embed_ncnn():
with RewriterContext({}, backend='ncnn'), torch.no_grad():
_, shape = wrapped_model(input)
assert shape[0] == patch_cfg['input_size'] / patch_cfg['stride']
def test_modulated_deform_conv():
check_backend(Backend.TORCHSCRIPT)
from mmdeploy.backend.torchscript import ops_available
if not ops_available():
pytest.skip('torchscript custom ops is required.')
from mmcv.ops import ModulatedDeformConv2dPack
from mmdeploy.apis.torch_jit import trace
model = ModulatedDeformConv2dPack(3, 1, 1).eval()
x = torch.rand(1, 3, 16, 16)
jit_model = trace(model, x, None, backend='torchscript')
out = model(x)
jit_out = jit_model(x)
torch.testing.assert_allclose(out, jit_out)