diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu index d04fae8..f09bd24 100644 --- a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu @@ -15,6 +15,7 @@ #include #include #include +#include namespace groundingdino { @@ -26,17 +27,17 @@ at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &attn_weight, const int im2col_step) { - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous"); + TORCH_CHECK(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + TORCH_CHECK(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + TORCH_CHECK(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + TORCH_CHECK(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + TORCH_CHECK(value.device().type() == at::kCUDA, "value must be a CUDA tensor"); + TORCH_CHECK(spatial_shapes.device().type() == at::kCUDA, "spatial_shapes must be a CUDA tensor"); + TORCH_CHECK(level_start_index.device().type() == at::kCUDA, "level_start_index must be a CUDA tensor"); + TORCH_CHECK(sampling_loc.device().type() == at::kCUDA, "sampling_loc must be a CUDA tensor"); + TORCH_CHECK(attn_weight.device().type() == at::kCUDA, "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); @@ -50,7 +51,7 @@ at::Tensor ms_deform_attn_cuda_forward( const int im2col_step_ = std::min(batch, im2col_step); - AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + TORCH_CHECK(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); @@ -59,19 +60,19 @@ at::Tensor ms_deform_attn_cuda_forward( auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), - value.data() + n * im2col_step_ * per_value_size, - spatial_shapes.data(), - level_start_index.data(), - sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - columns.data()); - + columns.data_ptr()); })); } @@ -80,7 +81,6 @@ at::Tensor ms_deform_attn_cuda_forward( return output; } - std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, @@ -90,20 +90,19 @@ std::vector ms_deform_attn_cuda_backward( const at::Tensor &grad_output, const int im2col_step) { + TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous"); + TORCH_CHECK(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + TORCH_CHECK(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + TORCH_CHECK(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + TORCH_CHECK(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + TORCH_CHECK(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); - AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + TORCH_CHECK(value.device().type() == at::kCUDA, "value must be a CUDA tensor"); + TORCH_CHECK(spatial_shapes.device().type() == at::kCUDA, "spatial_shapes must be a CUDA tensor"); + TORCH_CHECK(level_start_index.device().type() == at::kCUDA, "level_start_index must be a CUDA tensor"); + TORCH_CHECK(sampling_loc.device().type() == at::kCUDA, "sampling_loc must be a CUDA tensor"); + TORCH_CHECK(attn_weight.device().type() == at::kCUDA, "attn_weight must be a CUDA tensor"); + TORCH_CHECK(grad_output.device().type() == at::kCUDA, "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); @@ -117,7 +116,7 @@ std::vector ms_deform_attn_cuda_backward( const int im2col_step_ = std::min(batch, im2col_step); - AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + TORCH_CHECK(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); @@ -132,19 +131,18 @@ std::vector ms_deform_attn_cuda_backward( for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), - grad_output_g.data(), - value.data() + n * im2col_step_ * per_value_size, - spatial_shapes.data(), - level_start_index.data(), - sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - grad_value.data() + n * im2col_step_ * per_value_size, - grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); - + grad_value.data_ptr() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); })); } @@ -153,4 +151,4 @@ std::vector ms_deform_attn_cuda_backward( }; } -} // namespace groundingdino \ No newline at end of file +} // namespace groundingdino