From e621e08d541634fd286ce0d314c3a05f4ead7e73 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Sun, 29 Aug 2021 13:48:31 +0100 Subject: [PATCH] Add DCN and Modulated DCN CPU implementation (#1278) * DCN cpu version * add modulated dcn cpu version * move deform_conv_shape_check to deform conv utils * add inline to deform_conv_shape_check * add tests * run linter * add newline at file end * run pre-commit against modulated deform conv cpp * update saconv test * run clang-format * remove cuda device inline * refactor dcn cuda/cpu functions * remove DCN util * remove DCN util hpp from all included files * Addressing PR comment by refactoring modulated-DCN * fix lint in cpp files --- .../ops/csrc/pytorch/cuda/deform_conv_cuda.cu | 419 ------------- .../cuda/modulated_deform_conv_cuda.cu | 191 ------ mmcv/ops/csrc/pytorch/deform_conv.cpp | 551 +++++++++++++++--- mmcv/ops/csrc/pytorch/deform_conv_cpu.cpp | 377 ++++++++++++ .../csrc/pytorch/modulated_deform_conv.cpp | 324 ++++++++-- .../pytorch/modulated_deform_conv_cpu.cpp | 403 +++++++++++++ tests/test_ops/test_deform_conv.py | 21 +- tests/test_ops/test_modulated_deform_conv.py | 17 +- tests/test_ops/test_saconv.py | 8 +- 9 files changed, 1572 insertions(+), 739 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/deform_conv_cpu.cpp create mode 100644 mmcv/ops/csrc/pytorch/modulated_deform_conv_cpu.cpp diff --git a/mmcv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu index 2a1aedcf6..24cf65e9e 100644 --- a/mmcv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/deform_conv_cuda.cu @@ -101,422 +101,3 @@ void deformable_col2im_coord( })); AT_CUDA_CHECK(cudaGetLastError()); } - -void deform_conv_shape_check(Tensor input, Tensor offset, Tensor *gradOutput, - Tensor weight, int kH, int kW, int dH, int dW, - int padH, int padW, int dilationH, int dilationW, - int group, int deformable_group) { - TORCH_CHECK( - weight.ndimension() == 4, - "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s", - weight.ndimension()); - - TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - - TORCH_CHECK(kW > 0 && kH > 0, - "kernel size should be greater than zero, but got kH: %d kW: %d", - kH, kW); - - TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), - "kernel size should be consistent with weight, ", - "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", - kH, kW, weight.size(2), weight.size(3)); - - TORCH_CHECK(dW > 0 && dH > 0, - "stride should be greater than zero, but got dH: %d dW: %d", dH, - dW); - - TORCH_CHECK( - dilationW > 0 && dilationH > 0, - "dilation should be greater than 0, but got dilationH: %d dilationW: %d", - dilationH, dilationW); - - int ndim = input.ndimension(); - int dimf = 0; - int dimh = 1; - int dimw = 2; - - if (ndim == 4) { - dimf++; - dimh++; - dimw++; - } - - TORCH_CHECK(ndim == 3 || ndim == 4, - "3D or 4D input tensor expected but got: %s", ndim); - - long nInputPlane = weight.size(1) * group; - long inputHeight = input.size(dimh); - long inputWidth = input.size(dimw); - long nOutputPlane = weight.size(0); - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - - TORCH_CHECK(nInputPlane % deformable_group == 0, - "input channels must divide deformable group size"); - - if (outputWidth < 1 || outputHeight < 1) - AT_ERROR( - "Given input size: (%ld x %ld x %ld). " - "Calculated output size: (%ld x %ld x %ld). Output size is too small", - nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, - outputWidth); - - TORCH_CHECK(input.size(1) == nInputPlane, - "invalid number of input planes, expected: %d, but got: %d", - nInputPlane, input.size(1)); - - TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), - "input image is smaller than kernel"); - - TORCH_CHECK( - (offset.size(2) == outputHeight && offset.size(3) == outputWidth), - "invalid spatial size of offset, expected height: %d width: %d, but " - "got height: %d width: %d", - outputHeight, outputWidth, offset.size(2), offset.size(3)); - - TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), - "invalid number of channels of offset"); - - if (gradOutput != NULL) { - TORCH_CHECK( - gradOutput->size(dimf) == nOutputPlane, - "invalid number of gradOutput planes, expected: %d, but got: %d", - nOutputPlane, gradOutput->size(dimf)); - - TORCH_CHECK( - (gradOutput->size(dimh) == outputHeight && - gradOutput->size(dimw) == outputWidth), - "invalid size of gradOutput, expected height: %d width: %d , but " - "got height: %d width: %d", - outputHeight, outputWidth, gradOutput->size(dimh), - gradOutput->size(dimw)); - } -} - -void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight, - Tensor offset, Tensor output, - Tensor columns, Tensor ones, int kW, - int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, - int group, int deformable_group, - int im2col_step) { - // todo: resize columns to include im2col: done - // todo: add im2col_step as input - // todo: add new output buffer and transpose it to output (or directly - // transpose output) todo: possibly change data indexing because of - // parallel_imgs - - deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, - padW, dilationH, dilationW, group, deformable_group); - at::DeviceGuard guard(input.device()); - - int batch = 1; - if (input.ndimension() == 3) { - // Force batch - batch = 0; - input.unsqueeze_(0); - offset.unsqueeze_(0); - } - - // todo: assert batchsize dividable by im2col_step - - long batchSize = input.size(0); - long nInputPlane = input.size(1); - long inputHeight = input.size(2); - long inputWidth = input.size(3); - - long nOutputPlane = weight.size(0); - - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); - - output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, - outputHeight, outputWidth}); - columns = at::zeros( - {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, - input.options()); - - if (ones.ndimension() != 2 || - ones.size(0) * ones.size(1) < outputHeight * outputWidth) { - ones = at::ones({outputHeight, outputWidth}, input.options()); - } - - input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - offset = - offset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, - im2col_step * outputHeight, outputWidth}, - output.options()); - - output_buffer = output_buffer.view( - {output_buffer.size(0), group, output_buffer.size(1) / group, - output_buffer.size(2), output_buffer.size(3)}); - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, - dilationW, im2col_step, deformable_group, columns); - - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - - for (int g = 0; g < group; g++) { - output_buffer[elt][g] = output_buffer[elt][g] - .flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(output_buffer[elt][g]); - } - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), - weight.size(3), weight.size(4)}); - } - - output_buffer = output_buffer.view( - {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), - output_buffer.size(3), output_buffer.size(4)}); - - output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, - im2col_step, outputHeight, outputWidth}); - output_buffer.transpose_(1, 2); - output.copy_(output_buffer); - output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); - - input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); - offset = offset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - if (batch == 0) { - output = output.view({nOutputPlane, outputHeight, outputWidth}); - input = input.view({nInputPlane, inputHeight, inputWidth}); - offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); - } -} - -void DeformConvBackwardInputCUDAKernelLauncher( - Tensor input, Tensor offset, Tensor gradOutput, Tensor gradInput, - Tensor gradOffset, Tensor weight, Tensor columns, int kW, int kH, int dW, - int dH, int padW, int padH, int dilationW, int dilationH, int group, - int deformable_group, int im2col_step) { - deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, - padH, padW, dilationH, dilationW, group, - deformable_group); - at::DeviceGuard guard(input.device()); - - int batch = 1; - - if (input.ndimension() == 3) { - // Force batch - batch = 0; - input = input.view({1, input.size(0), input.size(1), input.size(2)}); - offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); - gradOutput = gradOutput.view( - {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); - } - - long batchSize = input.size(0); - long nInputPlane = input.size(1); - long inputHeight = input.size(2); - long inputWidth = input.size(3); - - long nOutputPlane = weight.size(0); - - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); - gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); - columns = at::zeros( - {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, - input.options()); - - // change order of grad output - gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, - nOutputPlane, outputHeight, outputWidth}); - gradOutput.transpose_(1, 2); - - gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, - outputWidth}); - offset = - offset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - // divide into groups - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - gradOutput = gradOutput.view( - {gradOutput.size(0), group, gradOutput.size(1) / group, - gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); - - for (int g = 0; g < group; g++) { - columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), - gradOutput[elt][g].flatten(1), 0.0f, 1.0f); - } - - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - gradOutput = gradOutput.view( - {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), - gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); - - deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, - inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, - dilationH, dilationW, im2col_step, deformable_group, - gradOffset[elt]); - - deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, - dilationW, im2col_step, deformable_group, gradInput[elt]); - weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), - weight.size(3), weight.size(4)}); - } - - gradOutput.transpose_(1, 2); - gradOutput = - gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); - - gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); - input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); - gradOffset = gradOffset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - offset = offset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - if (batch == 0) { - gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); - input = input.view({nInputPlane, inputHeight, inputWidth}); - gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); - offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); - gradOffset = - gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); - } -} - -void DeformConvBackwardParametersCUDAKernelLauncher( - Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight, - Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - float scale, int im2col_step) { - // todo: transpose and reshape outGrad - // todo: reshape columns - // todo: add im2col_step as input - - deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, - dW, padH, padW, dilationH, dilationW, group, - deformable_group); - at::DeviceGuard guard(input.device()); - - int batch = 1; - - if (input.ndimension() == 3) { - // Force batch - batch = 0; - input = input.view( - at::IntList({1, input.size(0), input.size(1), input.size(2)})); - gradOutput = gradOutput.view( - {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); - } - - long batchSize = input.size(0); - long nInputPlane = input.size(1); - long inputHeight = input.size(2); - long inputWidth = input.size(3); - - long nOutputPlane = gradWeight.size(0); - - long outputWidth = - (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = - (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); - - columns = at::zeros( - {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, - input.options()); - - gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, - nOutputPlane, outputHeight, outputWidth}); - gradOutput.transpose_(1, 2); - - Tensor gradOutputBuffer = at::zeros_like(gradOutput); - gradOutputBuffer = - gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, - outputHeight, outputWidth}); - gradOutputBuffer = gradOutputBuffer.contiguous(); - gradOutputBuffer.copy_(gradOutput); - gradOutputBuffer = - gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, - im2col_step * outputHeight, outputWidth}); - - gradOutput.transpose_(1, 2); - gradOutput = - gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); - - input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, - inputHeight, inputWidth}); - offset = - offset.view({batchSize / im2col_step, im2col_step, - deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, - dilationW, im2col_step, deformable_group, columns); - - // divide into group - gradOutputBuffer = gradOutputBuffer.view( - {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, - gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - gradWeight = - gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), - gradWeight.size(2), gradWeight.size(3)}); - - for (int g = 0; g < group; g++) { - gradWeight[g] = gradWeight[g] - .flatten(1) - .addmm_(gradOutputBuffer[elt][g].flatten(1), - columns[g].transpose(1, 0), 1.0, scale) - .view_as(gradWeight[g]); - } - gradOutputBuffer = gradOutputBuffer.view( - {gradOutputBuffer.size(0), - gradOutputBuffer.size(1) * gradOutputBuffer.size(2), - gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), - gradWeight.size(2), gradWeight.size(3), - gradWeight.size(4)}); - } - - input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); - offset = offset.view( - {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); - - if (batch == 0) { - gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); - input = input.view({nInputPlane, inputHeight, inputWidth}); - } -} diff --git a/mmcv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu index 6b7d64fc7..27b706702 100644 --- a/mmcv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/modulated_deform_conv_cuda.cu @@ -94,194 +94,3 @@ void modulated_deformable_col2im_coord_cuda( })); AT_CUDA_CHECK(cudaGetLastError()); } - -void ModulatedDeformConvForwardCUDAKernelLauncher( - Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, - Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, - const int stride_h, const int stride_w, const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, const int group, - const int deformable_group, const bool with_bias) { - at::DeviceGuard guard(input.device()); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_out = weight.size(0); - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel * group); - - const int height_out = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (ones.ndimension() != 2 || - ones.size(0) * ones.size(1) < height_out * width_out) { - // Resize plane and fill with ones... - ones = at::ones({height_out, width_out}, input.options()); - } - - // resize output - output = output.view({batch, channels_out, height_out, width_out}).zero_(); - // resize temporary columns - columns = - at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, - input.options()); - - output = output.view({output.size(0), group, output.size(1) / group, - output.size(2), output.size(3)}); - - for (int b = 0; b < batch; b++) { - modulated_deformable_im2col_cuda( - input[b], offset[b], mask[b], 1, channels, height, width, height_out, - width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, columns); - - // divide into group - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - - for (int g = 0; g < group; g++) { - output[b][g] = output[b][g] - .flatten(1) - .addmm_(weight[g].flatten(1), columns[g]) - .view_as(output[b][g]); - } - - weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), - weight.size(3), weight.size(4)}); - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - output = output.view({output.size(0), output.size(1) * output.size(2), - output.size(3), output.size(4)}); - - if (with_bias) { - output += bias.view({1, bias.size(0), 1, 1}); - } -} - -void ModulatedDeformConvBackwardCUDAKernelLauncher( - Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, - Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, - Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, - int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, - int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, - const bool with_bias) { - at::DeviceGuard guard(input.device()); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel * group); - - const int height_out = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (ones.ndimension() != 2 || - ones.size(0) * ones.size(1) < height_out * width_out) { - // Resize plane and fill with ones... - ones = at::ones({height_out, width_out}, input.options()); - } - - grad_input = grad_input.view({batch, channels, height, width}); - columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, - input.options()); - - grad_output = - grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, - grad_output.size(2), grad_output.size(3)}); - - for (int b = 0; b < batch; b++) { - // divide int group - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - weight = weight.view({group, weight.size(0) / group, weight.size(1), - weight.size(2), weight.size(3)}); - - for (int g = 0; g < group; g++) { - columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), - grad_output[b][g].flatten(1), 0.0f, 1.0f); - } - - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), - weight.size(3), weight.size(4)}); - - // gradient w.r.t. input coordinate data - modulated_deformable_col2im_coord_cuda( - columns, input[b], offset[b], mask[b], 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, - stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], - grad_mask[b]); - // gradient w.r.t. input data - modulated_deformable_col2im_cuda( - columns, offset[b], mask[b], 1, channels, height, width, height_out, - width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, grad_input[b]); - - // gradient w.r.t. weight, dWeight should accumulate across the batch and - // group - modulated_deformable_im2col_cuda( - input[b], offset[b], mask[b], 1, channels, height, width, height_out, - width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, columns); - - columns = columns.view({group, columns.size(0) / group, columns.size(1)}); - grad_weight = grad_weight.view({group, grad_weight.size(0) / group, - grad_weight.size(1), grad_weight.size(2), - grad_weight.size(3)}); - if (with_bias) - grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); - - for (int g = 0; g < group; g++) { - grad_weight[g] = - grad_weight[g] - .flatten(1) - .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) - .view_as(grad_weight[g]); - if (with_bias) { - grad_bias[g] = - grad_bias[g] - .view({-1, 1}) - .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) - .view(-1); - } - } - - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), grad_weight.size(3), - grad_weight.size(4)}); - if (with_bias) - grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); - } - grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), - grad_output.size(2), grad_output.size(3), - grad_output.size(4)}); -} diff --git a/mmcv/ops/csrc/pytorch/deform_conv.cpp b/mmcv/ops/csrc/pytorch/deform_conv.cpp index 4da104e83..455102744 100644 --- a/mmcv/ops/csrc/pytorch/deform_conv.cpp +++ b/mmcv/ops/csrc/pytorch/deform_conv.cpp @@ -2,61 +2,152 @@ #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA -void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight, - Tensor offset, Tensor output, - Tensor columns, Tensor ones, int kW, - int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, - int group, int deformable_group, - int im2col_step); -void DeformConvBackwardInputCUDAKernelLauncher( - Tensor input, Tensor offset, Tensor gradOutput, Tensor gradInput, - Tensor gradOffset, Tensor weight, Tensor columns, int kW, int kH, int dW, - int dH, int padW, int padH, int dilationW, int dilationH, int group, - int deformable_group, int im2col_step); +void deformable_im2col(Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col); -void DeformConvBackwardParametersCUDAKernelLauncher( - Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight, - Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - float scale, int im2col_step); +void deformable_col2im(Tensor data_col, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im); -void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset, - Tensor output, Tensor columns, Tensor ones, - int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, - int deformable_group, int im2col_step) { - DeformConvForwardCUDAKernelLauncher( - input, weight, offset, output, columns, ones, kW, kH, dW, dH, padW, padH, - dilationW, dilationH, group, deformable_group, im2col_step); -} +void deformable_col2im_coord( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset); -void deform_conv_backward_input_cuda(Tensor input, Tensor offset, - Tensor gradOutput, Tensor gradInput, - Tensor gradOffset, Tensor weight, - Tensor columns, int kW, int kH, int dW, - int dH, int padW, int padH, int dilationW, - int dilationH, int group, - int deformable_group, int im2col_step) { - DeformConvBackwardInputCUDAKernelLauncher( - input, offset, gradOutput, gradInput, gradOffset, weight, columns, kW, kH, - dW, dH, padW, padH, dilationW, dilationH, group, deformable_group, - im2col_step); -} - -void deform_conv_backward_parameters_cuda( - Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight, - Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - float scale, int im2col_step) { - DeformConvBackwardParametersCUDAKernelLauncher( - input, offset, gradOutput, gradWeight, columns, ones, kW, kH, dW, dH, - padW, padH, dilationW, dilationH, group, deformable_group, scale, - im2col_step); -} #endif +void deformable_im2col_cpu(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col); + +void deformable_col2im_cpu(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im); + +void deformable_col2im_coord_cpu( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset); + +void deform_conv_shape_check(at::Tensor input, at::Tensor offset, + at::Tensor *gradOutput, at::Tensor weight, int kH, + int kW, int dH, int dW, int padH, int padW, + int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK( + weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", + kH, kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", + kH, kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, + dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, + "3D or 4D input tensor expected but got: %s", ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK( + (offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK( + gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK( + (gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + void deform_conv_forward(Tensor input, Tensor weight, Tensor offset, Tensor output, Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW, int padH, @@ -70,15 +161,118 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset, CHECK_CUDA_INPUT(output); CHECK_CUDA_INPUT(columns); CHECK_CUDA_INPUT(ones); - - deform_conv_forward_cuda(input, weight, offset, output, columns, ones, kW, - kH, dW, dH, padW, padH, dilationW, dilationH, - group, deformable_group, im2col_step); #else AT_ERROR("DeformConv is not compiled with GPU support"); #endif } else { - AT_ERROR("DeformConv is not implemented on CPU"); + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(weight); + CHECK_CPU_INPUT(output); + CHECK_CPU_INPUT(columns); + CHECK_CPU_INPUT(ones); + } + + deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + if (input.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); +#endif + } else { + deformable_im2col_cpu(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + } + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); } } @@ -97,16 +291,134 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput, CHECK_CUDA_INPUT(gradOffset); CHECK_CUDA_INPUT(weight); CHECK_CUDA_INPUT(columns); - - deform_conv_backward_input_cuda(input, offset, gradOutput, gradInput, - gradOffset, weight, columns, kW, kH, dW, dH, - padW, padH, dilationW, dilationH, group, - deformable_group, im2col_step); #else AT_ERROR("DeformConv is not compiled with GPU support"); #endif } else { - AT_ERROR("DeformConv is not implemented on CPU"); + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(gradOutput); + CHECK_CPU_INPUT(gradInput); + CHECK_CPU_INPUT(gradOffset); + CHECK_CPU_INPUT(weight); + CHECK_CPU_INPUT(columns); + } + deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, + padH, padW, dilationH, dilationW, group, + deformable_group); + + at::DeviceGuard guard(input.device()); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + if (input.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, + dW, dilationH, dilationW, im2col_step, + deformable_group, gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, + gradInput[elt]); +#endif + } else { + deformable_col2im_coord_cpu(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, + dH, dW, dilationH, dilationW, im2col_step, + deformable_group, gradOffset[elt]); + + deformable_col2im_cpu(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, + gradInput[elt]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); } } @@ -125,15 +437,122 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset, CHECK_CUDA_INPUT(gradWeight); CHECK_CUDA_INPUT(columns); CHECK_CUDA_INPUT(ones); - - deform_conv_backward_parameters_cuda(input, offset, gradOutput, gradWeight, - columns, ones, kW, kH, dW, dH, padW, - padH, dilationW, dilationH, group, - deformable_group, scale, im2col_step); #else AT_ERROR("DeformConv is not compiled with GPU support"); #endif } else { - AT_ERROR("DeformConv is not implemented on CPU"); + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(gradOutput); + CHECK_CPU_INPUT(gradWeight); + CHECK_CPU_INPUT(columns); + CHECK_CPU_INPUT(ones); + } + + deform_conv_shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, + dW, padH, padW, dilationH, dilationW, group, + deformable_group); + at::DeviceGuard guard(input.device()); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer = gradOutputBuffer.contiguous(); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + if (input.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); +#endif + } else { + deformable_im2col_cpu(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + } + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); } } diff --git a/mmcv/ops/csrc/pytorch/deform_conv_cpu.cpp b/mmcv/ops/csrc/pytorch/deform_conv_cpu.cpp new file mode 100644 index 000000000..cb0e638c5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/deform_conv_cpu.cpp @@ -0,0 +1,377 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" + +template +T deformable_im2col_bilinear_cpu(const T *input, const int data_width, + const int height, const int width, T h, T w) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +T get_gradient_weight_cpu(T argmax_h, T argmax_w, const int h, const int w, + const int height, const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +T get_coordinate_weight_cpu(T argmax_h, T argmax_w, const int height, + const int width, const T *im_data, + const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +void deformable_im2col_cpu_kernel( + const int n, const T *data_im, const T *data_offset, const int height, + const int width, const int kernel_h, const int kernel_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T *data_col) { + for (int index = 0; index < n; index++) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + T *data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T *data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T *data_offset_ptr = + data_offset + (b_col * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = deformable_im2col_bilinear_cpu(data_im_ptr, width, height, + width, h_im, w_im); + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +void deformable_col2im_cpu_kernel( + const int n, const T *data_col, const T *data_offset, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T *grad_im) { + for (int index = 0; index < n; index++) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = + get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, + cur_h + dy, cur_w + dx, height, width); + *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad; + } + } + } + } +} + +template +void deformable_col2im_coord_cpu_kernel( + const int n, const T *data_col, const T *data_im, const T *data_offset, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int offset_channels, const int deformable_group, const int height_col, + const int width_col, T *grad_offset) { + for (int index = 0; index < n; index++) { + T val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T *data_col_ptr = data_col + deformable_group_index * + channel_per_deformable_group * + batch_size * width_col * height_col; + const T *data_im_ptr = + data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * + height * width; + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + inv_h = inv_w = -2; + const T weight = get_coordinate_weight_cpu( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_im2col_cpu(Tensor data_im, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor data_col) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_cpu", [&] { + deformable_im2col_cpu_kernel( + num_kernels, data_im.data_ptr(), + data_offset.data_ptr(), height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, + deformable_group, height_col, width_col, + data_col.data_ptr()); + }); +} + +void deformable_col2im_cpu(Tensor data_col, Tensor data_offset, + const int channels, const int height, + const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + Tensor grad_im) { + // todo: make sure parallel_imgs is passed in correctly + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_cpu_kernel( + num_kernels, data_col_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, + dilation_w, channel_per_deformable_group, parallel_imgs, + deformable_group, height_col, width_col, grad_im_); + })); +} + +void deformable_col2im_coord_cpu( + Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, Tensor grad_offset) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * + deformable_group * parallel_imgs; + int channel_per_deformable_group = + channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_cpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_cpu_kernel( + num_kernels, data_col_, data_im_, data_offset_, channels, height, + width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, + 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} diff --git a/mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp b/mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp index 159d843e4..584bc5459 100644 --- a/mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp +++ b/mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp @@ -2,48 +2,59 @@ #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA -void ModulatedDeformConvForwardCUDAKernelLauncher( - Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, - Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, - const int stride_h, const int stride_w, const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, const int group, - const int deformable_group, const bool with_bias); -void ModulatedDeformConvBackwardCUDAKernelLauncher( - Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, - Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, - Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, - int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, - int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, - const bool with_bias); +void modulated_deformable_im2col_cuda( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col); -void modulated_deform_conv_forward_cuda( - Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, - Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, - const int stride_h, const int stride_w, const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, const int group, - const int deformable_group, const bool with_bias) { - ModulatedDeformConvForwardCUDAKernelLauncher( - input, weight, bias, ones, offset, mask, output, columns, kernel_h, - kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, - deformable_group, with_bias); -} +void modulated_deformable_col2im_cuda( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask); -void modulated_deform_conv_backward_cuda( - Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, - Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, - Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, - int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, - int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, - const bool with_bias) { - ModulatedDeformConvBackwardCUDAKernelLauncher( - input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight, - grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w, - stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, - deformable_group, with_bias); -} #endif +void modulated_deformable_im2col_cpu( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col); + +void modulated_deformable_col2im_cpu( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im); + +void modulated_deformable_col2im_coord_cpu( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask); + void modulated_deform_conv_forward( Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, @@ -61,15 +72,98 @@ void modulated_deform_conv_forward( CHECK_CUDA_INPUT(output); CHECK_CUDA_INPUT(columns); - modulated_deform_conv_forward_cuda( - input, weight, bias, ones, offset, mask, output, columns, kernel_h, - kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, - group, deformable_group, with_bias); #else AT_ERROR("ModulatedDeformConv is not compiled with GPU support"); #endif } else { - AT_ERROR("ModulatedDeformConv is not implemented on CPU"); + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(weight); + CHECK_CPU_INPUT(bias); + CHECK_CPU_INPUT(ones); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(mask); + CHECK_CPU_INPUT(output); + CHECK_CPU_INPUT(columns); + } + + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + if (input.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); +#endif + } else { + modulated_deformable_im2col_cpu( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + } + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); } } @@ -96,15 +190,149 @@ void modulated_deform_conv_backward( CHECK_CUDA_INPUT(grad_mask); CHECK_CUDA_INPUT(grad_output); - modulated_deform_conv_backward_cuda( - input, weight, bias, ones, offset, mask, columns, grad_input, - grad_weight, grad_bias, grad_offset, grad_mask, grad_output, kernel_h, - kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, - group, deformable_group, with_bias); #else AT_ERROR("ModulatedDeformConv is not compiled with GPU support"); #endif } else { - AT_ERROR("ModulatedDeformConv is not implemented on CPU"); + CHECK_CPU_INPUT(input); + CHECK_CPU_INPUT(weight); + CHECK_CPU_INPUT(bias); + CHECK_CPU_INPUT(ones); + CHECK_CPU_INPUT(offset); + CHECK_CPU_INPUT(mask); + CHECK_CPU_INPUT(columns); + CHECK_CPU_INPUT(grad_input); + CHECK_CPU_INPUT(grad_weight); + CHECK_CPU_INPUT(grad_bias); + CHECK_CPU_INPUT(grad_offset); + CHECK_CPU_INPUT(grad_mask); + CHECK_CPU_INPUT(grad_output); } + + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + if (input.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); +#endif + } else { + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cpu( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cpu( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cpu( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + } + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); } diff --git a/mmcv/ops/csrc/pytorch/modulated_deform_conv_cpu.cpp b/mmcv/ops/csrc/pytorch/modulated_deform_conv_cpu.cpp new file mode 100644 index 000000000..89a81d733 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/modulated_deform_conv_cpu.cpp @@ -0,0 +1,403 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" + +template +T dmcn_im2col_bilinear_cpu(const T *input, const int data_width, + const int height, const int width, T h, T w) { + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +T dmcn_get_gradient_weight_cpu(T argmax_h, T argmax_w, const int h, const int w, + const int height, const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +T dmcn_get_coordinate_weight_cpu(T argmax_h, T argmax_w, const int height, + const int width, const T *im_data, + const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floorf(argmax_h); + int argmax_w_low = floorf(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +void modulated_deformable_im2col_cpu_kernel( + const int n, const T *data_im, const T *data_offset, const T *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T *data_col) { + for (int index = 0; index < n; index++) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T *data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T *data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T *data_offset_ptr = + data_offset + (b_col * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const T *data_mask_ptr = + data_mask + (b_col * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, + h_im, w_im); + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +void modulated_deformable_col2im_cpu_kernel( + const int n, const T *data_col, const T *data_offset, const T *data_mask, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T *grad_im) { + for (int index = 0; index < n; index++) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const T *data_mask_ptr = + data_mask + (b * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, + cur_inv_w_data, cur_h + dy, + cur_w + dx, height, width); + *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad; + } + } + } + } +} + +template +void modulated_deformable_col2im_coord_cpu_kernel( + const int n, const T *data_col, const T *data_im, const T *data_offset, + const T *data_mask, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, T *grad_offset, T *grad_mask) { + for (int index = 0; index < n; index++) { + T val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T *data_col_ptr = data_col + deformable_group_index * + channel_per_deformable_group * + batch_size * width_col * height_col; + const T *data_im_ptr = + data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * + height * width; + const T *data_offset_ptr = + data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + const T *data_mask_ptr = + data_mask + (b * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + inv_h = inv_w = -2; + else + mval += data_col_ptr[col_pos] * + dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, + width, height, width, inv_h, inv_w); + const T weight = dmcn_get_coordinate_weight_cpu( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + + // deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * + // height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + +void modulated_deformable_im2col_cpu( + const Tensor data_im, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_cpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_cpu_kernel( + num_kernels, data_im_, data_offset_, data_mask_, height_im, + width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, batch_size, + channels, deformable_group, height_col, width_col, data_col_); + })); +} + +void modulated_deformable_col2im_cpu( + const Tensor data_col, const Tensor data_offset, const Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, Tensor grad_im) { + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = + channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_cpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_cpu_kernel( + num_kernels, data_col_, data_offset_, data_mask_, channels, + height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); +} + +void modulated_deformable_col2im_coord_cpu( + const Tensor data_col, const Tensor data_im, const Tensor data_offset, + const Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + Tensor grad_offset, Tensor grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * + kernel_w * deformable_group; + const int channel_per_deformable_group = + channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_cpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_cpu_kernel( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, + channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, batch_size, + 2 * kernel_h * kernel_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_, grad_mask_); + })); +} diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index 5b7f3bfb9..0199477af 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -36,13 +36,16 @@ gt_deform_weight_grad = [[[[3.62, 0.], [0.40, 0.18]]]] class TestDeformconv(object): - def _test_deformconv(self, dtype=torch.float, threshold=1e-3): - if not torch.cuda.is_available(): - return + def _test_deformconv(self, + dtype=torch.float, + threshold=1e-3, + device='cuda'): + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('test requires GPU') from mmcv.ops import DeformConv2dPack c_in = 1 c_out = 1 - x = torch.Tensor(input).cuda().type(dtype) + x = torch.tensor(input, device=device, dtype=dtype) x.requires_grad = True model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) model.conv_offset.weight.data = torch.nn.Parameter( @@ -51,7 +54,9 @@ class TestDeformconv(object): torch.Tensor(offset_bias).reshape(8)) model.weight.data = torch.nn.Parameter( torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) - model.cuda().type(dtype) + if device == 'cuda': + model.cuda() + model.type(dtype) out = model(x) out.backward(torch.ones_like(out)) @@ -67,6 +72,7 @@ class TestDeformconv(object): gt_deform_weight_grad, threshold) from mmcv.ops import DeformConv2d + # test bias model = DeformConv2d(1, 1, 2, stride=1, padding=0) assert not hasattr(model, 'bias') @@ -121,6 +127,7 @@ class TestDeformconv(object): gt_deform_weight_grad, threshold) from mmcv.ops import DeformConv2d + # test bias model = DeformConv2d(1, 1, 2, stride=1, padding=0) assert not hasattr(model, 'bias') @@ -135,9 +142,11 @@ class TestDeformconv(object): model = DeformConv2d(3, 4, 3, groups=3) def test_deformconv(self): + self._test_deformconv(torch.double, device='cpu') + self._test_deformconv(torch.float, device='cpu', threshold=1e-1) self._test_deformconv(torch.double) self._test_deformconv(torch.float) - self._test_deformconv(torch.half, 1e-1) + self._test_deformconv(torch.half, threshold=1e-1) # test amp when torch version >= '1.6.0', the type of # input data for deformconv might be torch.float or torch.half diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 85443d913..b528c5112 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -1,6 +1,7 @@ import os import numpy +import pytest import torch from mmcv.utils import TORCH_VERSION, digit_version @@ -37,11 +38,11 @@ dcn_offset_b_grad = [ class TestMdconv(object): - def _test_mdconv(self, dtype=torch.float): - if not torch.cuda.is_available(): - return + def _test_mdconv(self, dtype=torch.float, device='cuda'): + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('test requires GPU') from mmcv.ops import ModulatedDeformConv2dPack - input = torch.tensor(input_t).cuda().type(dtype) + input = torch.tensor(input_t, dtype=dtype, device=device) input.requires_grad = True dcn = ModulatedDeformConv2dPack( @@ -51,7 +52,11 @@ class TestMdconv(object): stride=1, padding=1, deform_groups=1, - bias=False).cuda() + bias=False) + + if device == 'cuda': + dcn.cuda() + dcn.weight.data.fill_(1.) dcn.type(dtype) output = dcn(input) @@ -106,6 +111,8 @@ class TestMdconv(object): dcn_offset_b_grad, 1e-2) def test_mdconv(self): + self._test_mdconv(torch.double, device='cpu') + self._test_mdconv(torch.float, device='cpu') self._test_mdconv(torch.double) self._test_mdconv(torch.float) self._test_mdconv(torch.half) diff --git a/tests/test_ops/test_saconv.py b/tests/test_ops/test_saconv.py index 10fc2c5b7..b34865fc9 100644 --- a/tests/test_ops/test_saconv.py +++ b/tests/test_ops/test_saconv.py @@ -1,4 +1,3 @@ -import pytest import torch import torch.nn as nn @@ -33,9 +32,10 @@ def test_sacconv(): refer_out = refer_conv(x) assert deform_sac_out.shape == refer_out.shape else: - with pytest.raises(RuntimeError): - # deform conv is not implemented on cpu - deform_saconv(x) + deform_sac_out = deform_saconv(x) + refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1) + refer_out = refer_conv(x) + assert deform_sac_out.shape == refer_out.shape # test with groups >= 2 x = torch.rand(1, 4, 256, 256)