mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
torchjit mdcn 1.x (#1536)
This commit is contained in:
parent
67c1cd2475
commit
84ee45bc35
@ -1,5 +1,32 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER, SYMBOLIC_REWRITER
|
||||||
|
from mmdeploy.utils import IR
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
'mmcv.ops.modulated_deform_conv.modulated_deform_conv2d',
|
||||||
|
ir=IR.TORCHSCRIPT)
|
||||||
|
def modulated_deform_conv__torchscript(input, offset, mask, weight, bias,
|
||||||
|
stride, padding, dilation, groups,
|
||||||
|
deform_groups):
|
||||||
|
"""rewriter for the custom torchscript mdcn op."""
|
||||||
|
from mmdeploy.backend.torchscript import get_ops_path, ops_available
|
||||||
|
assert ops_available(), 'torchscript custom ops is required.'
|
||||||
|
torch.ops.load_library(get_ops_path())
|
||||||
|
from torch.nn.modules.utils import _pair
|
||||||
|
kernel_h, kernel_w = weight.shape[-2:]
|
||||||
|
stride = _pair(stride)
|
||||||
|
padding = _pair(padding)
|
||||||
|
dilation = _pair(dilation)
|
||||||
|
with_bias = bias is not None
|
||||||
|
if not with_bias:
|
||||||
|
bias = input.new_empty(0)
|
||||||
|
return torch.ops.mmdeploy.modulated_deform_conv(
|
||||||
|
input, weight, bias, offset, mask, kernel_h, kernel_w, stride[1],
|
||||||
|
stride[0], padding[1], padding[0], dilation[1], dilation[0], groups,
|
||||||
|
deform_groups, with_bias)
|
||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic(
|
@SYMBOLIC_REWRITER.register_symbolic(
|
||||||
|
@ -221,3 +221,25 @@ def test_multiclass_nms__ascend():
|
|||||||
|
|
||||||
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
||||||
'outputs: {}'.format(rewrite_outputs)
|
'outputs: {}'.format(rewrite_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_modulated_deform_conv():
|
||||||
|
check_backend(Backend.TORCHSCRIPT)
|
||||||
|
from mmdeploy.backend.torchscript import ops_available
|
||||||
|
|
||||||
|
if not ops_available():
|
||||||
|
pytest.skip('torchscript custom ops is required.')
|
||||||
|
|
||||||
|
from mmcv.ops import ModulatedDeformConv2dPack
|
||||||
|
|
||||||
|
from mmdeploy.apis.torch_jit import trace
|
||||||
|
|
||||||
|
model = ModulatedDeformConv2dPack(3, 1, 1).eval()
|
||||||
|
x = torch.rand(1, 3, 16, 16)
|
||||||
|
|
||||||
|
jit_model = trace(model, x, None, backend='torchscript')
|
||||||
|
|
||||||
|
out = model(x)
|
||||||
|
jit_out = jit_model(x)
|
||||||
|
|
||||||
|
torch.testing.assert_allclose(out, jit_out)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user