support torchjit mdcn (#1508)
parent
c1ca5a3dbf
commit
af4d304004
|
@ -1,5 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER, SYMBOLIC_REWRITER
|
||||
from mmdeploy.utils import IR
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcv.ops.modulated_deform_conv.modulated_deform_conv2d',
|
||||
ir=IR.TORCHSCRIPT)
|
||||
def modulated_deform_conv__torchscript(ctx, input, offset, mask, weight, bias,
|
||||
stride, padding, dilation, groups,
|
||||
deform_groups):
|
||||
"""rewriter for the custom torchscript mdcn op."""
|
||||
from mmdeploy.backend.torchscript import get_ops_path, ops_available
|
||||
assert ops_available(), 'torchscript custom ops is required.'
|
||||
torch.ops.load_library(get_ops_path())
|
||||
from torch.nn.modules.utils import _pair
|
||||
kernel_h, kernel_w = weight.shape[-2:]
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
dilation = _pair(dilation)
|
||||
with_bias = bias is not None
|
||||
if not with_bias:
|
||||
bias = input.new_empty(0)
|
||||
return torch.ops.mmdeploy.modulated_deform_conv(
|
||||
input, weight, bias, offset, mask, kernel_h, kernel_w, stride[1],
|
||||
stride[0], padding[1], padding[0], dilation[1], dilation[0], groups,
|
||||
deform_groups, with_bias)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
|
|
|
@ -114,3 +114,25 @@ def test_patch_embed_ncnn():
|
|||
with RewriterContext({}, backend='ncnn'), torch.no_grad():
|
||||
_, shape = wrapped_model(input)
|
||||
assert shape[0] == patch_cfg['input_size'] / patch_cfg['stride']
|
||||
|
||||
|
||||
def test_modulated_deform_conv():
|
||||
check_backend(Backend.TORCHSCRIPT)
|
||||
from mmdeploy.backend.torchscript import ops_available
|
||||
|
||||
if not ops_available():
|
||||
pytest.skip('torchscript custom ops is required.')
|
||||
|
||||
from mmcv.ops import ModulatedDeformConv2dPack
|
||||
|
||||
from mmdeploy.apis.torch_jit import trace
|
||||
|
||||
model = ModulatedDeformConv2dPack(3, 1, 1).eval()
|
||||
x = torch.rand(1, 3, 16, 16)
|
||||
|
||||
jit_model = trace(model, x, None, backend='torchscript')
|
||||
|
||||
out = model(x)
|
||||
jit_out = jit_model(x)
|
||||
|
||||
torch.testing.assert_allclose(out, jit_out)
|
||||
|
|
Loading…
Reference in New Issue