mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Revise unit test of correlation (#1368)
* [Fix] Revise unit test of correlation * rename * lint * lint * lint * lintpull/1372/head
parent
9d4571e3d0
commit
745aa7373e
|
@ -15,8 +15,9 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/types.h>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using namespace torch;
|
||||
|
||||
|
@ -28,17 +29,10 @@ using namespace torch;
|
|||
#define THREADS_BACKWARD 16
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
|
||||
const TensorAcc4R rInput2,
|
||||
TensorAcc5R output,
|
||||
int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH,
|
||||
int dilation_patchW,
|
||||
int dH, int dW)
|
||||
{
|
||||
__global__ void correlation_forward_cuda_kernel(
|
||||
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
|
||||
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
|
||||
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) {
|
||||
const int iH = rInput1.size(1);
|
||||
const int iW = rInput1.size(2);
|
||||
const int C = rInput1.size(3);
|
||||
|
@ -56,42 +50,35 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
|
|||
|
||||
__shared__ scalar_t prod_sum[THREADS_FORWARD];
|
||||
|
||||
for (int ph = 0; ph < patchH; ++ph)
|
||||
{
|
||||
for (int ph = 0; ph < patchH; ++ph) {
|
||||
int ph_dilated = ph * dilation_patchH - patchRadH;
|
||||
for (int pw = 0; pw < patchW; ++pw)
|
||||
{
|
||||
for (int pw = 0; pw < patchW; ++pw) {
|
||||
int pw_dilated = pw * dilation_patchW - patchRadW;
|
||||
prod_sum[thread] = 0;
|
||||
for (int i = 0; i < kH; ++i)
|
||||
{
|
||||
for (int i = 0; i < kH; ++i) {
|
||||
int i1 = start_i + i * dilationH;
|
||||
int i2 = i1 + ph_dilated;
|
||||
if WITHIN_BOUNDS (i1, i2, iH, iH)
|
||||
{
|
||||
for (int j = 0; j < kW; ++j)
|
||||
{
|
||||
int j1 = start_j + j * dilationW;
|
||||
int j2 = j1 + pw_dilated;
|
||||
if WITHIN_BOUNDS (j1, j2, iW, iW)
|
||||
{
|
||||
for (int c = thread; c < C; c += THREADS_FORWARD)
|
||||
{
|
||||
scalar_t v1 = rInput1[n][i1][j1][c];
|
||||
scalar_t v2 = rInput2[n][i2][j2][c];
|
||||
prod_sum[thread] += v1 * v2;
|
||||
}
|
||||
if
|
||||
WITHIN_BOUNDS(i1, i2, iH, iH) {
|
||||
for (int j = 0; j < kW; ++j) {
|
||||
int j1 = start_j + j * dilationW;
|
||||
int j2 = j1 + pw_dilated;
|
||||
if
|
||||
WITHIN_BOUNDS(j1, j2, iW, iW) {
|
||||
for (int c = thread; c < C; c += THREADS_FORWARD) {
|
||||
scalar_t v1 = rInput1[n][i1][j1][c];
|
||||
scalar_t v2 = rInput2[n][i2][j2][c];
|
||||
prod_sum[thread] += v1 * v2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// accumulate
|
||||
__syncthreads();
|
||||
if (thread == 0)
|
||||
{
|
||||
if (thread == 0) {
|
||||
scalar_t reduce_sum = 0;
|
||||
for (int index = 0; index < THREADS_FORWARD; ++index)
|
||||
{
|
||||
for (int index = 0; index < THREADS_FORWARD; ++index) {
|
||||
reduce_sum += prod_sum[index];
|
||||
}
|
||||
output[n][ph][pw][h][w] = reduce_sum;
|
||||
|
@ -101,18 +88,12 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_output,
|
||||
const TensorAcc4R input2,
|
||||
TensorAcc4R grad_input1,
|
||||
const int kH, const int kW,
|
||||
const int patchH, const int patchW,
|
||||
const int padH, const int padW,
|
||||
const int dilationH, const int dilationW,
|
||||
const int dilation_patchH, const int dilation_patchW,
|
||||
const int dH, const int dW,
|
||||
const int batch)
|
||||
{
|
||||
|
||||
__global__ void correlation_backward_cuda_kernel_input1(
|
||||
const TensorAcc5R grad_output, const TensorAcc4R input2,
|
||||
TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
|
||||
const int patchW, const int padH, const int padW, const int dilationH,
|
||||
const int dilationW, const int dilation_patchH, const int dilation_patchW,
|
||||
const int dH, const int dW, const int batch) {
|
||||
const int iH = input2.size(2);
|
||||
const int iW = input2.size(3);
|
||||
|
||||
|
@ -137,29 +118,23 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
|
|||
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
|
||||
prod_sum[ph_off][pw_off] = 0;
|
||||
|
||||
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD)
|
||||
{
|
||||
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
|
||||
int i1 = h + dilation_patchH * (ph - patchRadH);
|
||||
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD)
|
||||
{
|
||||
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
|
||||
int j1 = w + dilation_patchW * (pw - patchRadW);
|
||||
if (WITHIN_BOUNDS(i1, j1, iH, iW))
|
||||
{
|
||||
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
|
||||
scalar_t val = input2[n][c][i1][j1];
|
||||
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH)
|
||||
{
|
||||
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
|
||||
int i2 = (h_3) / dH;
|
||||
if (i2 * dH != h_3)
|
||||
continue;
|
||||
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
|
||||
{
|
||||
if (i2 * dH != h_3) continue;
|
||||
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
|
||||
int j2 = (w_3) / dW;
|
||||
if (j2 * dW != w_3)
|
||||
continue;
|
||||
if WITHIN_BOUNDS (i2, j2, H, W)
|
||||
{
|
||||
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
|
||||
}
|
||||
if (j2 * dW != w_3) continue;
|
||||
if
|
||||
WITHIN_BOUNDS(i2, j2, H, W) {
|
||||
prod_sum[ph_off][pw_off] +=
|
||||
grad_output[n][ph][pw][i2][j2] * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -168,13 +143,10 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
|
|||
|
||||
__syncthreads();
|
||||
|
||||
if (ph_off == 0 && pw_off == 0)
|
||||
{
|
||||
if (ph_off == 0 && pw_off == 0) {
|
||||
scalar_t reduce_sum = 0;
|
||||
for (int ph = 0; ph < THREADS_BACKWARD; ++ph)
|
||||
{
|
||||
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
|
||||
{
|
||||
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
|
||||
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
|
||||
reduce_sum += prod_sum[ph][pw];
|
||||
}
|
||||
}
|
||||
|
@ -183,17 +155,11 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_output,
|
||||
const TensorAcc4R input1,
|
||||
TensorAcc4R grad_input2,
|
||||
int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW,
|
||||
int batch)
|
||||
{
|
||||
__global__ void correlation_backward_cuda_kernel_input2(
|
||||
const TensorAcc5R grad_output, const TensorAcc4R input1,
|
||||
TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
|
||||
int padW, int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW, int batch) {
|
||||
const int iH = input1.size(2);
|
||||
const int iW = input1.size(3);
|
||||
|
||||
|
@ -216,50 +182,42 @@ __global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_o
|
|||
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
|
||||
prod_sum[ph_off][pw_off] = 0;
|
||||
|
||||
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD)
|
||||
{
|
||||
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
|
||||
int i1 = h - dilation_patchH * (ph - patchRadH);
|
||||
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD)
|
||||
{
|
||||
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
|
||||
int j1 = w - dilation_patchW * (pw - patchRadW);
|
||||
if WITHIN_BOUNDS (i1, j1, iH, iW)
|
||||
{
|
||||
scalar_t val = input1[n][c][i1][j1];
|
||||
if
|
||||
WITHIN_BOUNDS(i1, j1, iH, iW) {
|
||||
scalar_t val = input1[n][c][i1][j1];
|
||||
|
||||
const int h_2 = i1 + padH;
|
||||
const int w_2 = j1 + padW;
|
||||
const int min_h = h_2 - dilatedKH;
|
||||
const int min_w = w_2 - dilatedKW;
|
||||
const int h_2 = i1 + padH;
|
||||
const int w_2 = j1 + padW;
|
||||
const int min_h = h_2 - dilatedKH;
|
||||
const int min_w = w_2 - dilatedKW;
|
||||
|
||||
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH)
|
||||
{
|
||||
int i2 = (h_3) / dH;
|
||||
if (i2 * dH != h_3)
|
||||
continue;
|
||||
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
|
||||
{
|
||||
int j2 = (w_3) / dW;
|
||||
if (j2 * dW != w_3)
|
||||
continue;
|
||||
if WITHIN_BOUNDS (i2, j2, H, W)
|
||||
{
|
||||
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
|
||||
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
|
||||
int i2 = (h_3) / dH;
|
||||
if (i2 * dH != h_3) continue;
|
||||
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
|
||||
int j2 = (w_3) / dW;
|
||||
if (j2 * dW != w_3) continue;
|
||||
if
|
||||
WITHIN_BOUNDS(i2, j2, H, W) {
|
||||
prod_sum[ph_off][pw_off] +=
|
||||
grad_output[n][ph][pw][i2][j2] * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (ph_off == 0 && pw_off == 0)
|
||||
{
|
||||
if (ph_off == 0 && pw_off == 0) {
|
||||
scalar_t reduce_sum = 0;
|
||||
for (int ph = 0; ph < THREADS_BACKWARD; ++ph)
|
||||
{
|
||||
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
|
||||
{
|
||||
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
|
||||
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
|
||||
reduce_sum += prod_sum[ph][pw];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,116 +1,87 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include <iostream>
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
|
||||
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
|
||||
Tensor output, int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH,
|
||||
int dilation_patchW,
|
||||
int dH, int dW);
|
||||
int patchH, int patchW, int padH,
|
||||
int padW, int dilationH,
|
||||
int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1,
|
||||
Tensor input2, Tensor grad_input1,
|
||||
Tensor grad_input2, int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH,
|
||||
int dilation_patchW,
|
||||
int dH, int dW);
|
||||
int patchH, int patchW, int padH,
|
||||
int padW, int dilationH,
|
||||
int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output,
|
||||
int kH, int kW, int patchH, int patchW,
|
||||
int padH, int padW, int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW)
|
||||
{
|
||||
|
||||
CorrelationForwardCUDAKernelLauncher(input1, input2, output, kH, kW,
|
||||
patchH, patchW, padH, padW, dilationH,
|
||||
dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
int kH, int kW, int patchH, int patchW, int padH,
|
||||
int padW, int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW, int dH,
|
||||
int dW) {
|
||||
CorrelationForwardCUDAKernelLauncher(
|
||||
input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH,
|
||||
dilationW, dilation_patchH, dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
void correlation_cuda_backward(Tensor grad_output,
|
||||
Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2,
|
||||
int kH, int kW, int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW)
|
||||
{
|
||||
CorrelationBackwardCUDAKernelLauncher(grad_output, input1, input2,
|
||||
grad_input1, grad_input2, kH, kW,
|
||||
patchH, patchW, padH, padW,
|
||||
dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW,
|
||||
dH, dW);
|
||||
void correlation_cuda_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2, int kH,
|
||||
int kW, int patchH, int patchW, int padH,
|
||||
int padW, int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW, int dH,
|
||||
int dW) {
|
||||
CorrelationBackwardCUDAKernelLauncher(
|
||||
grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH,
|
||||
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void correlation_forward(Tensor input1, Tensor input2, Tensor output,
|
||||
int kH, int kW, int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW)
|
||||
{
|
||||
if (input1.device().is_cuda() and input2.device().is_cuda())
|
||||
{
|
||||
void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW) {
|
||||
if (input1.device().is_cuda() and input2.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(input1);
|
||||
CHECK_CUDA_INPUT(input2);
|
||||
correlation_cuda_forward(input1, input2, output, kH, kW,
|
||||
patchH, patchW, padH, padW,
|
||||
dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW,
|
||||
dH, dW);
|
||||
CHECK_CUDA_INPUT(input1);
|
||||
CHECK_CUDA_INPUT(input2);
|
||||
correlation_cuda_forward(input1, input2, output, kH, kW, patchH, patchW,
|
||||
padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
#else
|
||||
AT_ERROR("Correlation is not compiled with GPU support");
|
||||
AT_ERROR("Correlation is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
AT_ERROR("Correlation is not implemented on CPU");
|
||||
}
|
||||
} else {
|
||||
AT_ERROR("Correlation is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
||||
void correlation_backward(Tensor grad_output,
|
||||
Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2,
|
||||
int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW)
|
||||
{
|
||||
if (input1.device().is_cuda() and input2.device().is_cuda())
|
||||
{
|
||||
void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW) {
|
||||
if (input1.device().is_cuda() and input2.device().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(grad_output);
|
||||
CHECK_CUDA_INPUT(input1);
|
||||
CHECK_CUDA_INPUT(input2);
|
||||
correlation_cuda_backward(grad_output, input1, input2,
|
||||
grad_input1, grad_input2, kH, kW,
|
||||
patchH, patchW, padH, padW,
|
||||
dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW,
|
||||
dH, dW);
|
||||
CHECK_CUDA_INPUT(grad_output);
|
||||
CHECK_CUDA_INPUT(input1);
|
||||
CHECK_CUDA_INPUT(input2);
|
||||
correlation_cuda_backward(grad_output, input1, input2, grad_input1,
|
||||
grad_input2, kH, kW, patchH, patchW, padH, padW,
|
||||
dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
|
||||
#else
|
||||
AT_ERROR("Correlation is not compiled with GPU support");
|
||||
AT_ERROR("Correlation is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
AT_ERROR("Correlation is not implemented on CPU");
|
||||
}
|
||||
} else {
|
||||
AT_ERROR("Correlation is not implemented on CPU");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,99 +7,87 @@
|
|||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
|
||||
Tensor output, int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH,
|
||||
int dilation_patchW,
|
||||
int dH, int dW)
|
||||
{
|
||||
Tensor output, int kH, int kW,
|
||||
int patchH, int patchW, int padH,
|
||||
int padW, int dilationH,
|
||||
int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW) {
|
||||
const int batch_size = input1.size(0);
|
||||
const int iH = input1.size(2);
|
||||
const int iW = input1.size(3);
|
||||
const int dilatedKH = (kH - 1) * dilationH + 1;
|
||||
const int dilatedKW = (kW - 1) * dilationW + 1;
|
||||
|
||||
const int batch_size = input1.size(0);
|
||||
const int iH = input1.size(2);
|
||||
const int iW = input1.size(3);
|
||||
const int dilatedKH = (kH - 1) * dilationH + 1;
|
||||
const int dilatedKW = (kW - 1) * dilationW + 1;
|
||||
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
|
||||
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
|
||||
|
||||
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
|
||||
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
|
||||
|
||||
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
|
||||
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
|
||||
const int threads = THREADS_FORWARD;
|
||||
const dim3 blocks(batch_size, oH, oW);
|
||||
|
||||
at::cuda::CUDAGuard device_guard(input1.device());
|
||||
|
||||
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
|
||||
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
|
||||
|
||||
const int threads = THREADS_FORWARD;
|
||||
const dim3 blocks(batch_size, oH, oW);
|
||||
|
||||
at::cuda::CUDAGuard device_guard(input1.device());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(),
|
||||
"correlation_forward_cuda",
|
||||
([&]{
|
||||
TensorAcc4R trInput1_acc = trInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
||||
TensorAcc4R trInput2_acc = trInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
||||
TensorAcc5R output_acc = output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
|
||||
|
||||
correlation_forward_cuda_kernel<scalar_t><<<blocks, threads, 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
trInput1_acc, trInput2_acc, output_acc,
|
||||
kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW, dH, dW);
|
||||
|
||||
}));
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input1.scalar_type(), "correlation_forward_cuda", ([&] {
|
||||
TensorAcc4R trInput1_acc =
|
||||
trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
|
||||
TensorAcc4R trInput2_acc =
|
||||
trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
|
||||
TensorAcc5R output_acc =
|
||||
output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
|
||||
|
||||
correlation_forward_cuda_kernel<scalar_t>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
|
||||
padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW);
|
||||
}));
|
||||
}
|
||||
|
||||
void CorrelationBackwardCUDAKernelLauncher(
|
||||
Tensor grad_output, Tensor input1, Tensor input2, Tensor grad_input1,
|
||||
Tensor grad_input2, int kH, int kW, int patchH, int patchW, int padH,
|
||||
int padW, int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW) {
|
||||
const int batch_size = input1.size(0);
|
||||
const int iH = input1.size(2);
|
||||
const int iW = input1.size(3);
|
||||
const int C = input1.size(1);
|
||||
|
||||
void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1,
|
||||
Tensor input2, Tensor grad_input1,
|
||||
Tensor grad_input2, int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH,
|
||||
int dilation_patchW,
|
||||
int dH, int dW){
|
||||
const int batch_size = input1.size(0);
|
||||
const int iH = input1.size(2);
|
||||
const int iW = input1.size(3);
|
||||
const int C = input1.size(1);
|
||||
const dim3 blocks(C, iH, iW);
|
||||
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
|
||||
|
||||
const dim3 blocks(C, iH, iW);
|
||||
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
|
||||
at::cuda::CUDAGuard device_guard(input1.device());
|
||||
|
||||
at::cuda::CUDAGuard device_guard(input1.device());
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input1.scalar_type(), "correlation_backward_cuda", ([&] {
|
||||
TensorAcc4R input1_acc =
|
||||
input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
|
||||
TensorAcc4R input2_acc =
|
||||
input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
|
||||
TensorAcc4R grad_input1_acc =
|
||||
grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
|
||||
TensorAcc4R grad_input2_acc =
|
||||
grad_input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
|
||||
TensorAcc5R grad_output_acc =
|
||||
grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(),
|
||||
"correlation_backward_cuda",
|
||||
([&]{
|
||||
TensorAcc4R input1_acc = input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
||||
TensorAcc4R input2_acc = input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
||||
TensorAcc4R grad_input1_acc = grad_input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
||||
TensorAcc4R grad_input2_acc = grad_input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
||||
TensorAcc5R grad_output_acc = grad_output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
|
||||
for (int n = 0; n < batch_size; ++n) {
|
||||
correlation_backward_cuda_kernel_input1<scalar_t>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
|
||||
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW, n);
|
||||
}
|
||||
|
||||
for (int n = 0; n < batch_size; ++n){
|
||||
correlation_backward_cuda_kernel_input1<scalar_t><<<blocks, threads, 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_output_acc, input2_acc, grad_input1_acc,
|
||||
kH, kW, patchH, patchW, padH, padW,
|
||||
dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW,
|
||||
dH, dW, n);
|
||||
}
|
||||
|
||||
for (int n = 0; n < batch_size; ++n){
|
||||
correlation_backward_cuda_kernel_input2<scalar_t><<<blocks, threads, 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_output_acc, input1_acc, grad_input2_acc,
|
||||
kH, kW, patchH, patchW, padH, padW,
|
||||
dilationH, dilationW,
|
||||
dilation_patchH, dilation_patchW,
|
||||
dH, dW, n);
|
||||
|
||||
}
|
||||
}));
|
||||
for (int n = 0; n < batch_size; ++n) {
|
||||
correlation_backward_cuda_kernel_input2<scalar_t>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
|
||||
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
|
||||
dilation_patchW, dH, dW, n);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -225,22 +225,16 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
|
|||
const Tensor &argmax_idx, Tensor grad_input,
|
||||
const int pool_size);
|
||||
|
||||
void correlation_forward(Tensor input1, Tensor input2, Tensor output,
|
||||
int kH, int kW, int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW);
|
||||
void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
void correlation_backward(Tensor grad_output,
|
||||
Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2,
|
||||
int kH, int kW,
|
||||
int patchH, int patchW,
|
||||
int padH, int padW,
|
||||
int dilationH, int dilationW,
|
||||
int dilation_patchH, int dilation_patchW,
|
||||
int dH, int dW);
|
||||
void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
|
||||
Tensor grad_input1, Tensor grad_input2, int kH,
|
||||
int kW, int patchH, int patchW, int padH, int padW,
|
||||
int dilationH, int dilationW, int dilation_patchH,
|
||||
int dilation_patchW, int dH, int dW);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
|
|
|
@ -1,23 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import Correlation
|
||||
|
||||
_input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
|
||||
_input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
|
||||
_input2_2 = [[[[1., 2.], [3., 1.], [8., 5.]]]]
|
||||
|
||||
gt_out_shape = (1, 1, 1, 3, 3)
|
||||
_gt_out = [[[[[1., 4., 9.], [0., 1., 4.], [24., 25., 4.]]]]]
|
||||
gt_input1_grad = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]]
|
||||
_ap_gt_out = [[[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]],
|
||||
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]],
|
||||
[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]]],
|
||||
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
|
||||
[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]],
|
||||
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]],
|
||||
[[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]],
|
||||
[[5., 10., 15.], [15., 5., 10.], [40., 25., 10.]],
|
||||
[[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]]]]
|
||||
|
||||
|
||||
def assert_equal_tensor(tensor_a, tensor_b):
|
||||
|
@ -43,6 +35,8 @@ class TestCorrelation:
|
|||
assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu())
|
||||
assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu())
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_correlation(self):
|
||||
self._test_correlation(torch.float)
|
||||
self._test_correlation(torch.double)
|
Loading…
Reference in New Issue