diff --git a/mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh index 4e59bd3dc..12225ffdb 100644 --- a/mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh @@ -14,8 +14,6 @@ #include "common_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp" -const int CUDA_NUM_THREADS = 1024; - template __device__ scalar_t ms_deform_attn_im2col_bilinear( const scalar_t *&bottom_data, const int &height, const int &width, @@ -264,10 +262,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + const int qid_stride = num_heads * channels; CUDA_1D_KERNEL_LOOP(index, n) { - __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; - __shared__ scalar_t cache_grad_attn_weight[blockSize]; - unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; @@ -282,11 +281,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; - const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col = 0; l_col < num_levels; ++l_col) { @@ -323,23 +322,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( _grad_h = cache_grad_sampling_loc[1], _grad_a = cache_grad_attn_weight[0]; int sid = 2; - for (unsigned int tid = 1; tid < blockSize; ++tid) { + for (unsigned int _tid = 1; _tid < blockSize; ++_tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; - _grad_a += cache_grad_attn_weight[tid]; + _grad_a += cache_grad_attn_weight[_tid]; sid += 2; } - *grad_sampling_loc = _grad_w; - *(grad_sampling_loc + 1) = _grad_h; - *grad_attn_weight = _grad_a; + *grad_sampling_loc_out = _grad_w; + *(grad_sampling_loc_out + 1) = _grad_h; + *grad_attn_weight_out = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; } } } @@ -354,10 +353,10 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; CUDA_1D_KERNEL_LOOP(index, n) { - __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; - __shared__ scalar_t cache_grad_attn_weight[blockSize]; - unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; @@ -372,8 +371,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; @@ -422,16 +422,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( } if (tid == 0) { - *grad_sampling_loc = cache_grad_sampling_loc[0]; - *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; - *grad_attn_weight = cache_grad_attn_weight[0]; + *grad_sampling_loc_out = cache_grad_sampling_loc[0]; + *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight_out = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; } } } @@ -446,11 +446,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; CUDA_1D_KERNEL_LOOP(index, n) { - extern __shared__ int _s[]; - scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); - scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; @@ -465,8 +465,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; @@ -506,23 +507,23 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( _grad_h = cache_grad_sampling_loc[1], _grad_a = cache_grad_attn_weight[0]; int sid = 2; - for (unsigned int tid = 1; tid < blockDim.x; ++tid) { + for (unsigned int _tid = 1; _tid < blockDim.x; ++_tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; - _grad_a += cache_grad_attn_weight[tid]; + _grad_a += cache_grad_attn_weight[_tid]; sid += 2; } - *grad_sampling_loc = _grad_w; - *(grad_sampling_loc + 1) = _grad_h; - *grad_attn_weight = _grad_a; + *grad_sampling_loc_out = _grad_w; + *(grad_sampling_loc_out + 1) = _grad_h; + *grad_attn_weight_out = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; } } } @@ -537,11 +538,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; CUDA_1D_KERNEL_LOOP(index, n) { - extern __shared__ int _s[]; - scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); - scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; @@ -556,8 +557,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; @@ -615,16 +617,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( } if (tid == 0) { - *grad_sampling_loc = cache_grad_sampling_loc[0]; - *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; - *grad_attn_weight = cache_grad_attn_weight[0]; + *grad_sampling_loc_out = cache_grad_sampling_loc[0]; + *(grad_sampling_loc_out + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight_out = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; } } } @@ -639,11 +641,11 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; CUDA_1D_KERNEL_LOOP(index, n) { - extern __shared__ int _s[]; - scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); - scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; @@ -658,8 +660,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; @@ -717,16 +720,16 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( } if (tid == 0) { - atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); - atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); - atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + atomicAdd(grad_sampling_loc_out, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc_out + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight_out, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; } } } @@ -756,8 +759,9 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm( int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; - grad_sampling_loc += grad_sampling_ptr << 1; - grad_attn_weight += grad_sampling_ptr; + scalar_t *grad_sampling_loc_out = + grad_sampling_loc + (grad_sampling_ptr << 1); + scalar_t *grad_attn_weight_out = grad_attn_weight + grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; @@ -784,12 +788,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm( ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, - grad_sampling_loc, grad_attn_weight); + grad_sampling_loc_out, grad_attn_weight_out); } data_weight_ptr += 1; data_loc_w_ptr += 2; - grad_attn_weight += grad_weight_stride; - grad_sampling_loc += grad_loc_stride; + grad_attn_weight_out += grad_weight_stride; + grad_sampling_loc_out += grad_loc_stride; } } } diff --git a/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu index 2fccaa213..fd191ee9c 100644 --- a/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/ms_deform_attn_cuda.cu @@ -31,7 +31,7 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value, const int num_point, scalar_t *data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; - const int num_threads = CUDA_NUM_THREADS; + const int num_threads = THREADS_PER_BLOCK; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, @@ -54,11 +54,11 @@ void ms_deformable_col2im_cuda( const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { const int num_threads = - (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; + (channels > THREADS_PER_BLOCK) ? THREADS_PER_BLOCK : channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; - if (channels > 1024) { - if ((channels & 1023) == 0) { + if (channels > THREADS_PER_BLOCK) { + if ((channels & THREADS_PER_BLOCK - 1) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( @@ -178,16 +178,6 @@ void ms_deformable_col2im_cuda( channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; - case 1024: - ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 - <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, - data_level_start_index, data_sampling_loc, - data_attn_weight, batch_size, spatial_size, num_heads, - channels, num_levels, num_query, num_point, grad_value, - grad_sampling_loc, grad_attn_weight); - break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1