From 9a3f1249e5bdca0933dbc56739a841846d35707a Mon Sep 17 00:00:00 2001 From: Ming-Hsuan-Tu Date: Fri, 6 Nov 2020 13:01:28 +0800 Subject: [PATCH] [fix] Missing arguments when converting dcn to onnx (#624) * fix issues when converting deformable convolution to onnx * keep and for interface consistency Co-authored-by: maningsheng --- mmcv/ops/deform_conv.py | 28 ++++++++++++++++++---------- mmcv/ops/modulated_deform_conv.py | 10 +++++----- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 7a70fa910..250e096a5 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -20,20 +20,29 @@ ext_module = ext_loader.load_ext('_ext', [ class DeformConv2dFunction(Function): @staticmethod - def symbolic(g, input, offset, weight, stride, padding, dilation, groups, - deform_groups, bias, im2col_step): + def symbolic(g, + input, + offset, + weight, + stride, + padding, + dilation, + groups, + deform_groups, + bias=False, + im2col_step=32): return g.op( 'MMCVDeformConv2d', input, offset, weight, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - deform_groups=deform_groups, - bias=bias, - im2col_step=im2col_step) + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups, + bias_i=bias, + im2col_step_i=im2col_step) @staticmethod def forward(ctx, @@ -52,7 +61,6 @@ class DeformConv2dFunction(Function): f'Expected 4D tensor as input, got {input.dim()}D tensor \ instead.') assert bias is False, 'Only support bias is False.' - ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 1770b6579..b8ff1adeb 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -27,11 +27,11 @@ class ModulatedDeformConv2dFunction(Function): mask, weight, bias, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - deform_groups=deform_groups) + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups) @staticmethod def forward(ctx,