Update ms_deform_attn_cuda.cu to support cuda 11.8

pull/383/head
Ziqi Gao (Roy) 2025-02-10 15:43:58 -08:00 committed by GitHub
parent b604d1fed5
commit 023c6b5c77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 46 deletions

View File

@ -15,6 +15,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
namespace groundingdino {
@ -26,17 +27,17 @@ at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous");
TORCH_CHECK(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
TORCH_CHECK(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
TORCH_CHECK(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
TORCH_CHECK(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
TORCH_CHECK(value.device().type() == at::kCUDA, "value must be a CUDA tensor");
TORCH_CHECK(spatial_shapes.device().type() == at::kCUDA, "spatial_shapes must be a CUDA tensor");
TORCH_CHECK(level_start_index.device().type() == at::kCUDA, "level_start_index must be a CUDA tensor");
TORCH_CHECK(sampling_loc.device().type() == at::kCUDA, "sampling_loc must be a CUDA tensor");
TORCH_CHECK(attn_weight.device().type() == at::kCUDA, "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
@ -50,7 +51,7 @@ at::Tensor ms_deform_attn_cuda_forward(
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
TORCH_CHECK(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
@ -59,19 +60,19 @@ at::Tensor ms_deform_attn_cuda_forward(
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
columns.data_ptr<scalar_t>());
}));
}
@ -80,7 +81,6 @@ at::Tensor ms_deform_attn_cuda_forward(
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
@ -90,20 +90,19 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &grad_output,
const int im2col_step)
{
TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous");
TORCH_CHECK(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
TORCH_CHECK(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
TORCH_CHECK(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
TORCH_CHECK(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
TORCH_CHECK(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
TORCH_CHECK(value.device().type() == at::kCUDA, "value must be a CUDA tensor");
TORCH_CHECK(spatial_shapes.device().type() == at::kCUDA, "spatial_shapes must be a CUDA tensor");
TORCH_CHECK(level_start_index.device().type() == at::kCUDA, "level_start_index must be a CUDA tensor");
TORCH_CHECK(sampling_loc.device().type() == at::kCUDA, "sampling_loc must be a CUDA tensor");
TORCH_CHECK(attn_weight.device().type() == at::kCUDA, "attn_weight must be a CUDA tensor");
TORCH_CHECK(grad_output.device().type() == at::kCUDA, "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
@ -117,7 +116,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
TORCH_CHECK(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
@ -132,19 +131,18 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
grad_output_g.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
@ -153,4 +151,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
};
}
} // namespace groundingdino
} // namespace groundingdino