mirror of https://github.com/open-mmlab/mmcv.git
fix dcn forward and backward when batchsize is larger than im2col_step (#1212)
parent
571e3e5fc7
commit
eaebb30637
|
@ -278,6 +278,8 @@ void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,
|
|||
}
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
|
@ -384,6 +386,8 @@ void DeformConvBackwardInputCUDAKernelLauncher(
|
|||
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
|
|
@ -386,6 +386,8 @@ void DeformConvBackwardInputCUDAKernelLauncher(
|
|||
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
|
Loading…
Reference in New Issue