mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix deform_conv ops on Ascend NPU (#2805)
parent
9c0f9cb98c
commit
f07ed6bcc0
|
@ -56,7 +56,8 @@ class DeformConv2dFunction(Function):
|
|||
input_tensor, grad_output, offset_out, weight, offset_all,
|
||||
kernel_size=[weight.shape[3], weight.shape[2]],
|
||||
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
|
||||
ctx.padding[1]],
|
||||
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||
groups=ctx.groups, deformable_groups=ctx.deform_groups,
|
||||
modulated=True)
|
||||
|
|
|
@ -63,7 +63,9 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
conv2d_bias,
|
||||
kernel_size=[kernel_w, kernel_h],
|
||||
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||
padding=[
|
||||
ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]
|
||||
],
|
||||
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||
groups=ctx.groups,
|
||||
deformable_groups=ctx.deform_groups,
|
||||
|
@ -83,7 +85,8 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
input_tensor, grad_output, offset_out, weight, offset_all,
|
||||
kernel_size=[weight.shape[3], weight.shape[2]],
|
||||
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
|
||||
ctx.padding[1]],
|
||||
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||
groups=ctx.groups, deformable_groups=ctx.deform_groups,
|
||||
modulated=True)
|
||||
|
|
Loading…
Reference in New Issue