From b92ea0b5df9cfa4222fec704890bd6c8901fe6b8 Mon Sep 17 00:00:00 2001 From: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Date: Thu, 23 Sep 2021 21:08:41 +0800 Subject: [PATCH] [Feature] Add Correlation CUDA op (#1361) --- docs/understand_mmcv/ops.md | 1 + docs_zh_CN/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 4 +- mmcv/ops/correlation.py | 173 +++++++++++ .../ops/csrc/common/cuda/correlation_cuda.cuh | 269 ++++++++++++++++++ mmcv/ops/csrc/pytorch/correlation.cpp | 116 ++++++++ .../ops/csrc/pytorch/cuda/correlation_cuda.cu | 105 +++++++ mmcv/ops/csrc/pytorch/pybind.cpp | 19 ++ tests/test_ops/test_corr.py | 49 ++++ 9 files changed, 736 insertions(+), 1 deletion(-) create mode 100644 mmcv/ops/correlation.py create mode 100644 mmcv/ops/csrc/common/cuda/correlation_cuda.cuh create mode 100644 mmcv/ops/csrc/pytorch/correlation.cpp create mode 100644 mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu create mode 100644 tests/test_ops/test_corr.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index ab5eb5f3b..e04f32c7c 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -22,3 +22,4 @@ We implement common CUDA ops used in detection, segmentation, etc. - SoftNMS - Synchronized BatchNorm - Weight standardization +- Correlation diff --git a/docs_zh_CN/understand_mmcv/ops.md b/docs_zh_CN/understand_mmcv/ops.md index dfc1c96fe..0d7c17fd7 100644 --- a/docs_zh_CN/understand_mmcv/ops.md +++ b/docs_zh_CN/understand_mmcv/ops.md @@ -22,3 +22,4 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - SoftNMS - Synchronized BatchNorm - Weight standardization +- Correlation diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index e849d9ce2..3a4a838e7 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -7,6 +7,7 @@ from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive from .cc_attention import CrissCrossAttention from .contour_expand import contour_expand from .corner_pool import CornerPool +from .correlation import Correlation from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack, ModulatedDeformRoIPoolPack, deform_roi_pool) @@ -53,5 +54,6 @@ __all__ = [ 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', - 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align' + 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', + 'Correlation' ] diff --git a/mmcv/ops/correlation.py b/mmcv/ops/correlation.py new file mode 100644 index 000000000..2b85114e0 --- /dev/null +++ b/mmcv/ops/correlation.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['correlation_forward', 'correlation_backward']) + + +class CorrelationFunction(Function): + + @staticmethod + def forward(ctx, + input1, + input2, + kernel_size=1, + max_displacement=1, + stride=1, + padding=1, + dilation=1, + dilation_patch=1): + + ctx.save_for_backward(input1, input2) + + kH, kW = ctx.kernel_size = _pair(kernel_size) + patch_size = max_displacement * 2 + 1 + ctx.patch_size = patch_size + dH, dW = ctx.stride = _pair(stride) + padH, padW = ctx.padding = _pair(padding) + dilationH, dilationW = ctx.dilation = _pair(dilation) + dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair( + dilation_patch) + + output_size = CorrelationFunction._output_size(ctx, input1) + + output = input1.new_zeros(output_size) + + ext_module.correlation_forward(input1, input2, output, kH, kW, + patch_size, patch_size, padH, padW, + dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input1, input2 = ctx.saved_tensors + + kH, kW = ctx.kernel_size + patch_size = ctx.patch_size + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilation_patchH, dilation_patchW = ctx.dilation_patch + dH, dW = ctx.stride + grad_input1 = torch.zeros_like(input1) + grad_input2 = torch.zeros_like(input2) + + ext_module.correlation_backward(grad_output, input1, input2, + grad_input1, grad_input2, kH, kW, + patch_size, patch_size, padH, padW, + dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW) + return grad_input1, grad_input2, None, None, None, None, None, None + + @staticmethod + def _output_size(ctx, input1): + iH, iW = input1.size(2), input1.size(3) + batch_size = input1.size(0) + kH, kW = ctx.kernel_size + patch_size = ctx.patch_size + dH, dW = ctx.stride + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilatedKH = (kH - 1) * dilationH + 1 + dilatedKW = (kW - 1) * dilationW + 1 + + oH = int((iH + 2 * padH - dilatedKH) / dH + 1) + oW = int((iW + 2 * padW - dilatedKW) / dW + 1) + + output_size = (batch_size, patch_size, patch_size, oH, oW) + return output_size + + +class Correlation(nn.Module): + r"""Correlation operator + + This correlation operator works for optical flow correlation computation. + + There are two batched tensors with shape :math:`(N, C, H, W)`, + and the correlation output's shape is + :math:`(N, \text{max_displacement} \times 2+1, + \text{max_displacement} \times 2+1, + H_{out}, W_{out})` + + where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding} - + \text{dilation} \times (\text{kernel_size} - 1) - 1} + {\text{stride}} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding} - + \text{dilation} \times (\text{kernel_size} - 1) - 1} + {\text{stride}} + 1\right\rfloor + + the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding + window convolution between input1 and shifted input2, + + .. math:: + Corr(N_i, dx, dy) = + \sum_{c=0}^{C-1} + input1(N_i, c) \star + \mathcal{S}(input2(N_i, c), dy, dx) + + where :math:`\star` is the valid 2d sliding window convolution operator, + and :math:`\mathcal{S}` means shifting the input features (auto-complete + zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in + [-\text{max_displacement} \times \text{dilation_patch}, + \text{max_displacement} \times \text{dilation_patch}]`. + + Args: + kernel_size (int): The size of sliding window i.e. local neighborhood + representing the center points and involved in correlation + computation. Defaults to 1. + max_displacement (int): The radius for computing correlation volume, + but the actual working space can be dilated by dilation_patch. + Defaults to 1. + stride (int): The stride of the sliding blocks in the input spatial + dimensions. Defaults to 1. + padding (int): Zero padding added to all four sides of the input1. + Defaults to 0. + dilation (int): The spacing of local neighborhood that will involved + in correlation. Defaults to 1. + dilation_patch (int): The spacing between position need to compute + correlation. Defaults to 1. + """ + + def __init__(self, + kernel_size: int = 1, + max_displacement: int = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + dilation_patch: int = 1) -> None: + super().__init__() + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride = stride + self.padding = padding + self.dilation = dilation + self.dilation_patch = dilation_patch + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + return CorrelationFunction.apply(input1, input2, self.kernel_size, + self.max_displacement, self.stride, + self.padding, self.dilation, + self.dilation_patch) + + def __repr__(self) -> str: + s = self.__class__.__name__ + s += f'(kernel_size={self.kernel_size}, ' + s += f'max_displacement={self.max_displacement}, ' + s += f'stride={self.stride}, ' + s += f'padding={self.padding}, ' + s += f'dilation={self.dilation}, ' + s += f'dilation_patch={self.dilation_patch})' + return s diff --git a/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh b/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh new file mode 100644 index 000000000..b479e21bf --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh @@ -0,0 +1,269 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu +// Original licence: Under MIT License + +#ifndef CORRELATION_CUDA +#define CORRELATION_CUDA + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +#include +#include +#include +#include +#include + +using namespace torch; + +#define TensorAcc4R PackedTensorAccessor32 +#define TensorAcc5R PackedTensorAccessor32 +#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W) + +#define THREADS_FORWARD 32 +#define THREADS_BACKWARD 16 + +template +__global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, + const TensorAcc4R rInput2, + TensorAcc5R output, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, + int dilation_patchW, + int dH, int dW) +{ + const int iH = rInput1.size(1); + const int iW = rInput1.size(2); + const int C = rInput1.size(3); + + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int thread = threadIdx.x; + + const int start_i = -padH + h * dH; + const int start_j = -padW + w * dW; + + const int patchRadH = dilation_patchH * (patchH - 1) / 2; + const int patchRadW = dilation_patchW * (patchW - 1) / 2; + + __shared__ scalar_t prod_sum[THREADS_FORWARD]; + + for (int ph = 0; ph < patchH; ++ph) + { + int ph_dilated = ph * dilation_patchH - patchRadH; + for (int pw = 0; pw < patchW; ++pw) + { + int pw_dilated = pw * dilation_patchW - patchRadW; + prod_sum[thread] = 0; + for (int i = 0; i < kH; ++i) + { + int i1 = start_i + i * dilationH; + int i2 = i1 + ph_dilated; + if WITHIN_BOUNDS (i1, i2, iH, iH) + { + for (int j = 0; j < kW; ++j) + { + int j1 = start_j + j * dilationW; + int j2 = j1 + pw_dilated; + if WITHIN_BOUNDS (j1, j2, iW, iW) + { + for (int c = thread; c < C; c += THREADS_FORWARD) + { + scalar_t v1 = rInput1[n][i1][j1][c]; + scalar_t v2 = rInput2[n][i2][j2][c]; + prod_sum[thread] += v1 * v2; + } + } + } + } + } + // accumulate + __syncthreads(); + if (thread == 0) + { + scalar_t reduce_sum = 0; + for (int index = 0; index < THREADS_FORWARD; ++index) + { + reduce_sum += prod_sum[index]; + } + output[n][ph][pw][h][w] = reduce_sum; + } + } + } +} + +template +__global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_output, + const TensorAcc4R input2, + TensorAcc4R grad_input1, + const int kH, const int kW, + const int patchH, const int patchW, + const int padH, const int padW, + const int dilationH, const int dilationW, + const int dilation_patchH, const int dilation_patchW, + const int dH, const int dW, + const int batch) +{ + + const int iH = input2.size(2); + const int iW = input2.size(3); + + const int H = grad_output.size(3); + const int W = grad_output.size(4); + + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + + const int n = batch; + const int c = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int ph_off = threadIdx.x; + const int pw_off = threadIdx.y; + + const int h_2 = h + padH; + const int w_2 = w + padW; + const int min_h = h_2 - kH * dilationH; + const int min_w = w_2 - kW * dilationW; + + __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; + prod_sum[ph_off][pw_off] = 0; + + for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) + { + int i1 = h + dilation_patchH * (ph - patchRadH); + for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) + { + int j1 = w + dilation_patchW * (pw - patchRadW); + if (WITHIN_BOUNDS(i1, j1, iH, iW)) + { + scalar_t val = input2[n][c][i1][j1]; + for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) + { + int i2 = (h_3) / dH; + if (i2 * dH != h_3) + continue; + for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) + { + int j2 = (w_3) / dW; + if (j2 * dW != w_3) + continue; + if WITHIN_BOUNDS (i2, j2, H, W) + { + prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val; + } + } + } + } + } + } + + __syncthreads(); + + if (ph_off == 0 && pw_off == 0) + { + scalar_t reduce_sum = 0; + for (int ph = 0; ph < THREADS_BACKWARD; ++ph) + { + for (int pw = 0; pw < THREADS_BACKWARD; ++pw) + { + reduce_sum += prod_sum[ph][pw]; + } + } + grad_input1[n][c][h][w] = reduce_sum; + } +} + +template +__global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_output, + const TensorAcc4R input1, + TensorAcc4R grad_input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW, + int batch) +{ + const int iH = input1.size(2); + const int iW = input1.size(3); + + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + + const int H = grad_output.size(3); + const int W = grad_output.size(4); + + const int dilatedKH = kH * dilationH; + const int dilatedKW = kW * dilationW; + + const int n = batch; + const int c = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int ph_off = threadIdx.x; + const int pw_off = threadIdx.y; + + __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; + prod_sum[ph_off][pw_off] = 0; + + for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) + { + int i1 = h - dilation_patchH * (ph - patchRadH); + for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) + { + int j1 = w - dilation_patchW * (pw - patchRadW); + if WITHIN_BOUNDS (i1, j1, iH, iW) + { + scalar_t val = input1[n][c][i1][j1]; + + const int h_2 = i1 + padH; + const int w_2 = j1 + padW; + const int min_h = h_2 - dilatedKH; + const int min_w = w_2 - dilatedKW; + + for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) + { + int i2 = (h_3) / dH; + if (i2 * dH != h_3) + continue; + for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) + { + int j2 = (w_3) / dW; + if (j2 * dW != w_3) + continue; + if WITHIN_BOUNDS (i2, j2, H, W) + { + prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val; + } + } + } + } + } + } + + __syncthreads(); + + if (ph_off == 0 && pw_off == 0) + { + scalar_t reduce_sum = 0; + for (int ph = 0; ph < THREADS_BACKWARD; ++ph) + { + for (int pw = 0; pw < THREADS_BACKWARD; ++pw) + { + reduce_sum += prod_sum[ph][pw]; + } + } + grad_input2[n][c][h][w] = reduce_sum; + } +} +#endif diff --git a/mmcv/ops/csrc/pytorch/correlation.cpp b/mmcv/ops/csrc/pytorch/correlation.cpp new file mode 100644 index 000000000..af2864de5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/correlation.cpp @@ -0,0 +1,116 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA + +void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, + Tensor output, int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, + int dilation_patchW, + int dH, int dW); + +void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, + Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, + int dilation_patchW, + int dH, int dW); + +void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, + int padH, int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) +{ + + CorrelationForwardCUDAKernelLauncher(input1, input2, output, kH, kW, + patchH, patchW, padH, padW, dilationH, + dilationW, dilation_patchH, + dilation_patchW, dH, dW); +} + +void correlation_cuda_backward(Tensor grad_output, + Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, + int kH, int kW, int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) +{ + CorrelationBackwardCUDAKernelLauncher(grad_output, input1, input2, + grad_input1, grad_input2, kH, kW, + patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); +} + +#endif + +void correlation_forward(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) +{ + if (input1.device().is_cuda() and input2.device().is_cuda()) + { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_forward(input1, input2, output, kH, kW, + patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); +#else + AT_ERROR("Correlation is not compiled with GPU support"); +#endif + } + else + { + AT_ERROR("Correlation is not implemented on CPU"); + } +} + +void correlation_backward(Tensor grad_output, + Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) +{ + if (input1.device().is_cuda() and input2.device().is_cuda()) + { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(grad_output); + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_backward(grad_output, input1, input2, + grad_input1, grad_input2, kH, kW, + patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); + +#else + AT_ERROR("Correlation is not compiled with GPU support"); +#endif + } + else + { + AT_ERROR("Correlation is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu new file mode 100644 index 000000000..86991b362 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu @@ -0,0 +1,105 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu +// Original licence: Under MIT License + +#include "correlation_cuda.cuh" +#include "pytorch_cuda_helper.hpp" + +void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, + Tensor output, int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, + int dilation_patchW, + int dH, int dW) +{ + + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int dilatedKH = (kH - 1) * dilationH + 1; + const int dilatedKW = (kW - 1) * dilationW + 1; + + + const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1; + const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1; + + + auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); + auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); + + const int threads = THREADS_FORWARD; + const dim3 blocks(batch_size, oH, oW); + + at::cuda::CUDAGuard device_guard(input1.device()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), + "correlation_forward_cuda", + ([&]{ + TensorAcc4R trInput1_acc = trInput1.packed_accessor32(); + TensorAcc4R trInput2_acc = trInput2.packed_accessor32(); + TensorAcc5R output_acc = output.packed_accessor32(); + + correlation_forward_cuda_kernel<<>>( + trInput1_acc, trInput2_acc, output_acc, + kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, dH, dW); + + })); + +} + + +void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, + Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, + int dilation_patchW, + int dH, int dW){ + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int C = input1.size(1); + + const dim3 blocks(C, iH, iW); + const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD); + + at::cuda::CUDAGuard device_guard(input1.device()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), + "correlation_backward_cuda", + ([&]{ + TensorAcc4R input1_acc = input1.packed_accessor32(); + TensorAcc4R input2_acc = input2.packed_accessor32(); + TensorAcc4R grad_input1_acc = grad_input1.packed_accessor32(); + TensorAcc4R grad_input2_acc = grad_input2.packed_accessor32(); + TensorAcc5R grad_output_acc = grad_output.packed_accessor32(); + + for (int n = 0; n < batch_size; ++n){ + correlation_backward_cuda_kernel_input1<<>>( + grad_output_acc, input2_acc, grad_input1_acc, + kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW, n); + } + + for (int n = 0; n < batch_size; ++n){ + correlation_backward_cuda_kernel_input2<<>>( + grad_output_acc, input1_acc, grad_input2_acc, + kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW, n); + + } + })); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 7b888a261..0f173ef73 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -225,6 +225,23 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, const Tensor &argmax_idx, Tensor grad_input, const int pool_size); +void correlation_forward(Tensor input1, Tensor input2, Tensor output, + int kH, int kW, int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW); + +void correlation_backward(Tensor grad_output, + Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -452,4 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "backward function of border_align", py::arg("grad_output"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("pool_size")); + m.def("correlation_forward", &correlation_forward, "Correlation forward"); + m.def("correlation_backward", &correlation_backward, "Correlation backward"); } diff --git a/tests/test_ops/test_corr.py b/tests/test_ops/test_corr.py new file mode 100644 index 000000000..2b34c6ad9 --- /dev/null +++ b/tests/test_ops/test_corr.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmcv.ops import Correlation + +_input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] +_input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] +_input2_2 = [[[[1., 2.], [3., 1.], [8., 5.]]]] +gt_out_shape = (1, 1, 1, 3, 3) +_gt_out = [[[[[1., 4., 9.], [0., 1., 4.], [24., 25., 4.]]]]] +gt_input1_grad = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] +_ap_gt_out = [[[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]], + [[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]], + [[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]], + [[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]], + [[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]], + [[5., 10., 15.], [15., 5., 10.], [40., 25., 10.]], + [[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]]]] + + +def assert_equal_tensor(tensor_a, tensor_b): + + assert tensor_a.eq(tensor_b).all() + + +class TestCorrelation: + + def _test_correlation(self, dtype=torch.float): + + layer = Correlation(max_displacement=0) + + input1 = torch.tensor(_input1, dtype=dtype).cuda() + input2 = torch.tensor(_input2, dtype=dtype).cuda() + input1.requires_grad = True + input2.requires_grad = True + out = layer(input1, input2) + out.backward(torch.ones_like(out)) + + gt_out = torch.tensor(_gt_out, dtype=dtype) + assert_equal_tensor(out.cpu(), gt_out) + assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu()) + assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu()) + + def test_correlation(self): + self._test_correlation(torch.float) + self._test_correlation(torch.double) + self._test_correlation(torch.half)