Fix ms deform attn (#1823)

* rename grad_sampling_loc and grad_attn_weight

* recover cache initialize
pull/1836/head
q.yao 2022-03-24 21:55:33 +08:00 committed by GitHub
parent 5b5d0c15bc
commit d929fa4136
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 79 deletions

View File

@ -14,8 +14,6 @@
#include "common_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp"
const int CUDA_NUM_THREADS = 1024;
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width,
@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
const int qid_stride = num_heads * channels;
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
scalar_t *grad_sampling_loc_out =
grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) {
for (unsigned int _tid = 1; _tid < blockSize; ++_tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
_grad_a += cache_grad_attn_weight[_tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
*grad_sampling_loc_out = _grad_w;
*(grad_sampling_loc_out + 1) = _grad_h;
*grad_attn_weight_out = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc_out += grad_loc_stride;
}
}
}
@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
scalar_t *grad_sampling_loc_out =
grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
*grad_sampling_loc_out = cache_grad_sampling_loc[0];
*(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight_out = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc_out += grad_loc_stride;
}
}
}
@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
scalar_t *grad_sampling_loc_out =
grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
for (unsigned int _tid = 1; _tid < blockDim.x; ++_tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
_grad_a += cache_grad_attn_weight[_tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
*grad_sampling_loc_out = _grad_w;
*(grad_sampling_loc_out + 1) = _grad_h;
*grad_attn_weight_out = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc_out += grad_loc_stride;
}
}
}
@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
scalar_t *grad_sampling_loc_out =
grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
*grad_sampling_loc_out = cache_grad_sampling_loc[0];
*(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight_out = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc_out += grad_loc_stride;
}
}
}
@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
scalar_t *grad_sampling_loc_out =
grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
}
if (tid == 0) {
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
atomicAdd(grad_sampling_loc_out, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc_out + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight_out, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc_out += grad_loc_stride;
}
}
}
@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
scalar_t *grad_sampling_loc_out =
grad_sampling_loc + (grad_sampling_ptr << 1);
scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
grad_sampling_loc_out, grad_attn_weight_out);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
grad_attn_weight_out += grad_weight_stride;
grad_sampling_loc_out += grad_loc_stride;
}
}
}

View File

@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int num_point, scalar_t *data_col) {
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
const int num_threads = THREADS_PER_BLOCK;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda(
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
(channels > THREADS_PER_BLOCK) ? THREADS_PER_BLOCK : channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024) {
if ((channels & 1023) == 0) {
if (channels > THREADS_PER_BLOCK) {
if ((channels & THREADS_PER_BLOCK - 1) == 0) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda(
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
default:
if (channels < 64) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>