From 745aa7373eaf3f80f8e789a26ee889e61c151fa1 Mon Sep 17 00:00:00 2001 From: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Date: Sat, 25 Sep 2021 21:13:34 +0800 Subject: [PATCH] [Fix] Revise unit test of correlation (#1368) * [Fix] Revise unit test of correlation * rename * lint * lint * lint * lint --- .../ops/csrc/common/cuda/correlation_cuda.cuh | 188 +++++++----------- mmcv/ops/csrc/pytorch/correlation.cpp | 143 ++++++------- .../ops/csrc/pytorch/cuda/correlation_cuda.cu | 152 +++++++------- mmcv/ops/csrc/pytorch/pybind.cpp | 24 +-- .../{test_corr.py => test_correlation.py} | 14 +- 5 files changed, 213 insertions(+), 308 deletions(-) rename tests/test_ops/{test_corr.py => test_correlation.py} (67%) diff --git a/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh b/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh index b479e21bf..d049e9e2d 100644 --- a/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh +++ b/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh @@ -15,8 +15,9 @@ #include #include #include -#include + #include +#include using namespace torch; @@ -28,17 +29,10 @@ using namespace torch; #define THREADS_BACKWARD 16 template -__global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, - const TensorAcc4R rInput2, - TensorAcc5R output, - int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, - int dilation_patchW, - int dH, int dW) -{ +__global__ void correlation_forward_cuda_kernel( + const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output, + int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH, + int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) { const int iH = rInput1.size(1); const int iW = rInput1.size(2); const int C = rInput1.size(3); @@ -56,42 +50,35 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, __shared__ scalar_t prod_sum[THREADS_FORWARD]; - for (int ph = 0; ph < patchH; ++ph) - { + for (int ph = 0; ph < patchH; ++ph) { int ph_dilated = ph * dilation_patchH - patchRadH; - for (int pw = 0; pw < patchW; ++pw) - { + for (int pw = 0; pw < patchW; ++pw) { int pw_dilated = pw * dilation_patchW - patchRadW; prod_sum[thread] = 0; - for (int i = 0; i < kH; ++i) - { + for (int i = 0; i < kH; ++i) { int i1 = start_i + i * dilationH; int i2 = i1 + ph_dilated; - if WITHIN_BOUNDS (i1, i2, iH, iH) - { - for (int j = 0; j < kW; ++j) - { - int j1 = start_j + j * dilationW; - int j2 = j1 + pw_dilated; - if WITHIN_BOUNDS (j1, j2, iW, iW) - { - for (int c = thread; c < C; c += THREADS_FORWARD) - { - scalar_t v1 = rInput1[n][i1][j1][c]; - scalar_t v2 = rInput2[n][i2][j2][c]; - prod_sum[thread] += v1 * v2; - } + if + WITHIN_BOUNDS(i1, i2, iH, iH) { + for (int j = 0; j < kW; ++j) { + int j1 = start_j + j * dilationW; + int j2 = j1 + pw_dilated; + if + WITHIN_BOUNDS(j1, j2, iW, iW) { + for (int c = thread; c < C; c += THREADS_FORWARD) { + scalar_t v1 = rInput1[n][i1][j1][c]; + scalar_t v2 = rInput2[n][i2][j2][c]; + prod_sum[thread] += v1 * v2; + } + } } } - } } // accumulate __syncthreads(); - if (thread == 0) - { + if (thread == 0) { scalar_t reduce_sum = 0; - for (int index = 0; index < THREADS_FORWARD; ++index) - { + for (int index = 0; index < THREADS_FORWARD; ++index) { reduce_sum += prod_sum[index]; } output[n][ph][pw][h][w] = reduce_sum; @@ -101,18 +88,12 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1, } template -__global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_output, - const TensorAcc4R input2, - TensorAcc4R grad_input1, - const int kH, const int kW, - const int patchH, const int patchW, - const int padH, const int padW, - const int dilationH, const int dilationW, - const int dilation_patchH, const int dilation_patchW, - const int dH, const int dW, - const int batch) -{ - +__global__ void correlation_backward_cuda_kernel_input1( + const TensorAcc5R grad_output, const TensorAcc4R input2, + TensorAcc4R grad_input1, const int kH, const int kW, const int patchH, + const int patchW, const int padH, const int padW, const int dilationH, + const int dilationW, const int dilation_patchH, const int dilation_patchW, + const int dH, const int dW, const int batch) { const int iH = input2.size(2); const int iW = input2.size(3); @@ -137,29 +118,23 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; prod_sum[ph_off][pw_off] = 0; - for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) - { + for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) { int i1 = h + dilation_patchH * (ph - patchRadH); - for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) - { + for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) { int j1 = w + dilation_patchW * (pw - patchRadW); - if (WITHIN_BOUNDS(i1, j1, iH, iW)) - { + if (WITHIN_BOUNDS(i1, j1, iH, iW)) { scalar_t val = input2[n][c][i1][j1]; - for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) - { + for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { int i2 = (h_3) / dH; - if (i2 * dH != h_3) - continue; - for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) - { + if (i2 * dH != h_3) continue; + for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { int j2 = (w_3) / dW; - if (j2 * dW != w_3) - continue; - if WITHIN_BOUNDS (i2, j2, H, W) - { - prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val; - } + if (j2 * dW != w_3) continue; + if + WITHIN_BOUNDS(i2, j2, H, W) { + prod_sum[ph_off][pw_off] += + grad_output[n][ph][pw][i2][j2] * val; + } } } } @@ -168,13 +143,10 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o __syncthreads(); - if (ph_off == 0 && pw_off == 0) - { + if (ph_off == 0 && pw_off == 0) { scalar_t reduce_sum = 0; - for (int ph = 0; ph < THREADS_BACKWARD; ++ph) - { - for (int pw = 0; pw < THREADS_BACKWARD; ++pw) - { + for (int ph = 0; ph < THREADS_BACKWARD; ++ph) { + for (int pw = 0; pw < THREADS_BACKWARD; ++pw) { reduce_sum += prod_sum[ph][pw]; } } @@ -183,17 +155,11 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o } template -__global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_output, - const TensorAcc4R input1, - TensorAcc4R grad_input2, - int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW, - int batch) -{ +__global__ void correlation_backward_cuda_kernel_input2( + const TensorAcc5R grad_output, const TensorAcc4R input1, + TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW, int batch) { const int iH = input1.size(2); const int iW = input1.size(3); @@ -216,50 +182,42 @@ __global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_o __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; prod_sum[ph_off][pw_off] = 0; - for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) - { + for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) { int i1 = h - dilation_patchH * (ph - patchRadH); - for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) - { + for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) { int j1 = w - dilation_patchW * (pw - patchRadW); - if WITHIN_BOUNDS (i1, j1, iH, iW) - { - scalar_t val = input1[n][c][i1][j1]; + if + WITHIN_BOUNDS(i1, j1, iH, iW) { + scalar_t val = input1[n][c][i1][j1]; - const int h_2 = i1 + padH; - const int w_2 = j1 + padW; - const int min_h = h_2 - dilatedKH; - const int min_w = w_2 - dilatedKW; + const int h_2 = i1 + padH; + const int w_2 = j1 + padW; + const int min_h = h_2 - dilatedKH; + const int min_w = w_2 - dilatedKW; - for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) - { - int i2 = (h_3) / dH; - if (i2 * dH != h_3) - continue; - for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) - { - int j2 = (w_3) / dW; - if (j2 * dW != w_3) - continue; - if WITHIN_BOUNDS (i2, j2, H, W) - { - prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val; + for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { + int i2 = (h_3) / dH; + if (i2 * dH != h_3) continue; + for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { + int j2 = (w_3) / dW; + if (j2 * dW != w_3) continue; + if + WITHIN_BOUNDS(i2, j2, H, W) { + prod_sum[ph_off][pw_off] += + grad_output[n][ph][pw][i2][j2] * val; + } } } } - } } } __syncthreads(); - if (ph_off == 0 && pw_off == 0) - { + if (ph_off == 0 && pw_off == 0) { scalar_t reduce_sum = 0; - for (int ph = 0; ph < THREADS_BACKWARD; ++ph) - { - for (int pw = 0; pw < THREADS_BACKWARD; ++pw) - { + for (int ph = 0; ph < THREADS_BACKWARD; ++ph) { + for (int pw = 0; pw < THREADS_BACKWARD; ++pw) { reduce_sum += prod_sum[ph][pw]; } } diff --git a/mmcv/ops/csrc/pytorch/correlation.cpp b/mmcv/ops/csrc/pytorch/correlation.cpp index af2864de5..0687b1d12 100644 --- a/mmcv/ops/csrc/pytorch/correlation.cpp +++ b/mmcv/ops/csrc/pytorch/correlation.cpp @@ -1,116 +1,87 @@ // Copyright (c) OpenMMLab. All rights reserved. #include + #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, Tensor output, int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, - int dilation_patchW, - int dH, int dW); + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, Tensor input2, Tensor grad_input1, Tensor grad_input2, int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, - int dilation_patchW, - int dH, int dW); + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); void correlation_cuda_forward(Tensor input1, Tensor input2, Tensor output, - int kH, int kW, int patchH, int patchW, - int padH, int padW, int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW) -{ - - CorrelationForwardCUDAKernelLauncher(input1, input2, output, kH, kW, - patchH, patchW, padH, padW, dilationH, - dilationW, dilation_patchH, - dilation_patchW, dH, dW); + int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationForwardCUDAKernelLauncher( + input1, input2, output, kH, kW, patchH, patchW, padH, padW, dilationH, + dilationW, dilation_patchH, dilation_patchW, dH, dW); } -void correlation_cuda_backward(Tensor grad_output, - Tensor input1, Tensor input2, - Tensor grad_input1, Tensor grad_input2, - int kH, int kW, int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW) -{ - CorrelationBackwardCUDAKernelLauncher(grad_output, input1, input2, - grad_input1, grad_input2, kH, kW, - patchH, patchW, padH, padW, - dilationH, dilationW, - dilation_patchH, dilation_patchW, - dH, dW); +void correlation_cuda_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, int dH, + int dW) { + CorrelationBackwardCUDAKernelLauncher( + grad_output, input1, input2, grad_input1, grad_input2, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); } #endif -void correlation_forward(Tensor input1, Tensor input2, Tensor output, - int kH, int kW, int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW) -{ - if (input1.device().is_cuda() and input2.device().is_cuda()) - { +void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + if (input1.device().is_cuda() and input2.device().is_cuda()) { #ifdef MMCV_WITH_CUDA - CHECK_CUDA_INPUT(input1); - CHECK_CUDA_INPUT(input2); - correlation_cuda_forward(input1, input2, output, kH, kW, - patchH, patchW, padH, padW, - dilationH, dilationW, - dilation_patchH, dilation_patchW, - dH, dW); + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_forward(input1, input2, output, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); #else - AT_ERROR("Correlation is not compiled with GPU support"); + AT_ERROR("Correlation is not compiled with GPU support"); #endif - } - else - { - AT_ERROR("Correlation is not implemented on CPU"); - } + } else { + AT_ERROR("Correlation is not implemented on CPU"); + } } -void correlation_backward(Tensor grad_output, - Tensor input1, Tensor input2, - Tensor grad_input1, Tensor grad_input2, - int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW) -{ - if (input1.device().is_cuda() and input2.device().is_cuda()) - { +void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + if (input1.device().is_cuda() and input2.device().is_cuda()) { #ifdef MMCV_WITH_CUDA - CHECK_CUDA_INPUT(grad_output); - CHECK_CUDA_INPUT(input1); - CHECK_CUDA_INPUT(input2); - correlation_cuda_backward(grad_output, input1, input2, - grad_input1, grad_input2, kH, kW, - patchH, patchW, padH, padW, - dilationH, dilationW, - dilation_patchH, dilation_patchW, - dH, dW); + CHECK_CUDA_INPUT(grad_output); + CHECK_CUDA_INPUT(input1); + CHECK_CUDA_INPUT(input2); + correlation_cuda_backward(grad_output, input1, input2, grad_input1, + grad_input2, kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); #else - AT_ERROR("Correlation is not compiled with GPU support"); + AT_ERROR("Correlation is not compiled with GPU support"); #endif - } - else - { - AT_ERROR("Correlation is not implemented on CPU"); - } + } else { + AT_ERROR("Correlation is not implemented on CPU"); + } } diff --git a/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu index 86991b362..56d2e644f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu @@ -7,99 +7,87 @@ #include "pytorch_cuda_helper.hpp" void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2, - Tensor output, int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, - int dilation_patchW, - int dH, int dW) -{ + Tensor output, int kH, int kW, + int patchH, int patchW, int padH, + int padW, int dilationH, + int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int dilatedKH = (kH - 1) * dilationH + 1; + const int dilatedKW = (kW - 1) * dilationW + 1; - const int batch_size = input1.size(0); - const int iH = input1.size(2); - const int iW = input1.size(3); - const int dilatedKH = (kH - 1) * dilationH + 1; - const int dilatedKW = (kW - 1) * dilationW + 1; + const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1; + const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1; + auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); + auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); - const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1; - const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1; + const int threads = THREADS_FORWARD; + const dim3 blocks(batch_size, oH, oW); + at::cuda::CUDAGuard device_guard(input1.device()); - auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); - auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); - - const int threads = THREADS_FORWARD; - const dim3 blocks(batch_size, oH, oW); - - at::cuda::CUDAGuard device_guard(input1.device()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), - "correlation_forward_cuda", - ([&]{ - TensorAcc4R trInput1_acc = trInput1.packed_accessor32(); - TensorAcc4R trInput2_acc = trInput2.packed_accessor32(); - TensorAcc5R output_acc = output.packed_accessor32(); - - correlation_forward_cuda_kernel<<>>( - trInput1_acc, trInput2_acc, output_acc, - kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, - dilation_patchH, dilation_patchW, dH, dW); - - })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input1.scalar_type(), "correlation_forward_cuda", ([&] { + TensorAcc4R trInput1_acc = + trInput1.packed_accessor32(); + TensorAcc4R trInput2_acc = + trInput2.packed_accessor32(); + TensorAcc5R output_acc = + output.packed_accessor32(); + correlation_forward_cuda_kernel + <<>>( + trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW); + })); } +void CorrelationBackwardCUDAKernelLauncher( + Tensor grad_output, Tensor input1, Tensor input2, Tensor grad_input1, + Tensor grad_input2, int kH, int kW, int patchH, int patchW, int padH, + int padW, int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW) { + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int C = input1.size(1); -void CorrelationBackwardCUDAKernelLauncher(Tensor grad_output, Tensor input1, - Tensor input2, Tensor grad_input1, - Tensor grad_input2, int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, - int dilation_patchW, - int dH, int dW){ - const int batch_size = input1.size(0); - const int iH = input1.size(2); - const int iW = input1.size(3); - const int C = input1.size(1); + const dim3 blocks(C, iH, iW); + const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD); - const dim3 blocks(C, iH, iW); - const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD); + at::cuda::CUDAGuard device_guard(input1.device()); - at::cuda::CUDAGuard device_guard(input1.device()); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input1.scalar_type(), "correlation_backward_cuda", ([&] { + TensorAcc4R input1_acc = + input1.packed_accessor32(); + TensorAcc4R input2_acc = + input2.packed_accessor32(); + TensorAcc4R grad_input1_acc = + grad_input1.packed_accessor32(); + TensorAcc4R grad_input2_acc = + grad_input2.packed_accessor32(); + TensorAcc5R grad_output_acc = + grad_output.packed_accessor32(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), - "correlation_backward_cuda", - ([&]{ - TensorAcc4R input1_acc = input1.packed_accessor32(); - TensorAcc4R input2_acc = input2.packed_accessor32(); - TensorAcc4R grad_input1_acc = grad_input1.packed_accessor32(); - TensorAcc4R grad_input2_acc = grad_input2.packed_accessor32(); - TensorAcc5R grad_output_acc = grad_output.packed_accessor32(); + for (int n = 0; n < batch_size; ++n) { + correlation_backward_cuda_kernel_input1 + <<>>( + grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW, n); + } - for (int n = 0; n < batch_size; ++n){ - correlation_backward_cuda_kernel_input1<<>>( - grad_output_acc, input2_acc, grad_input1_acc, - kH, kW, patchH, patchW, padH, padW, - dilationH, dilationW, - dilation_patchH, dilation_patchW, - dH, dW, n); - } - - for (int n = 0; n < batch_size; ++n){ - correlation_backward_cuda_kernel_input2<<>>( - grad_output_acc, input1_acc, grad_input2_acc, - kH, kW, patchH, patchW, padH, padW, - dilationH, dilationW, - dilation_patchH, dilation_patchW, - dH, dW, n); - - } - })); + for (int n = 0; n < batch_size; ++n) { + correlation_backward_cuda_kernel_input2 + <<>>( + grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH, + patchW, padH, padW, dilationH, dilationW, dilation_patchH, + dilation_patchW, dH, dW, n); + } + })); } diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 0f173ef73..ee3d4e3bd 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -225,22 +225,16 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, const Tensor &argmax_idx, Tensor grad_input, const int pool_size); -void correlation_forward(Tensor input1, Tensor input2, Tensor output, - int kH, int kW, int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW); +void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); -void correlation_backward(Tensor grad_output, - Tensor input1, Tensor input2, - Tensor grad_input1, Tensor grad_input2, - int kH, int kW, - int patchH, int patchW, - int padH, int padW, - int dilationH, int dilationW, - int dilation_patchH, int dilation_patchW, - int dH, int dW); +void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2, + Tensor grad_input1, Tensor grad_input2, int kH, + int kW, int patchH, int patchW, int padH, int padW, + int dilationH, int dilationW, int dilation_patchH, + int dilation_patchW, int dH, int dW); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), diff --git a/tests/test_ops/test_corr.py b/tests/test_ops/test_correlation.py similarity index 67% rename from tests/test_ops/test_corr.py rename to tests/test_ops/test_correlation.py index 2b34c6ad9..6b75a9f38 100644 --- a/tests/test_ops/test_corr.py +++ b/tests/test_ops/test_correlation.py @@ -1,23 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch from mmcv.ops import Correlation _input1 = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] _input2 = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] -_input2_2 = [[[[1., 2.], [3., 1.], [8., 5.]]]] + gt_out_shape = (1, 1, 1, 3, 3) _gt_out = [[[[[1., 4., 9.], [0., 1., 4.], [24., 25., 4.]]]]] gt_input1_grad = [[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]]]] -_ap_gt_out = [[[[[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]], - [[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]], - [[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]]], - [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], - [[1., 2., 3.], [3., 1., 2.], [8., 5., 2.]], - [[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]], - [[[3., 6., 9.], [9., 3., 6.], [24., 15., 6.]], - [[5., 10., 15.], [15., 5., 10.], [40., 25., 10.]], - [[2., 4., 6.], [6., 2., 4.], [16., 10., 4.]]]]] def assert_equal_tensor(tensor_a, tensor_b): @@ -43,6 +35,8 @@ class TestCorrelation: assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu()) assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu()) + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') def test_correlation(self): self._test_correlation(torch.float) self._test_correlation(torch.double)