mirror of https://github.com/open-mmlab/mmcv.git
Fix ms deform attn (#1823)
* rename grad_sampling_loc and grad_attn_weight * recover cache initializepull/1836/head
parent
5b5d0c15bc
commit
d929fa4136
|
@ -14,8 +14,6 @@
|
|||
#include "common_cuda_helper.hpp"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t ms_deform_attn_im2col_bilinear(
|
||||
const scalar_t *&bottom_data, const int &height, const int &width,
|
||||
|
@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
|
|||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
||||
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
||||
unsigned int tid = threadIdx.x;
|
||||
const int qid_stride = num_heads * channels;
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
||||
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
|
@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
|
|||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
scalar_t *grad_sampling_loc_out =
|
||||
grad_sampling_loc + (grad_sampling_ptr << 1);
|
||||
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
|
@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
|
|||
_grad_h = cache_grad_sampling_loc[1],
|
||||
_grad_a = cache_grad_attn_weight[0];
|
||||
int sid = 2;
|
||||
for (unsigned int tid = 1; tid < blockSize; ++tid) {
|
||||
for (unsigned int _tid = 1; _tid < blockSize; ++_tid) {
|
||||
_grad_w += cache_grad_sampling_loc[sid];
|
||||
_grad_h += cache_grad_sampling_loc[sid + 1];
|
||||
_grad_a += cache_grad_attn_weight[tid];
|
||||
_grad_a += cache_grad_attn_weight[_tid];
|
||||
sid += 2;
|
||||
}
|
||||
|
||||
*grad_sampling_loc = _grad_w;
|
||||
*(grad_sampling_loc + 1) = _grad_h;
|
||||
*grad_attn_weight = _grad_a;
|
||||
*grad_sampling_loc_out = _grad_w;
|
||||
*(grad_sampling_loc_out + 1) = _grad_h;
|
||||
*grad_attn_weight_out = _grad_a;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
grad_attn_weight_out += grad_weight_stride;
|
||||
grad_sampling_loc_out += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
|
|||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
||||
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
||||
unsigned int tid = threadIdx.x;
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
||||
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
|
@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
|
|||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
scalar_t *grad_sampling_loc_out =
|
||||
grad_sampling_loc + (grad_sampling_ptr << 1);
|
||||
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
|
@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
|
|||
}
|
||||
|
||||
if (tid == 0) {
|
||||
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
||||
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
||||
*grad_attn_weight = cache_grad_attn_weight[0];
|
||||
*grad_sampling_loc_out = cache_grad_sampling_loc[0];
|
||||
*(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
|
||||
*grad_attn_weight_out = cache_grad_attn_weight[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
grad_attn_weight_out += grad_weight_stride;
|
||||
grad_sampling_loc_out += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
|
|||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
|
@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
|
|||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
scalar_t *grad_sampling_loc_out =
|
||||
grad_sampling_loc + (grad_sampling_ptr << 1);
|
||||
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
|
@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
|
|||
_grad_h = cache_grad_sampling_loc[1],
|
||||
_grad_a = cache_grad_attn_weight[0];
|
||||
int sid = 2;
|
||||
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
|
||||
for (unsigned int _tid = 1; _tid < blockDim.x; ++_tid) {
|
||||
_grad_w += cache_grad_sampling_loc[sid];
|
||||
_grad_h += cache_grad_sampling_loc[sid + 1];
|
||||
_grad_a += cache_grad_attn_weight[tid];
|
||||
_grad_a += cache_grad_attn_weight[_tid];
|
||||
sid += 2;
|
||||
}
|
||||
|
||||
*grad_sampling_loc = _grad_w;
|
||||
*(grad_sampling_loc + 1) = _grad_h;
|
||||
*grad_attn_weight = _grad_a;
|
||||
*grad_sampling_loc_out = _grad_w;
|
||||
*(grad_sampling_loc_out + 1) = _grad_h;
|
||||
*grad_attn_weight_out = _grad_a;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
grad_attn_weight_out += grad_weight_stride;
|
||||
grad_sampling_loc_out += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
|
|||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
|
@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
|
|||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
scalar_t *grad_sampling_loc_out =
|
||||
grad_sampling_loc + (grad_sampling_ptr << 1);
|
||||
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
|
@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
|
|||
}
|
||||
|
||||
if (tid == 0) {
|
||||
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
||||
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
||||
*grad_attn_weight = cache_grad_attn_weight[0];
|
||||
*grad_sampling_loc_out = cache_grad_sampling_loc[0];
|
||||
*(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
|
||||
*grad_attn_weight_out = cache_grad_attn_weight[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
grad_attn_weight_out += grad_weight_stride;
|
||||
grad_sampling_loc_out += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
|
|||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
|
@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
|
|||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
scalar_t *grad_sampling_loc_out =
|
||||
grad_sampling_loc + (grad_sampling_ptr << 1);
|
||||
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
|
@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
|
|||
}
|
||||
|
||||
if (tid == 0) {
|
||||
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
||||
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
||||
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
||||
atomicAdd(grad_sampling_loc_out, cache_grad_sampling_loc[0]);
|
||||
atomicAdd(grad_sampling_loc_out + 1, cache_grad_sampling_loc[1]);
|
||||
atomicAdd(grad_attn_weight_out, cache_grad_attn_weight[0]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
grad_attn_weight_out += grad_weight_stride;
|
||||
grad_sampling_loc_out += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
|
|||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
scalar_t *grad_sampling_loc_out =
|
||||
grad_sampling_loc + (grad_sampling_ptr << 1);
|
||||
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
|
@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
|
|||
ms_deform_attn_col2im_bilinear_gm(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
grad_sampling_loc_out, grad_attn_weight_out);
|
||||
}
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
grad_attn_weight_out += grad_weight_stride;
|
||||
grad_sampling_loc_out += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
|
|||
const int num_point, scalar_t *data_col) {
|
||||
const int num_kernels = batch_size * num_query * num_heads * channels;
|
||||
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
||||
const int num_threads = CUDA_NUM_THREADS;
|
||||
const int num_threads = THREADS_PER_BLOCK;
|
||||
ms_deformable_im2col_gpu_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
|
||||
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
|
||||
|
@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda(
|
|||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
const int num_threads =
|
||||
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
|
||||
(channels > THREADS_PER_BLOCK) ? THREADS_PER_BLOCK : channels;
|
||||
const int num_kernels = batch_size * num_query * num_heads * channels;
|
||||
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
||||
if (channels > 1024) {
|
||||
if ((channels & 1023) == 0) {
|
||||
if (channels > THREADS_PER_BLOCK) {
|
||||
if ((channels & THREADS_PER_BLOCK - 1) == 0) {
|
||||
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
||||
num_threads * 3 * sizeof(scalar_t), stream>>>(
|
||||
|
@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda(
|
|||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 1024:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
|
||||
1024>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
default:
|
||||
if (channels < 64) {
|
||||
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
||||
|
|
Loading…
Reference in New Issue