[Feature] Add Correlation CUDA op (#1361)

pull/1378/head
Miao Zheng 2021-09-23 21:08:41 +08:00 committed by GitHub
parent f3dfc4135b
commit b92ea0b5df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 736 additions and 1 deletions

View File

@ -22,3 +22,4 @@ We implement common CUDA ops used in detection, segmentation, etc.
- SoftNMS
- Synchronized BatchNorm
- Weight standardization
- Correlation

View File

@ -22,3 +22,4 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- SoftNMS
- Synchronized BatchNorm
- Weight standardization
- Correlation

View File

@ -7,6 +7,7 @@ from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .contour_expand import contour_expand
from .corner_pool import CornerPool
from .correlation import Correlation
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
ModulatedDeformRoIPoolPack, deform_roi_pool)
@ -53,5 +54,6 @@ __all__ = [
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align'
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'Correlation'
]

View File

@ -0,0 +1,173 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['correlation_forward', 'correlation_backward'])
class CorrelationFunction(Function):
@staticmethod
def forward(ctx,
input1,
input2,
kernel_size=1,
max_displacement=1,
stride=1,
padding=1,
dilation=1,
dilation_patch=1):
ctx.save_for_backward(input1, input2)
kH, kW = ctx.kernel_size = _pair(kernel_size)
patch_size = max_displacement * 2 + 1
ctx.patch_size = patch_size
dH, dW = ctx.stride = _pair(stride)
padH, padW = ctx.padding = _pair(padding)
dilationH, dilationW = ctx.dilation = _pair(dilation)
dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
dilation_patch)
output_size = CorrelationFunction._output_size(ctx, input1)
output = input1.new_zeros(output_size)
ext_module.correlation_forward(input1, input2, output, kH, kW,
patch_size, patch_size, padH, padW,
dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input1, input2 = ctx.saved_tensors
kH, kW = ctx.kernel_size
patch_size = ctx.patch_size
padH, padW = ctx.padding
dilationH, dilationW = ctx.dilation
dilation_patchH, dilation_patchW = ctx.dilation_patch
dH, dW = ctx.stride
grad_input1 = torch.zeros_like(input1)
grad_input2 = torch.zeros_like(input2)
ext_module.correlation_backward(grad_output, input1, input2,
grad_input1, grad_input2, kH, kW,
patch_size, patch_size, padH, padW,
dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW)
return grad_input1, grad_input2, None, None, None, None, None, None
@staticmethod
def _output_size(ctx, input1):
iH, iW = input1.size(2), input1.size(3)
batch_size = input1.size(0)
kH, kW = ctx.kernel_size
patch_size = ctx.patch_size
dH, dW = ctx.stride
padH, padW = ctx.padding
dilationH, dilationW = ctx.dilation
dilatedKH = (kH - 1) * dilationH + 1
dilatedKW = (kW - 1) * dilationW + 1
oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
output_size = (batch_size, patch_size, patch_size, oH, oW)
return output_size
class Correlation(nn.Module):
r"""Correlation operator
This correlation operator works for optical flow correlation computation.
There are two batched tensors with shape :math:`(N, C, H, W)`,
and the correlation output's shape is
:math:`(N, \text{max_displacement} \times 2+1,
\text{max_displacement} \times 2+1,
H_{out}, W_{out})`
where
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding} -
\text{dilation} \times (\text{kernel_size} - 1) - 1}
{\text{stride}} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding} -
\text{dilation} \times (\text{kernel_size} - 1) - 1}
{\text{stride}} + 1\right\rfloor
the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
window convolution between input1 and shifted input2,
.. math::
Corr(N_i, dx, dy) =
\sum_{c=0}^{C-1}
input1(N_i, c) \star
\mathcal{S}(input2(N_i, c), dy, dx)
where :math:`\star` is the valid 2d sliding window convolution operator,
and :math:`\mathcal{S}` means shifting the input features (auto-complete
zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
[-\text{max_displacement} \times \text{dilation_patch},
\text{max_displacement} \times \text{dilation_patch}]`.
Args:
kernel_size (int): The size of sliding window i.e. local neighborhood
representing the center points and involved in correlation
computation. Defaults to 1.
max_displacement (int): The radius for computing correlation volume,
but the actual working space can be dilated by dilation_patch.
Defaults to 1.
stride (int): The stride of the sliding blocks in the input spatial
dimensions. Defaults to 1.
padding (int): Zero padding added to all four sides of the input1.
Defaults to 0.
dilation (int): The spacing of local neighborhood that will involved
in correlation. Defaults to 1.
dilation_patch (int): The spacing between position need to compute
correlation. Defaults to 1.
"""
def __init__(self,
kernel_size: int = 1,
max_displacement: int = 1,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
dilation_patch: int = 1) -> None:
super().__init__()
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride = stride
self.padding = padding
self.dilation = dilation
self.dilation_patch = dilation_patch
def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
return CorrelationFunction.apply(input1, input2, self.kernel_size,
self.max_displacement, self.stride,
self.padding, self.dilation,
self.dilation_patch)
def __repr__(self) -> str:
s = self.__class__.__name__
s += f'(kernel_size={self.kernel_size}, '
s += f'max_displacement={self.max_displacement}, '
s += f'stride={self.stride}, '
s += f'padding={self.padding}, '
s += f'dilation={self.dilation}, '
s += f'dilation_patch={self.dilation_patch})'
return s

View File

@ -0,0 +1,269 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu
// Original licence: Under MIT License
#ifndef CORRELATION_CUDA
#define CORRELATION_CUDA
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <vector>
#include <iostream>
using namespace torch;
#define TensorAcc4R PackedTensorAccessor32<scalar_t, 4, RestrictPtrTraits>
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define THREADS_FORWARD 32
#define THREADS_BACKWARD 16
template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
const TensorAcc4R rInput2,
TensorAcc5R output,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW)
{
const int iH = rInput1.size(1);
const int iW = rInput1.size(2);
const int C = rInput1.size(3);
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int thread = threadIdx.x;
const int start_i = -padH + h * dH;
const int start_j = -padW + w * dW;
const int patchRadH = dilation_patchH * (patchH - 1) / 2;
const int patchRadW = dilation_patchW * (patchW - 1) / 2;
__shared__ scalar_t prod_sum[THREADS_FORWARD];
for (int ph = 0; ph < patchH; ++ph)
{
int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw)
{
int pw_dilated = pw * dilation_patchW - patchRadW;
prod_sum[thread] = 0;
for (int i = 0; i < kH; ++i)
{
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
if WITHIN_BOUNDS (i1, i2, iH, iH)
{
for (int j = 0; j < kW; ++j)
{
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if WITHIN_BOUNDS (j1, j2, iW, iW)
{
for (int c = thread; c < C; c += THREADS_FORWARD)
{
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum[thread] += v1 * v2;
}
}
}
}
}
// accumulate
__syncthreads();
if (thread == 0)
{
scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_FORWARD; ++index)
{
reduce_sum += prod_sum[index];
}
output[n][ph][pw][h][w] = reduce_sum;
}
}
}
}
template <typename scalar_t>
__global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_output,
const TensorAcc4R input2,
TensorAcc4R grad_input1,
const int kH, const int kW,
const int patchH, const int patchW,
const int padH, const int padW,
const int dilationH, const int dilationW,
const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW,
const int batch)
{
const int iH = input2.size(2);
const int iW = input2.size(3);
const int H = grad_output.size(3);
const int W = grad_output.size(4);
const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
const int n = batch;
const int c = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;
const int h_2 = h + padH;
const int w_2 = w + padW;
const int min_h = h_2 - kH * dilationH;
const int min_w = w_2 - kW * dilationW;
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD)
{
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD)
{
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW))
{
scalar_t val = input2[n][c][i1][j1];
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH)
{
int i2 = (h_3) / dH;
if (i2 * dH != h_3)
continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
{
int j2 = (w_3) / dW;
if (j2 * dW != w_3)
continue;
if WITHIN_BOUNDS (i2, j2, H, W)
{
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
}
}
}
}
}
}
__syncthreads();
if (ph_off == 0 && pw_off == 0)
{
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph)
{
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
{
reduce_sum += prod_sum[ph][pw];
}
}
grad_input1[n][c][h][w] = reduce_sum;
}
}
template <typename scalar_t>
__global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_output,
const TensorAcc4R input1,
TensorAcc4R grad_input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW,
int batch)
{
const int iH = input1.size(2);
const int iW = input1.size(3);
const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
const int H = grad_output.size(3);
const int W = grad_output.size(4);
const int dilatedKH = kH * dilationH;
const int dilatedKW = kW * dilationW;
const int n = batch;
const int c = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD)
{
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD)
{
int j1 = w - dilation_patchW * (pw - patchRadW);
if WITHIN_BOUNDS (i1, j1, iH, iW)
{
scalar_t val = input1[n][c][i1][j1];
const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH)
{
int i2 = (h_3) / dH;
if (i2 * dH != h_3)
continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
{
int j2 = (w_3) / dW;
if (j2 * dW != w_3)
continue;
if WITHIN_BOUNDS (i2, j2, H, W)
{
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
}
}
}
}
}
}
__syncthreads();
if (ph_off == 0 && pw_off == 0)
{
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph)
{
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
{
reduce_sum += prod_sum[ph][pw];
}
}
grad_input2[n][c][h][w] = reduce_sum;
}
}
#endif

View File

@ -0,0 +1,116 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <iostream>
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
Tensor output, int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW);
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1,
Tensor input2, Tensor grad_input1,
Tensor grad_input2, int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW);
void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output,
int kH, int kW, int patchH, int patchW,
int padH, int padW, int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW)
{
CorrelationForwardCUDAKernelLauncher(input1, input2, output, kH, kW,
patchH, patchW, padH, padW, dilationH,
dilationW, dilation_patchH,
dilation_patchW, dH, dW);
}
void correlation_cuda_backward(Tensor grad_output,
Tensor input1, Tensor input2,
Tensor grad_input1, Tensor grad_input2,
int kH, int kW, int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW)
{
CorrelationBackwardCUDAKernelLauncher(grad_output, input1, input2,
grad_input1, grad_input2, kH, kW,
patchH, patchW, padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
}
#endif
void correlation_forward(Tensor input1, Tensor input2, Tensor output,
int kH, int kW, int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW)
{
if (input1.device().is_cuda() and input2.device().is_cuda())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input1);
CHECK_CUDA_INPUT(input2);
correlation_cuda_forward(input1, input2, output, kH, kW,
patchH, patchW, padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
#else
AT_ERROR("Correlation is not compiled with GPU support");
#endif
}
else
{
AT_ERROR("Correlation is not implemented on CPU");
}
}
void correlation_backward(Tensor grad_output,
Tensor input1, Tensor input2,
Tensor grad_input1, Tensor grad_input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW)
{
if (input1.device().is_cuda() and input2.device().is_cuda())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(input1);
CHECK_CUDA_INPUT(input2);
correlation_cuda_backward(grad_output, input1, input2,
grad_input1, grad_input2, kH, kW,
patchH, patchW, padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
#else
AT_ERROR("Correlation is not compiled with GPU support");
#endif
}
else
{
AT_ERROR("Correlation is not implemented on CPU");
}
}

View File

@ -0,0 +1,105 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu
// Original licence: Under MIT License
#include "correlation_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
Tensor output, int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW)
{
const int batch_size = input1.size(0);
const int iH = input1.size(2);
const int iW = input1.size(3);
const int dilatedKH = (kH - 1) * dilationH + 1;
const int dilatedKW = (kW - 1) * dilationW + 1;
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const int threads = THREADS_FORWARD;
const dim3 blocks(batch_size, oH, oW);
at::cuda::CUDAGuard device_guard(input1.device());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(),
"correlation_forward_cuda",
([&]{
TensorAcc4R trInput1_acc = trInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
TensorAcc4R trInput2_acc = trInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
TensorAcc5R output_acc = output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
correlation_forward_cuda_kernel<scalar_t><<<blocks, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
trInput1_acc, trInput2_acc, output_acc,
kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
dilation_patchH, dilation_patchW, dH, dW);
}));
}
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1,
Tensor input2, Tensor grad_input1,
Tensor grad_input2, int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW){
const int batch_size = input1.size(0);
const int iH = input1.size(2);
const int iW = input1.size(3);
const int C = input1.size(1);
const dim3 blocks(C, iH, iW);
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
at::cuda::CUDAGuard device_guard(input1.device());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(),
"correlation_backward_cuda",
([&]{
TensorAcc4R input1_acc = input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
TensorAcc4R input2_acc = input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
TensorAcc4R grad_input1_acc = grad_input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
TensorAcc4R grad_input2_acc = grad_input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
TensorAcc5R grad_output_acc = grad_output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
for (int n = 0; n < batch_size; ++n){
correlation_backward_cuda_kernel_input1<scalar_t><<<blocks, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc,
kH, kW, patchH, patchW, padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW, n);
}
for (int n = 0; n < batch_size; ++n){
correlation_backward_cuda_kernel_input2<scalar_t><<<blocks, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc,
kH, kW, patchH, patchW, padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW, n);
}
}));
}

View File

@ -225,6 +225,23 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input,
const int pool_size);
void correlation_forward(Tensor input1, Tensor input2, Tensor output,
int kH, int kW, int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
void correlation_backward(Tensor grad_output,
Tensor input1, Tensor input2,
Tensor grad_input1, Tensor grad_input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -452,4 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"backward function of border_align", py::arg("grad_output"),
py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"),
py::arg("pool_size"));
m.def("correlation_forward", &correlation_forward, "Correlation forward");
m.def("correlation_backward", &correlation_backward, "Correlation backward");
}

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import Correlation
_input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
_input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
_input2_2 = [[[[1., 2.], [3., 1.], [8., 5.]]]]
gt_out_shape = (1, 1, 1, 3, 3)
_gt_out = [[[[[1., 4., 9.], [0., 1., 4.], [24., 25., 4.]]]]]
gt_input1_grad = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
_ap_gt_out = [[[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]],
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]],
[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]],
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]],
[[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]],
[[5., 10., 15.], [15., 5., 10.], [40., 25., 10.]],
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]]]]
def assert_equal_tensor(tensor_a, tensor_b):
assert tensor_a.eq(tensor_b).all()
class TestCorrelation:
def _test_correlation(self, dtype=torch.float):
layer = Correlation(max_displacement=0)
input1 = torch.tensor(_input1, dtype=dtype).cuda()
input2 = torch.tensor(_input2, dtype=dtype).cuda()
input1.requires_grad = True
input2.requires_grad = True
out = layer(input1, input2)
out.backward(torch.ones_like(out))
gt_out = torch.tensor(_gt_out, dtype=dtype)
assert_equal_tensor(out.cpu(), gt_out)
assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu())
assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu())
def test_correlation(self):
self._test_correlation(torch.float)
self._test_correlation(torch.double)
self._test_correlation(torch.half)