mirror of https://github.com/open-mmlab/mmcv.git
Add DCN and Modulated DCN CPU implementation (#1278)
* DCN cpu version * add modulated dcn cpu version * move deform_conv_shape_check to deform conv utils * add inline to deform_conv_shape_check * add tests * run linter * add newline at file end * run pre-commit against modulated deform conv cpp * update saconv test * run clang-format * remove cuda device inline * refactor dcn cuda/cpu functions * remove DCN util * remove DCN util hpp from all included files * Addressing PR comment by refactoring modulated-DCN * fix lint in cpp filespull/1309/head
parent
5617ad72d0
commit
e621e08d54
|
@ -101,422 +101,3 @@ void deformable_col2im_coord(
|
|||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void deform_conv_shape_check(Tensor input, Tensor offset, Tensor *gradOutput,
|
||||
Tensor weight, int kH, int kW, int dH, int dW,
|
||||
int padH, int padW, int dilationH, int dilationW,
|
||||
int group, int deformable_group) {
|
||||
TORCH_CHECK(
|
||||
weight.ndimension() == 4,
|
||||
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s",
|
||||
weight.ndimension());
|
||||
|
||||
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
|
||||
TORCH_CHECK(kW > 0 && kH > 0,
|
||||
"kernel size should be greater than zero, but got kH: %d kW: %d",
|
||||
kH, kW);
|
||||
|
||||
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
||||
"kernel size should be consistent with weight, ",
|
||||
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
|
||||
kH, kW, weight.size(2), weight.size(3));
|
||||
|
||||
TORCH_CHECK(dW > 0 && dH > 0,
|
||||
"stride should be greater than zero, but got dH: %d dW: %d", dH,
|
||||
dW);
|
||||
|
||||
TORCH_CHECK(
|
||||
dilationW > 0 && dilationH > 0,
|
||||
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
||||
dilationH, dilationW);
|
||||
|
||||
int ndim = input.ndimension();
|
||||
int dimf = 0;
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
if (ndim == 4) {
|
||||
dimf++;
|
||||
dimh++;
|
||||
dimw++;
|
||||
}
|
||||
|
||||
TORCH_CHECK(ndim == 3 || ndim == 4,
|
||||
"3D or 4D input tensor expected but got: %s", ndim);
|
||||
|
||||
long nInputPlane = weight.size(1) * group;
|
||||
long inputHeight = input.size(dimh);
|
||||
long inputWidth = input.size(dimw);
|
||||
long nOutputPlane = weight.size(0);
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
|
||||
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
||||
"input channels must divide deformable group size");
|
||||
|
||||
if (outputWidth < 1 || outputHeight < 1)
|
||||
AT_ERROR(
|
||||
"Given input size: (%ld x %ld x %ld). "
|
||||
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
||||
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
||||
outputWidth);
|
||||
|
||||
TORCH_CHECK(input.size(1) == nInputPlane,
|
||||
"invalid number of input planes, expected: %d, but got: %d",
|
||||
nInputPlane, input.size(1));
|
||||
|
||||
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
||||
"input image is smaller than kernel");
|
||||
|
||||
TORCH_CHECK(
|
||||
(offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
||||
"invalid spatial size of offset, expected height: %d width: %d, but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
||||
|
||||
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
||||
"invalid number of channels of offset");
|
||||
|
||||
if (gradOutput != NULL) {
|
||||
TORCH_CHECK(
|
||||
gradOutput->size(dimf) == nOutputPlane,
|
||||
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
||||
nOutputPlane, gradOutput->size(dimf));
|
||||
|
||||
TORCH_CHECK(
|
||||
(gradOutput->size(dimh) == outputHeight &&
|
||||
gradOutput->size(dimw) == outputWidth),
|
||||
"invalid size of gradOutput, expected height: %d width: %d , but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, gradOutput->size(dimh),
|
||||
gradOutput->size(dimw));
|
||||
}
|
||||
}
|
||||
|
||||
void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,
|
||||
Tensor offset, Tensor output,
|
||||
Tensor columns, Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW,
|
||||
int padH, int dilationW, int dilationH,
|
||||
int group, int deformable_group,
|
||||
int im2col_step) {
|
||||
// todo: resize columns to include im2col: done
|
||||
// todo: add im2col_step as input
|
||||
// todo: add new output buffer and transpose it to output (or directly
|
||||
// transpose output) todo: possibly change data indexing because of
|
||||
// parallel_imgs
|
||||
|
||||
deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH,
|
||||
padW, dilationH, dilationW, group, deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
int batch = 1;
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input.unsqueeze_(0);
|
||||
offset.unsqueeze_(0);
|
||||
}
|
||||
|
||||
// todo: assert batchsize dividable by im2col_step
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
||||
outputHeight, outputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
||||
ones = at::ones({outputHeight, outputWidth}, input.options());
|
||||
}
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth},
|
||||
output.options());
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
||||
output_buffer.size(2), output_buffer.size(3)});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output_buffer[elt][g] = output_buffer[elt][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output_buffer[elt][g]);
|
||||
}
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
||||
output_buffer.size(3), output_buffer.size(4)});
|
||||
|
||||
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step, outputHeight, outputWidth});
|
||||
output_buffer.transpose_(1, 2);
|
||||
output.copy_(output_buffer);
|
||||
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
}
|
||||
|
||||
void DeformConvBackwardInputCUDAKernelLauncher(
|
||||
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradInput,
|
||||
Tensor gradOffset, Tensor weight, Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW,
|
||||
padH, padW, dilationH, dilationW, group,
|
||||
deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||||
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
// change order of grad output
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight,
|
||||
outputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
// divide into groups
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
||||
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
||||
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
||||
|
||||
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
||||
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
||||
dilationH, dilationW, im2col_step, deformable_group,
|
||||
gradOffset[elt]);
|
||||
|
||||
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
gradOffset =
|
||||
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
}
|
||||
|
||||
void DeformConvBackwardParametersCUDAKernelLauncher(
|
||||
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight,
|
||||
Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW,
|
||||
int padH, int dilationW, int dilationH, int group, int deformable_group,
|
||||
float scale, int im2col_step) {
|
||||
// todo: transpose and reshape outGrad
|
||||
// todo: reshape columns
|
||||
// todo: add im2col_step as input
|
||||
|
||||
deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH,
|
||||
dW, padH, padW, dilationH, dilationW, group,
|
||||
deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view(
|
||||
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = gradWeight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
||||
outputHeight, outputWidth});
|
||||
gradOutputBuffer = gradOutputBuffer.contiguous();
|
||||
gradOutputBuffer.copy_(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth});
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
||||
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
gradWeight =
|
||||
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
gradWeight[g] = gradWeight[g]
|
||||
.flatten(1)
|
||||
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
||||
columns[g].transpose(1, 0), 1.0, scale)
|
||||
.view_as(gradWeight[g]);
|
||||
}
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0),
|
||||
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
||||
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3),
|
||||
gradWeight.size(4)});
|
||||
}
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -94,194 +94,3 @@ void modulated_deformable_col2im_coord_cuda(
|
|||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void ModulatedDeformConvForwardCUDAKernelLauncher(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, const int group,
|
||||
const int deformable_group, const bool with_bias) {
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_out = weight.size(0);
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
// resize output
|
||||
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||||
// resize temporary columns
|
||||
columns =
|
||||
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||||
input.options());
|
||||
|
||||
output = output.view({output.size(0), group, output.size(1) / group,
|
||||
output.size(2), output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output[b][g] = output[b][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output[b][g]);
|
||||
}
|
||||
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
}
|
||||
|
||||
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||||
output.size(3), output.size(4)});
|
||||
|
||||
if (with_bias) {
|
||||
output += bias.view({1, bias.size(0), 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
void ModulatedDeformConvBackwardCUDAKernelLauncher(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
|
||||
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias) {
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
grad_input = grad_input.view({batch, channels, height, width});
|
||||
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||||
input.options());
|
||||
|
||||
grad_output =
|
||||
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||||
grad_output.size(2), grad_output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
// divide int group
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
|
||||
// gradient w.r.t. input coordinate data
|
||||
modulated_deformable_col2im_coord_cuda(
|
||||
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||
grad_mask[b]);
|
||||
// gradient w.r.t. input data
|
||||
modulated_deformable_col2im_cuda(
|
||||
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||
|
||||
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||
// group
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||||
grad_weight.size(1), grad_weight.size(2),
|
||||
grad_weight.size(3)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
grad_weight[g] =
|
||||
grad_weight[g]
|
||||
.flatten(1)
|
||||
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||||
.view_as(grad_weight[g]);
|
||||
if (with_bias) {
|
||||
grad_bias[g] =
|
||||
grad_bias[g]
|
||||
.view({-1, 1})
|
||||
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||||
.view(-1);
|
||||
}
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||||
grad_weight.size(2), grad_weight.size(3),
|
||||
grad_weight.size(4)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||||
}
|
||||
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||||
grad_output.size(2), grad_output.size(3),
|
||||
grad_output.size(4)});
|
||||
}
|
||||
|
|
|
@ -2,61 +2,152 @@
|
|||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,
|
||||
Tensor offset, Tensor output,
|
||||
Tensor columns, Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW,
|
||||
int padH, int dilationW, int dilationH,
|
||||
int group, int deformable_group,
|
||||
int im2col_step);
|
||||
|
||||
void DeformConvBackwardInputCUDAKernelLauncher(
|
||||
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradInput,
|
||||
Tensor gradOffset, Tensor weight, Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step);
|
||||
void deformable_im2col(Tensor data_im, Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
Tensor data_col);
|
||||
|
||||
void DeformConvBackwardParametersCUDAKernelLauncher(
|
||||
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight,
|
||||
Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW,
|
||||
int padH, int dilationW, int dilationH, int group, int deformable_group,
|
||||
float scale, int im2col_step);
|
||||
void deformable_col2im(Tensor data_col, Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
Tensor grad_im);
|
||||
|
||||
void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset,
|
||||
Tensor output, Tensor columns, Tensor ones,
|
||||
int kW, int kH, int dW, int dH, int padW,
|
||||
int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
DeformConvForwardCUDAKernelLauncher(
|
||||
input, weight, offset, output, columns, ones, kW, kH, dW, dH, padW, padH,
|
||||
dilationW, dilationH, group, deformable_group, im2col_step);
|
||||
}
|
||||
void deformable_col2im_coord(
|
||||
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h, const int ksize_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, Tensor grad_offset);
|
||||
|
||||
void deform_conv_backward_input_cuda(Tensor input, Tensor offset,
|
||||
Tensor gradOutput, Tensor gradInput,
|
||||
Tensor gradOffset, Tensor weight,
|
||||
Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW,
|
||||
int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
DeformConvBackwardInputCUDAKernelLauncher(
|
||||
input, offset, gradOutput, gradInput, gradOffset, weight, columns, kW, kH,
|
||||
dW, dH, padW, padH, dilationW, dilationH, group, deformable_group,
|
||||
im2col_step);
|
||||
}
|
||||
|
||||
void deform_conv_backward_parameters_cuda(
|
||||
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight,
|
||||
Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW,
|
||||
int padH, int dilationW, int dilationH, int group, int deformable_group,
|
||||
float scale, int im2col_step) {
|
||||
DeformConvBackwardParametersCUDAKernelLauncher(
|
||||
input, offset, gradOutput, gradWeight, columns, ones, kW, kH, dW, dH,
|
||||
padW, padH, dilationW, dilationH, group, deformable_group, scale,
|
||||
im2col_step);
|
||||
}
|
||||
#endif
|
||||
|
||||
void deformable_im2col_cpu(Tensor data_im, Tensor data_offset,
|
||||
const int channels, const int height,
|
||||
const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
Tensor data_col);
|
||||
|
||||
void deformable_col2im_cpu(Tensor data_col, Tensor data_offset,
|
||||
const int channels, const int height,
|
||||
const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
Tensor grad_im);
|
||||
|
||||
void deformable_col2im_coord_cpu(
|
||||
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h, const int ksize_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, Tensor grad_offset);
|
||||
|
||||
void deform_conv_shape_check(at::Tensor input, at::Tensor offset,
|
||||
at::Tensor *gradOutput, at::Tensor weight, int kH,
|
||||
int kW, int dH, int dW, int padH, int padW,
|
||||
int dilationH, int dilationW, int group,
|
||||
int deformable_group) {
|
||||
TORCH_CHECK(
|
||||
weight.ndimension() == 4,
|
||||
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s",
|
||||
weight.ndimension());
|
||||
|
||||
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
|
||||
TORCH_CHECK(kW > 0 && kH > 0,
|
||||
"kernel size should be greater than zero, but got kH: %d kW: %d",
|
||||
kH, kW);
|
||||
|
||||
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
||||
"kernel size should be consistent with weight, ",
|
||||
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
|
||||
kH, kW, weight.size(2), weight.size(3));
|
||||
|
||||
TORCH_CHECK(dW > 0 && dH > 0,
|
||||
"stride should be greater than zero, but got dH: %d dW: %d", dH,
|
||||
dW);
|
||||
|
||||
TORCH_CHECK(
|
||||
dilationW > 0 && dilationH > 0,
|
||||
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
||||
dilationH, dilationW);
|
||||
|
||||
int ndim = input.ndimension();
|
||||
int dimf = 0;
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
if (ndim == 4) {
|
||||
dimf++;
|
||||
dimh++;
|
||||
dimw++;
|
||||
}
|
||||
|
||||
TORCH_CHECK(ndim == 3 || ndim == 4,
|
||||
"3D or 4D input tensor expected but got: %s", ndim);
|
||||
|
||||
long nInputPlane = weight.size(1) * group;
|
||||
long inputHeight = input.size(dimh);
|
||||
long inputWidth = input.size(dimw);
|
||||
long nOutputPlane = weight.size(0);
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
|
||||
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
||||
"input channels must divide deformable group size");
|
||||
|
||||
if (outputWidth < 1 || outputHeight < 1)
|
||||
AT_ERROR(
|
||||
"Given input size: (%ld x %ld x %ld). "
|
||||
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
||||
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
||||
outputWidth);
|
||||
|
||||
TORCH_CHECK(input.size(1) == nInputPlane,
|
||||
"invalid number of input planes, expected: %d, but got: %d",
|
||||
nInputPlane, input.size(1));
|
||||
|
||||
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
||||
"input image is smaller than kernel");
|
||||
|
||||
TORCH_CHECK(
|
||||
(offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
||||
"invalid spatial size of offset, expected height: %d width: %d, but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
||||
|
||||
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
||||
"invalid number of channels of offset");
|
||||
|
||||
if (gradOutput != NULL) {
|
||||
TORCH_CHECK(
|
||||
gradOutput->size(dimf) == nOutputPlane,
|
||||
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
||||
nOutputPlane, gradOutput->size(dimf));
|
||||
|
||||
TORCH_CHECK(
|
||||
(gradOutput->size(dimh) == outputHeight &&
|
||||
gradOutput->size(dimw) == outputWidth),
|
||||
"invalid size of gradOutput, expected height: %d width: %d , but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, gradOutput->size(dimh),
|
||||
gradOutput->size(dimw));
|
||||
}
|
||||
}
|
||||
|
||||
void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
|
||||
Tensor output, Tensor columns, Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW, int padH,
|
||||
|
@ -70,15 +161,118 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
|
|||
CHECK_CUDA_INPUT(output);
|
||||
CHECK_CUDA_INPUT(columns);
|
||||
CHECK_CUDA_INPUT(ones);
|
||||
|
||||
deform_conv_forward_cuda(input, weight, offset, output, columns, ones, kW,
|
||||
kH, dW, dH, padW, padH, dilationW, dilationH,
|
||||
group, deformable_group, im2col_step);
|
||||
#else
|
||||
AT_ERROR("DeformConv is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("DeformConv is not implemented on CPU");
|
||||
CHECK_CPU_INPUT(input);
|
||||
CHECK_CPU_INPUT(offset);
|
||||
CHECK_CPU_INPUT(weight);
|
||||
CHECK_CPU_INPUT(output);
|
||||
CHECK_CPU_INPUT(columns);
|
||||
CHECK_CPU_INPUT(ones);
|
||||
}
|
||||
|
||||
deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH,
|
||||
padW, dilationH, dilationW, group, deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
int batch = 1;
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input.unsqueeze_(0);
|
||||
offset.unsqueeze_(0);
|
||||
}
|
||||
|
||||
// todo: assert batchsize dividable by im2col_step
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
||||
outputHeight, outputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
||||
ones = at::ones({outputHeight, outputWidth}, input.options());
|
||||
}
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth},
|
||||
output.options());
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
||||
output_buffer.size(2), output_buffer.size(3)});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
#endif
|
||||
} else {
|
||||
deformable_im2col_cpu(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
}
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output_buffer[elt][g] = output_buffer[elt][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output_buffer[elt][g]);
|
||||
}
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
||||
output_buffer.size(3), output_buffer.size(4)});
|
||||
|
||||
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step, outputHeight, outputWidth});
|
||||
output_buffer.transpose_(1, 2);
|
||||
output.copy_(output_buffer);
|
||||
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -97,16 +291,134 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput,
|
|||
CHECK_CUDA_INPUT(gradOffset);
|
||||
CHECK_CUDA_INPUT(weight);
|
||||
CHECK_CUDA_INPUT(columns);
|
||||
|
||||
deform_conv_backward_input_cuda(input, offset, gradOutput, gradInput,
|
||||
gradOffset, weight, columns, kW, kH, dW, dH,
|
||||
padW, padH, dilationW, dilationH, group,
|
||||
deformable_group, im2col_step);
|
||||
#else
|
||||
AT_ERROR("DeformConv is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("DeformConv is not implemented on CPU");
|
||||
CHECK_CPU_INPUT(input);
|
||||
CHECK_CPU_INPUT(offset);
|
||||
CHECK_CPU_INPUT(gradOutput);
|
||||
CHECK_CPU_INPUT(gradInput);
|
||||
CHECK_CPU_INPUT(gradOffset);
|
||||
CHECK_CPU_INPUT(weight);
|
||||
CHECK_CPU_INPUT(columns);
|
||||
}
|
||||
deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW,
|
||||
padH, padW, dilationH, dilationW, group,
|
||||
deformable_group);
|
||||
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
int batch = 1;
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||||
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
// change order of grad output
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight,
|
||||
outputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
// divide into groups
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
||||
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
||||
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
||||
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
||||
inputHeight, inputWidth, kH, kW, padH, padW, dH,
|
||||
dW, dilationH, dilationW, im2col_step,
|
||||
deformable_group, gradOffset[elt]);
|
||||
|
||||
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group,
|
||||
gradInput[elt]);
|
||||
#endif
|
||||
} else {
|
||||
deformable_col2im_coord_cpu(columns, input[elt], offset[elt], nInputPlane,
|
||||
inputHeight, inputWidth, kH, kW, padH, padW,
|
||||
dH, dW, dilationH, dilationW, im2col_step,
|
||||
deformable_group, gradOffset[elt]);
|
||||
|
||||
deformable_col2im_cpu(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group,
|
||||
gradInput[elt]);
|
||||
}
|
||||
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
}
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
gradOffset =
|
||||
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,15 +437,122 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset,
|
|||
CHECK_CUDA_INPUT(gradWeight);
|
||||
CHECK_CUDA_INPUT(columns);
|
||||
CHECK_CUDA_INPUT(ones);
|
||||
|
||||
deform_conv_backward_parameters_cuda(input, offset, gradOutput, gradWeight,
|
||||
columns, ones, kW, kH, dW, dH, padW,
|
||||
padH, dilationW, dilationH, group,
|
||||
deformable_group, scale, im2col_step);
|
||||
#else
|
||||
AT_ERROR("DeformConv is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("DeformConv is not implemented on CPU");
|
||||
CHECK_CPU_INPUT(input);
|
||||
CHECK_CPU_INPUT(offset);
|
||||
CHECK_CPU_INPUT(gradOutput);
|
||||
CHECK_CPU_INPUT(gradWeight);
|
||||
CHECK_CPU_INPUT(columns);
|
||||
CHECK_CPU_INPUT(ones);
|
||||
}
|
||||
|
||||
deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH,
|
||||
dW, padH, padW, dilationH, dilationW, group,
|
||||
deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view(
|
||||
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = gradWeight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
||||
outputHeight, outputWidth});
|
||||
gradOutputBuffer = gradOutputBuffer.contiguous();
|
||||
gradOutputBuffer.copy_(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth});
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
#endif
|
||||
} else {
|
||||
deformable_im2col_cpu(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
}
|
||||
|
||||
// divide into group
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
||||
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
gradWeight =
|
||||
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
gradWeight[g] = gradWeight[g]
|
||||
.flatten(1)
|
||||
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
||||
columns[g].transpose(1, 0), 1.0, scale)
|
||||
.view_as(gradWeight[g]);
|
||||
}
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0),
|
||||
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
||||
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3),
|
||||
gradWeight.size(4)});
|
||||
}
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,377 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
T deformable_im2col_bilinear_cpu(const T *input, const int data_width,
|
||||
const int height, const int width, T h, T w) {
|
||||
if (h <= -1 || height <= h || w <= -1 || width <= w) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int h_low = floor(h);
|
||||
int w_low = floor(w);
|
||||
int h_high = h_low + 1;
|
||||
int w_high = w_low + 1;
|
||||
|
||||
T lh = h - h_low;
|
||||
T lw = w - w_low;
|
||||
T hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
T v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
|
||||
T v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1)
|
||||
v2 = input[h_low * data_width + w_high];
|
||||
T v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0)
|
||||
v3 = input[h_high * data_width + w_low];
|
||||
T v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1)
|
||||
v4 = input[h_high * data_width + w_high];
|
||||
|
||||
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_gradient_weight_cpu(T argmax_h, T argmax_w, const int h, const int w,
|
||||
const int height, const int width) {
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
||||
argmax_w >= width) {
|
||||
// empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
T weight = 0;
|
||||
if (h == argmax_h_low && w == argmax_w_low)
|
||||
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_low && w == argmax_w_high)
|
||||
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
||||
if (h == argmax_h_high && w == argmax_w_low)
|
||||
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_high && w == argmax_w_high)
|
||||
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_coordinate_weight_cpu(T argmax_h, T argmax_w, const int height,
|
||||
const int width, const T *im_data,
|
||||
const int data_width, const int bp_dir) {
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
||||
argmax_w >= width) {
|
||||
// empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
T weight = 0;
|
||||
|
||||
if (bp_dir == 0) {
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_w_low + 1 - argmax_w) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += -1 * (argmax_w - argmax_w_low) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += (argmax_w_low + 1 - argmax_w) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_w - argmax_w_low) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
} else if (bp_dir == 1) {
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h_low + 1 - argmax_h) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h_low + 1 - argmax_h) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h - argmax_h_low) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h - argmax_h_low) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void deformable_im2col_cpu_kernel(
|
||||
const int n, const T *data_im, const T *data_offset, const int height,
|
||||
const int width, const int kernel_h, const int kernel_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group, const int batch_size,
|
||||
const int num_channels, const int deformable_group, const int height_col,
|
||||
const int width_col, T *data_col) {
|
||||
for (int index = 0; index < n; index++) {
|
||||
// index index of output matrix
|
||||
const int w_col = index % width_col;
|
||||
const int h_col = (index / width_col) % height_col;
|
||||
const int b_col = (index / width_col / height_col) % batch_size;
|
||||
const int c_im = (index / width_col / height_col) / batch_size;
|
||||
const int c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
// compute deformable group index
|
||||
const int deformable_group_index = c_im / channel_per_deformable_group;
|
||||
|
||||
const int h_in = h_col * stride_h - pad_h;
|
||||
const int w_in = w_col * stride_w - pad_w;
|
||||
T *data_col_ptr =
|
||||
data_col +
|
||||
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
||||
const T *data_im_ptr =
|
||||
data_im + (b_col * num_channels + c_im) * height * width;
|
||||
const T *data_offset_ptr =
|
||||
data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
for (int i = 0; i < kernel_h; ++i) {
|
||||
for (int j = 0; j < kernel_w; ++j) {
|
||||
const int data_offset_h_ptr =
|
||||
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
||||
const int data_offset_w_ptr =
|
||||
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
|
||||
w_col;
|
||||
const T offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const T offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
T val = static_cast<T>(0);
|
||||
const T h_im = h_in + i * dilation_h + offset_h;
|
||||
const T w_im = w_in + j * dilation_w + offset_w;
|
||||
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
||||
val = deformable_im2col_bilinear_cpu(data_im_ptr, width, height,
|
||||
width, h_im, w_im);
|
||||
*data_col_ptr = val;
|
||||
data_col_ptr += batch_size * height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void deformable_col2im_cpu_kernel(
|
||||
const int n, const T *data_col, const T *data_offset, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group, const int batch_size,
|
||||
const int deformable_group, const int height_col, const int width_col,
|
||||
T *grad_im) {
|
||||
for (int index = 0; index < n; index++) {
|
||||
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
||||
const int i =
|
||||
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
const int c =
|
||||
index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / channel_per_deformable_group;
|
||||
|
||||
int w_out = index % width_col;
|
||||
int h_out = (index / width_col) % height_col;
|
||||
int b = (index / width_col / height_col) % batch_size;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
|
||||
const T *data_offset_ptr =
|
||||
data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
const int data_offset_h_ptr =
|
||||
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
||||
const int data_offset_w_ptr =
|
||||
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
||||
const T offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const T offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
||||
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
||||
|
||||
const T cur_top_grad = data_col[index];
|
||||
const int cur_h = (int)cur_inv_h_data;
|
||||
const int cur_w = (int)cur_inv_w_data;
|
||||
for (int dy = -2; dy <= 2; dy++) {
|
||||
for (int dx = -2; dx <= 2; dx++) {
|
||||
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
|
||||
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
||||
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
|
||||
int cur_bottom_grad_pos =
|
||||
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
||||
T weight =
|
||||
get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data,
|
||||
cur_h + dy, cur_w + dx, height, width);
|
||||
*(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void deformable_col2im_coord_cpu_kernel(
|
||||
const int n, const T *data_col, const T *data_im, const T *data_offset,
|
||||
const int channels, const int height, const int width, const int kernel_h,
|
||||
const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
|
||||
const int stride_w, const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group, const int batch_size,
|
||||
const int offset_channels, const int deformable_group, const int height_col,
|
||||
const int width_col, T *grad_offset) {
|
||||
for (int index = 0; index < n; index++) {
|
||||
T val = 0;
|
||||
int w = index % width_col;
|
||||
int h = (index / width_col) % height_col;
|
||||
int c = (index / width_col / height_col) % offset_channels;
|
||||
int b = (index / width_col / height_col) / offset_channels;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
||||
const int col_step = kernel_h * kernel_w;
|
||||
int cnt = 0;
|
||||
const T *data_col_ptr = data_col + deformable_group_index *
|
||||
channel_per_deformable_group *
|
||||
batch_size * width_col * height_col;
|
||||
const T *data_im_ptr =
|
||||
data_im + (b * deformable_group + deformable_group_index) *
|
||||
channel_per_deformable_group / kernel_h / kernel_w *
|
||||
height * width;
|
||||
const T *data_offset_ptr =
|
||||
data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
||||
|
||||
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
|
||||
col_c += col_step) {
|
||||
const int col_pos =
|
||||
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
||||
const int bp_dir = offset_c % 2;
|
||||
|
||||
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
||||
int i =
|
||||
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
int w_out = col_pos % width_col;
|
||||
int h_out = (col_pos / width_col) % height_col;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
const int data_offset_h_ptr =
|
||||
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
||||
const int data_offset_w_ptr =
|
||||
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
|
||||
w_out);
|
||||
const T offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const T offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
T inv_h = h_in + i * dilation_h + offset_h;
|
||||
T inv_w = w_in + j * dilation_w + offset_w;
|
||||
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
||||
inv_h = inv_w = -2;
|
||||
const T weight = get_coordinate_weight_cpu(
|
||||
inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
|
||||
width, bp_dir);
|
||||
val += weight * data_col_ptr[col_pos];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
grad_offset[index] = val;
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_im2col_cpu(Tensor data_im, Tensor data_offset,
|
||||
const int channels, const int height,
|
||||
const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
Tensor data_col) {
|
||||
int height_col =
|
||||
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col =
|
||||
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
||||
int channel_per_deformable_group = channels / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_im.scalar_type(), "deformable_im2col_cpu", [&] {
|
||||
deformable_im2col_cpu_kernel<scalar_t>(
|
||||
num_kernels, data_im.data_ptr<scalar_t>(),
|
||||
data_offset.data_ptr<scalar_t>(), height, width, ksize_h, ksize_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
||||
channel_per_deformable_group, parallel_imgs, channels,
|
||||
deformable_group, height_col, width_col,
|
||||
data_col.data_ptr<scalar_t>());
|
||||
});
|
||||
}
|
||||
|
||||
void deformable_col2im_cpu(Tensor data_col, Tensor data_offset,
|
||||
const int channels, const int height,
|
||||
const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
Tensor grad_im) {
|
||||
// todo: make sure parallel_imgs is passed in correctly
|
||||
int height_col =
|
||||
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col =
|
||||
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels =
|
||||
channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
||||
int channel_per_deformable_group = channels / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
||||
|
||||
deformable_col2im_cpu_kernel<scalar_t>(
|
||||
num_kernels, data_col_, data_offset_, channels, height, width,
|
||||
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
|
||||
dilation_w, channel_per_deformable_group, parallel_imgs,
|
||||
deformable_group, height_col, width_col, grad_im_);
|
||||
}));
|
||||
}
|
||||
|
||||
void deformable_col2im_coord_cpu(
|
||||
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h, const int ksize_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, Tensor grad_offset) {
|
||||
int height_col =
|
||||
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col =
|
||||
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
|
||||
deformable_group * parallel_imgs;
|
||||
int channel_per_deformable_group =
|
||||
channels * ksize_h * ksize_w / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "deformable_col2im_coord_cpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
||||
|
||||
deformable_col2im_coord_cpu_kernel<scalar_t>(
|
||||
num_kernels, data_col_, data_im_, data_offset_, channels, height,
|
||||
width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
|
||||
2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
||||
height_col, width_col, grad_offset_);
|
||||
}));
|
||||
}
|
|
@ -2,48 +2,59 @@
|
|||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
void ModulatedDeformConvForwardCUDAKernelLauncher(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, const int group,
|
||||
const int deformable_group, const bool with_bias);
|
||||
|
||||
void ModulatedDeformConvBackwardCUDAKernelLauncher(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
|
||||
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias);
|
||||
void modulated_deformable_im2col_cuda(
|
||||
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor data_col);
|
||||
|
||||
void modulated_deform_conv_forward_cuda(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, const int group,
|
||||
const int deformable_group, const bool with_bias) {
|
||||
ModulatedDeformConvForwardCUDAKernelLauncher(
|
||||
input, weight, bias, ones, offset, mask, output, columns, kernel_h,
|
||||
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
deformable_group, with_bias);
|
||||
}
|
||||
void modulated_deformable_col2im_cuda(
|
||||
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor grad_im);
|
||||
|
||||
void modulated_deformable_col2im_coord_cuda(
|
||||
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
|
||||
const Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
Tensor grad_offset, Tensor grad_mask);
|
||||
|
||||
void modulated_deform_conv_backward_cuda(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
|
||||
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias) {
|
||||
ModulatedDeformConvBackwardCUDAKernelLauncher(
|
||||
input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
|
||||
grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
|
||||
stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
deformable_group, with_bias);
|
||||
}
|
||||
#endif
|
||||
|
||||
void modulated_deformable_im2col_cpu(
|
||||
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor data_col);
|
||||
|
||||
void modulated_deformable_col2im_cpu(
|
||||
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor grad_im);
|
||||
|
||||
void modulated_deformable_col2im_coord_cpu(
|
||||
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
|
||||
const Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
Tensor grad_offset, Tensor grad_mask);
|
||||
|
||||
void modulated_deform_conv_forward(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
|
||||
|
@ -61,15 +72,98 @@ void modulated_deform_conv_forward(
|
|||
CHECK_CUDA_INPUT(output);
|
||||
CHECK_CUDA_INPUT(columns);
|
||||
|
||||
modulated_deform_conv_forward_cuda(
|
||||
input, weight, bias, ones, offset, mask, output, columns, kernel_h,
|
||||
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
|
||||
group, deformable_group, with_bias);
|
||||
#else
|
||||
AT_ERROR("ModulatedDeformConv is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("ModulatedDeformConv is not implemented on CPU");
|
||||
CHECK_CPU_INPUT(input);
|
||||
CHECK_CPU_INPUT(weight);
|
||||
CHECK_CPU_INPUT(bias);
|
||||
CHECK_CPU_INPUT(ones);
|
||||
CHECK_CPU_INPUT(offset);
|
||||
CHECK_CPU_INPUT(mask);
|
||||
CHECK_CPU_INPUT(output);
|
||||
CHECK_CPU_INPUT(columns);
|
||||
}
|
||||
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_out = weight.size(0);
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
// resize output
|
||||
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||||
// resize temporary columns
|
||||
columns =
|
||||
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||||
input.options());
|
||||
|
||||
output = output.view({output.size(0), group, output.size(1) / group,
|
||||
output.size(2), output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
#endif
|
||||
} else {
|
||||
modulated_deformable_im2col_cpu(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
}
|
||||
|
||||
// divide into group
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output[b][g] = output[b][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output[b][g]);
|
||||
}
|
||||
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
}
|
||||
|
||||
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||||
output.size(3), output.size(4)});
|
||||
|
||||
if (with_bias) {
|
||||
output += bias.view({1, bias.size(0), 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,15 +190,149 @@ void modulated_deform_conv_backward(
|
|||
CHECK_CUDA_INPUT(grad_mask);
|
||||
CHECK_CUDA_INPUT(grad_output);
|
||||
|
||||
modulated_deform_conv_backward_cuda(
|
||||
input, weight, bias, ones, offset, mask, columns, grad_input,
|
||||
grad_weight, grad_bias, grad_offset, grad_mask, grad_output, kernel_h,
|
||||
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
|
||||
group, deformable_group, with_bias);
|
||||
#else
|
||||
AT_ERROR("ModulatedDeformConv is not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
AT_ERROR("ModulatedDeformConv is not implemented on CPU");
|
||||
CHECK_CPU_INPUT(input);
|
||||
CHECK_CPU_INPUT(weight);
|
||||
CHECK_CPU_INPUT(bias);
|
||||
CHECK_CPU_INPUT(ones);
|
||||
CHECK_CPU_INPUT(offset);
|
||||
CHECK_CPU_INPUT(mask);
|
||||
CHECK_CPU_INPUT(columns);
|
||||
CHECK_CPU_INPUT(grad_input);
|
||||
CHECK_CPU_INPUT(grad_weight);
|
||||
CHECK_CPU_INPUT(grad_bias);
|
||||
CHECK_CPU_INPUT(grad_offset);
|
||||
CHECK_CPU_INPUT(grad_mask);
|
||||
CHECK_CPU_INPUT(grad_output);
|
||||
}
|
||||
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
grad_input = grad_input.view({batch, channels, height, width});
|
||||
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||||
input.options());
|
||||
|
||||
grad_output =
|
||||
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||||
grad_output.size(2), grad_output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
// divide int group
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
// gradient w.r.t. input coordinate data
|
||||
modulated_deformable_col2im_coord_cuda(
|
||||
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||
grad_mask[b]);
|
||||
// gradient w.r.t. input data
|
||||
modulated_deformable_col2im_cuda(
|
||||
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||
|
||||
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||
// group
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
#endif
|
||||
} else {
|
||||
// gradient w.r.t. input coordinate data
|
||||
modulated_deformable_col2im_coord_cpu(
|
||||
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||
grad_mask[b]);
|
||||
// gradient w.r.t. input data
|
||||
modulated_deformable_col2im_cpu(
|
||||
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||
// group
|
||||
modulated_deformable_im2col_cpu(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
}
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||||
grad_weight.size(1), grad_weight.size(2),
|
||||
grad_weight.size(3)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
grad_weight[g] =
|
||||
grad_weight[g]
|
||||
.flatten(1)
|
||||
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||||
.view_as(grad_weight[g]);
|
||||
if (with_bias) {
|
||||
grad_bias[g] =
|
||||
grad_bias[g]
|
||||
.view({-1, 1})
|
||||
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||||
.view(-1);
|
||||
}
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||||
grad_weight.size(2), grad_weight.size(3),
|
||||
grad_weight.size(4)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||||
}
|
||||
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||||
grad_output.size(2), grad_output.size(3),
|
||||
grad_output.size(4)});
|
||||
}
|
||||
|
|
|
@ -0,0 +1,403 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
T dmcn_im2col_bilinear_cpu(const T *input, const int data_width,
|
||||
const int height, const int width, T h, T w) {
|
||||
int h_low = floorf(h);
|
||||
int w_low = floorf(w);
|
||||
int h_high = h_low + 1;
|
||||
int w_high = w_low + 1;
|
||||
|
||||
T lh = h - h_low;
|
||||
T lw = w - w_low;
|
||||
T hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
T v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
|
||||
T v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1)
|
||||
v2 = input[h_low * data_width + w_high];
|
||||
T v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0)
|
||||
v3 = input[h_high * data_width + w_low];
|
||||
T v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1)
|
||||
v4 = input[h_high * data_width + w_high];
|
||||
|
||||
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T dmcn_get_gradient_weight_cpu(T argmax_h, T argmax_w, const int h, const int w,
|
||||
const int height, const int width) {
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
||||
argmax_w >= width) {
|
||||
// empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floorf(argmax_h);
|
||||
int argmax_w_low = floorf(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
T weight = 0;
|
||||
if (h == argmax_h_low && w == argmax_w_low)
|
||||
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_low && w == argmax_w_high)
|
||||
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
||||
if (h == argmax_h_high && w == argmax_w_low)
|
||||
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_high && w == argmax_w_high)
|
||||
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T dmcn_get_coordinate_weight_cpu(T argmax_h, T argmax_w, const int height,
|
||||
const int width, const T *im_data,
|
||||
const int data_width, const int bp_dir) {
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
||||
argmax_w >= width) {
|
||||
// empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floorf(argmax_h);
|
||||
int argmax_w_low = floorf(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
T weight = 0;
|
||||
|
||||
if (bp_dir == 0) {
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_w_low + 1 - argmax_w) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += -1 * (argmax_w - argmax_w_low) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += (argmax_w_low + 1 - argmax_w) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_w - argmax_w_low) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
} else if (bp_dir == 1) {
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h_low + 1 - argmax_h) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h_low + 1 - argmax_h) *
|
||||
im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h - argmax_h_low) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h - argmax_h_low) *
|
||||
im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void modulated_deformable_im2col_cpu_kernel(
|
||||
const int n, const T *data_im, const T *data_offset, const T *data_mask,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group, const int batch_size,
|
||||
const int num_channels, const int deformable_group, const int height_col,
|
||||
const int width_col, T *data_col) {
|
||||
for (int index = 0; index < n; index++) {
|
||||
// index index of output matrix
|
||||
const int w_col = index % width_col;
|
||||
const int h_col = (index / width_col) % height_col;
|
||||
const int b_col = (index / width_col / height_col) % batch_size;
|
||||
const int c_im = (index / width_col / height_col) / batch_size;
|
||||
const int c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
// compute deformable group index
|
||||
const int deformable_group_index = c_im / channel_per_deformable_group;
|
||||
|
||||
const int h_in = h_col * stride_h - pad_h;
|
||||
const int w_in = w_col * stride_w - pad_w;
|
||||
|
||||
T *data_col_ptr =
|
||||
data_col +
|
||||
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
||||
const T *data_im_ptr =
|
||||
data_im + (b_col * num_channels + c_im) * height * width;
|
||||
const T *data_offset_ptr =
|
||||
data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const T *data_mask_ptr =
|
||||
data_mask + (b_col * deformable_group + deformable_group_index) *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
for (int i = 0; i < kernel_h; ++i) {
|
||||
for (int j = 0; j < kernel_w; ++j) {
|
||||
const int data_offset_h_ptr =
|
||||
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
||||
const int data_offset_w_ptr =
|
||||
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
|
||||
w_col;
|
||||
const int data_mask_hw_ptr =
|
||||
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
||||
const T offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const T offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const T mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
T val = static_cast<T>(0);
|
||||
const T h_im = h_in + i * dilation_h + offset_h;
|
||||
const T w_im = w_in + j * dilation_w + offset_w;
|
||||
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
||||
val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width,
|
||||
h_im, w_im);
|
||||
*data_col_ptr = val * mask;
|
||||
data_col_ptr += batch_size * height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void modulated_deformable_col2im_cpu_kernel(
|
||||
const int n, const T *data_col, const T *data_offset, const T *data_mask,
|
||||
const int channels, const int height, const int width, const int kernel_h,
|
||||
const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
|
||||
const int stride_w, const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group, const int batch_size,
|
||||
const int deformable_group, const int height_col, const int width_col,
|
||||
T *grad_im) {
|
||||
for (int index = 0; index < n; index++) {
|
||||
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
||||
const int i =
|
||||
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
const int c =
|
||||
index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / channel_per_deformable_group;
|
||||
|
||||
int w_out = index % width_col;
|
||||
int h_out = (index / width_col) % height_col;
|
||||
int b = (index / width_col / height_col) % batch_size;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
|
||||
const T *data_offset_ptr =
|
||||
data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
const T *data_mask_ptr =
|
||||
data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
|
||||
kernel_w * height_col * width_col;
|
||||
const int data_offset_h_ptr =
|
||||
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
||||
const int data_offset_w_ptr =
|
||||
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
||||
const int data_mask_hw_ptr =
|
||||
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
||||
const T offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const T offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const T mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
||||
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
||||
|
||||
const T cur_top_grad = data_col[index] * mask;
|
||||
const int cur_h = (int)cur_inv_h_data;
|
||||
const int cur_w = (int)cur_inv_w_data;
|
||||
for (int dy = -2; dy <= 2; dy++) {
|
||||
for (int dx = -2; dx <= 2; dx++) {
|
||||
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
|
||||
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
||||
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
|
||||
int cur_bottom_grad_pos =
|
||||
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
||||
T weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data,
|
||||
cur_inv_w_data, cur_h + dy,
|
||||
cur_w + dx, height, width);
|
||||
*(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void modulated_deformable_col2im_coord_cpu_kernel(
|
||||
const int n, const T *data_col, const T *data_im, const T *data_offset,
|
||||
const T *data_mask, const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int channel_per_deformable_group,
|
||||
const int batch_size, const int offset_channels, const int deformable_group,
|
||||
const int height_col, const int width_col, T *grad_offset, T *grad_mask) {
|
||||
for (int index = 0; index < n; index++) {
|
||||
T val = 0, mval = 0;
|
||||
int w = index % width_col;
|
||||
int h = (index / width_col) % height_col;
|
||||
int c = (index / width_col / height_col) % offset_channels;
|
||||
int b = (index / width_col / height_col) / offset_channels;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
||||
const int col_step = kernel_h * kernel_w;
|
||||
int cnt = 0;
|
||||
const T *data_col_ptr = data_col + deformable_group_index *
|
||||
channel_per_deformable_group *
|
||||
batch_size * width_col * height_col;
|
||||
const T *data_im_ptr =
|
||||
data_im + (b * deformable_group + deformable_group_index) *
|
||||
channel_per_deformable_group / kernel_h / kernel_w *
|
||||
height * width;
|
||||
const T *data_offset_ptr =
|
||||
data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
const T *data_mask_ptr =
|
||||
data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
|
||||
kernel_w * height_col * width_col;
|
||||
|
||||
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
||||
|
||||
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
|
||||
col_c += col_step) {
|
||||
const int col_pos =
|
||||
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
||||
const int bp_dir = offset_c % 2;
|
||||
|
||||
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
||||
int i =
|
||||
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
int w_out = col_pos % width_col;
|
||||
int h_out = (col_pos / width_col) % height_col;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
const int data_offset_h_ptr =
|
||||
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
||||
const int data_offset_w_ptr =
|
||||
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
|
||||
w_out);
|
||||
const int data_mask_hw_ptr =
|
||||
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
||||
const T offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const T offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const T mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
T inv_h = h_in + i * dilation_h + offset_h;
|
||||
T inv_w = w_in + j * dilation_w + offset_w;
|
||||
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
||||
inv_h = inv_w = -2;
|
||||
else
|
||||
mval += data_col_ptr[col_pos] *
|
||||
dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width,
|
||||
width, height, width, inv_h, inv_w);
|
||||
const T weight = dmcn_get_coordinate_weight_cpu(
|
||||
inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
|
||||
width, bp_dir);
|
||||
val += weight * data_col_ptr[col_pos] * mask;
|
||||
cnt += 1;
|
||||
}
|
||||
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
||||
grad_offset[index] = val;
|
||||
if (offset_c % 2 == 0)
|
||||
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group +
|
||||
// deformable_group_index) * kernel_h * kernel_w + offset_c / 2) *
|
||||
// height_col + h) * width_col + w], mask_req, mval);
|
||||
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
|
||||
kernel_w +
|
||||
offset_c / 2) *
|
||||
height_col +
|
||||
h) *
|
||||
width_col +
|
||||
w] = mval;
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_im2col_cpu(
|
||||
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor data_col) {
|
||||
// num_axes should be smaller than block size
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels = channels * batch_size * height_col * width_col;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_im.scalar_type(), "modulated_deformable_im2col_cpu", ([&] {
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
||||
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
|
||||
modulated_deformable_im2col_cpu_kernel(
|
||||
num_kernels, data_im_, data_offset_, data_mask_, height_im,
|
||||
width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group, batch_size,
|
||||
channels, deformable_group, height_col, width_col, data_col_);
|
||||
}));
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_cpu(
|
||||
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor grad_im) {
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels =
|
||||
channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "modulated_deformable_col2im_cpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
||||
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
||||
|
||||
modulated_deformable_col2im_cpu_kernel(
|
||||
num_kernels, data_col_, data_offset_, data_mask_, channels,
|
||||
height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, deformable_group, height_col, width_col, grad_im_);
|
||||
}));
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_coord_cpu(
|
||||
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
|
||||
const Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
Tensor grad_offset, Tensor grad_mask) {
|
||||
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h *
|
||||
kernel_w * deformable_group;
|
||||
const int channel_per_deformable_group =
|
||||
channels * kernel_h * kernel_w / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "modulated_deformable_col2im_coord_cpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
||||
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
||||
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
|
||||
|
||||
modulated_deformable_col2im_coord_cpu_kernel(
|
||||
num_kernels, data_col_, data_im_, data_offset_, data_mask_,
|
||||
channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w,
|
||||
channel_per_deformable_group, batch_size,
|
||||
2 * kernel_h * kernel_w * deformable_group, deformable_group,
|
||||
height_col, width_col, grad_offset_, grad_mask_);
|
||||
}));
|
||||
}
|
|
@ -36,13 +36,16 @@ gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]]
|
|||
|
||||
class TestDeformconv(object):
|
||||
|
||||
def _test_deformconv(self, dtype=torch.float, threshold=1e-3):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
def _test_deformconv(self,
|
||||
dtype=torch.float,
|
||||
threshold=1e-3,
|
||||
device='cuda'):
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
pytest.skip('test requires GPU')
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
c_in = 1
|
||||
c_out = 1
|
||||
x = torch.Tensor(input).cuda().type(dtype)
|
||||
x = torch.tensor(input, device=device, dtype=dtype)
|
||||
x.requires_grad = True
|
||||
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
|
||||
model.conv_offset.weight.data = torch.nn.Parameter(
|
||||
|
@ -51,7 +54,9 @@ class TestDeformconv(object):
|
|||
torch.Tensor(offset_bias).reshape(8))
|
||||
model.weight.data = torch.nn.Parameter(
|
||||
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
|
||||
model.cuda().type(dtype)
|
||||
if device == 'cuda':
|
||||
model.cuda()
|
||||
model.type(dtype)
|
||||
|
||||
out = model(x)
|
||||
out.backward(torch.ones_like(out))
|
||||
|
@ -67,6 +72,7 @@ class TestDeformconv(object):
|
|||
gt_deform_weight_grad, threshold)
|
||||
|
||||
from mmcv.ops import DeformConv2d
|
||||
|
||||
# test bias
|
||||
model = DeformConv2d(1, 1, 2, stride=1, padding=0)
|
||||
assert not hasattr(model, 'bias')
|
||||
|
@ -121,6 +127,7 @@ class TestDeformconv(object):
|
|||
gt_deform_weight_grad, threshold)
|
||||
|
||||
from mmcv.ops import DeformConv2d
|
||||
|
||||
# test bias
|
||||
model = DeformConv2d(1, 1, 2, stride=1, padding=0)
|
||||
assert not hasattr(model, 'bias')
|
||||
|
@ -135,9 +142,11 @@ class TestDeformconv(object):
|
|||
model = DeformConv2d(3, 4, 3, groups=3)
|
||||
|
||||
def test_deformconv(self):
|
||||
self._test_deformconv(torch.double, device='cpu')
|
||||
self._test_deformconv(torch.float, device='cpu', threshold=1e-1)
|
||||
self._test_deformconv(torch.double)
|
||||
self._test_deformconv(torch.float)
|
||||
self._test_deformconv(torch.half, 1e-1)
|
||||
self._test_deformconv(torch.half, threshold=1e-1)
|
||||
|
||||
# test amp when torch version >= '1.6.0', the type of
|
||||
# input data for deformconv might be torch.float or torch.half
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
|
@ -37,11 +38,11 @@ dcn_offset_b_grad = [
|
|||
|
||||
class TestMdconv(object):
|
||||
|
||||
def _test_mdconv(self, dtype=torch.float):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
def _test_mdconv(self, dtype=torch.float, device='cuda'):
|
||||
if not torch.cuda.is_available() and device == 'cuda':
|
||||
pytest.skip('test requires GPU')
|
||||
from mmcv.ops import ModulatedDeformConv2dPack
|
||||
input = torch.tensor(input_t).cuda().type(dtype)
|
||||
input = torch.tensor(input_t, dtype=dtype, device=device)
|
||||
input.requires_grad = True
|
||||
|
||||
dcn = ModulatedDeformConv2dPack(
|
||||
|
@ -51,7 +52,11 @@ class TestMdconv(object):
|
|||
stride=1,
|
||||
padding=1,
|
||||
deform_groups=1,
|
||||
bias=False).cuda()
|
||||
bias=False)
|
||||
|
||||
if device == 'cuda':
|
||||
dcn.cuda()
|
||||
|
||||
dcn.weight.data.fill_(1.)
|
||||
dcn.type(dtype)
|
||||
output = dcn(input)
|
||||
|
@ -106,6 +111,8 @@ class TestMdconv(object):
|
|||
dcn_offset_b_grad, 1e-2)
|
||||
|
||||
def test_mdconv(self):
|
||||
self._test_mdconv(torch.double, device='cpu')
|
||||
self._test_mdconv(torch.float, device='cpu')
|
||||
self._test_mdconv(torch.double)
|
||||
self._test_mdconv(torch.float)
|
||||
self._test_mdconv(torch.half)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -33,9 +32,10 @@ def test_sacconv():
|
|||
refer_out = refer_conv(x)
|
||||
assert deform_sac_out.shape == refer_out.shape
|
||||
else:
|
||||
with pytest.raises(RuntimeError):
|
||||
# deform conv is not implemented on cpu
|
||||
deform_saconv(x)
|
||||
deform_sac_out = deform_saconv(x)
|
||||
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
|
||||
refer_out = refer_conv(x)
|
||||
assert deform_sac_out.shape == refer_out.shape
|
||||
|
||||
# test with groups >= 2
|
||||
x = torch.rand(1, 4, 256, 256)
|
||||
|
|
Loading…
Reference in New Issue