From 947976cc17d6dc084c715282c452b41127e2d093 Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Wed, 16 Apr 2025 15:28:17 -0700 Subject: [PATCH] Add torch2.6 support for ms_deform_attn_cuda --- .../csrc/MsDeformAttn/ms_deform_attn_cuda.cu | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu index d04fae8..e65670b 100644 --- a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu @@ -15,11 +15,24 @@ #include #include #include +#include +#include + +// Check PyTorch version and define appropriate macros +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6 + // PyTorch 2.x and above + #define GET_TENSOR_TYPE(x) x.scalar_type() + #define IS_CUDA_TENSOR(x) x.device().is_cuda() +#else + // PyTorch 1.x + #define GET_TENSOR_TYPE(x) x.type() + #define IS_CUDA_TENSOR(x) x.type().is_cuda() +#endif namespace groundingdino { at::Tensor ms_deform_attn_cuda_forward( - const at::Tensor &value, + const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, @@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward( AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); @@ -51,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward( const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); - + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; @@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward( for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), @@ -82,7 +95,7 @@ at::Tensor ms_deform_attn_cuda_forward( std::vector ms_deform_attn_cuda_backward( - const at::Tensor &value, + const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, @@ -98,12 +111,12 @@ std::vector ms_deform_attn_cuda_backward( AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); - AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); @@ -128,11 +141,11 @@ std::vector ms_deform_attn_cuda_backward( auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); - + for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, @@ -153,4 +166,4 @@ std::vector ms_deform_attn_cuda_backward( }; } -} // namespace groundingdino \ No newline at end of file +} // namespace groundingdino