[Fix] Fix deform_conv ops on Ascend NPU (#2805)

pull/2849/head^2
Yinlei Sun 2023-06-15 16:46:07 +08:00 committed by GitHub
parent 9c0f9cb98c
commit f07ed6bcc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 3 deletions

View File

@ -56,7 +56,8 @@ class DeformConv2dFunction(Function):
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True)

View File

@ -63,7 +63,9 @@ class ModulatedDeformConv2dFunction(Function):
conv2d_bias,
kernel_size=[kernel_w, kernel_h],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
padding=[
ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]
],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups,
deformable_groups=ctx.deform_groups,
@ -83,7 +85,8 @@ class ModulatedDeformConv2dFunction(Function):
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True)