mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: support Multi-Scale-DeformAttention in deformable-detr (#878)
* add c++ ms_deform_atten * fix cpp lint * fix cpp lint * clang format * remove cmakefile * google style * clang-format precommit * use clang-format-lint-action * add transformer base class * add merge * add docstr * add pyargs * fix according to commments * resiger module * change to use basemodule * add _ between build function * split the name * fix according to comments * fix lint and fix unitest * fix cpp lint * fix bug of deformdetr_atten * fix drop out * fix residual * use CUDA_1D_KERNEL_LOOPpull/981/head
parent
0dd0c49a5b
commit
54a7ebb4ec
|
@ -1,11 +1,15 @@
|
|||
import copy
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv import ConfigDict
|
||||
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
|
||||
constant_init, xavier_init)
|
||||
from mmcv.ops.multi_scale_deform_attn import (
|
||||
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
from mmcv.utils import build_from_cfg
|
||||
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
|
||||
|
@ -135,6 +139,201 @@ class MultiheadAttention(BaseModule):
|
|||
return residual + self.dropout(out)
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
class MultiScaleDeformableAttention(BaseModule):
|
||||
"""An attention module used in Deformable-Detr. `Deformable DETR:
|
||||
Deformable Transformers for End-to-End Object Detection.
|
||||
|
||||
<https://arxiv.org/pdf/2010.04159.pdf>`_.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension of Attention.
|
||||
Default: 256.
|
||||
num_heads (int): Parallel attention heads. Default: 64.
|
||||
num_levels (int): The number of feature map used in
|
||||
Attention. Default: 4.
|
||||
num_points (int): The number of sampling points for
|
||||
each query in each head. Default: 4.
|
||||
im2col_step (int): The step used in image_to_column.
|
||||
Default: 64.
|
||||
dropout (float): A Dropout layer on `inp_residual`.
|
||||
Default: 0..
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=256,
|
||||
num_heads=8,
|
||||
num_levels=4,
|
||||
num_points=4,
|
||||
im2col_step=64,
|
||||
dropout=0.1,
|
||||
norm_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
if embed_dims % num_heads != 0:
|
||||
raise ValueError(f'embed_dims must be divisible by num_heads, '
|
||||
f'but got {embed_dims} and {num_heads}')
|
||||
dim_per_head = embed_dims // num_heads
|
||||
self.norm_cfg = norm_cfg
|
||||
self.init_cfg = init_cfg
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# you'd better set dim_per_head to a power of 2
|
||||
# which is more efficient in the CUDA implementation
|
||||
def _is_power_of_2(n):
|
||||
if (not isinstance(n, int)) or (n < 0):
|
||||
raise ValueError(
|
||||
'invalid input for _is_power_of_2: {} (type: {})'.format(
|
||||
n, type(n)))
|
||||
return (n & (n - 1) == 0) and n != 0
|
||||
|
||||
if not _is_power_of_2(dim_per_head):
|
||||
warnings.warn(
|
||||
"You'd better set embed_dims in "
|
||||
'MultiScaleDeformAttention to make '
|
||||
'the dimension of each attention head a power of 2 '
|
||||
'which is more efficient in our CUDA implementation.')
|
||||
|
||||
self.im2col_step = im2col_step
|
||||
self.embed_dims = embed_dims
|
||||
self.num_levels = num_levels
|
||||
self.num_heads = num_heads
|
||||
self.num_points = num_points
|
||||
self.sampling_offsets = nn.Linear(
|
||||
embed_dims, num_heads * num_levels * num_points * 2)
|
||||
self.attention_weights = nn.Linear(embed_dims,
|
||||
num_heads * num_levels * num_points)
|
||||
self.value_proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.output_proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
"""Default initialization for Parameters of Module."""
|
||||
constant_init(self.sampling_offsets, 0.)
|
||||
thetas = torch.arange(
|
||||
self.num_heads,
|
||||
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
|
||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||
grid_init = (grid_init /
|
||||
grid_init.abs().max(-1, keepdim=True)[0]).view(
|
||||
self.num_heads, 1, 1,
|
||||
2).repeat(1, self.num_levels, self.num_points, 1)
|
||||
for i in range(self.num_points):
|
||||
grid_init[:, :, i, :] *= i + 1
|
||||
|
||||
self.sampling_offsets.bias.data = grid_init.view(-1)
|
||||
constant_init(self.attention_weights, val=0., bias=0.)
|
||||
xavier_init(self.value_proj, distribution='uniform', bias=0.)
|
||||
xavier_init(self.output_proj, distribution='uniform', bias=0.)
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
residual=None,
|
||||
query_pos=None,
|
||||
key_padding_mask=None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
level_start_index=None,
|
||||
**kwargs):
|
||||
"""Forward Function of MultiScaleDeformAttention.
|
||||
|
||||
Args:
|
||||
query (Tensor): Query of Transformer with shape
|
||||
(num_query, bs, embed_dims).
|
||||
key (Tensor): The key tensor with shape
|
||||
`(num_key, bs, embed_dims)`.
|
||||
value (Tensor): The value tensor with shape
|
||||
`(num_key, bs, embed_dims)`.
|
||||
residual (Tensor): The tensor used for addition, with the
|
||||
same shape as `x`. Default None. If None, `x` will be used.
|
||||
query_pos (Tensor): The positional encoding for `query`.
|
||||
Default: None.
|
||||
key_pos (Tensor): The positional encoding for `key`. Default
|
||||
None.
|
||||
reference_points (Tensor): The normalized reference
|
||||
points with shape (bs, num_query, num_levels, 2),
|
||||
all elements is range in [0, 1], top-left (0,0),
|
||||
bottom-right (1, 1), including padding area.
|
||||
or (N, Length_{query}, num_levels, 4), add
|
||||
additional two dimensions is (w, h) to
|
||||
form reference boxes.
|
||||
key_padding_mask (Tensor): ByteTensor for `query`, with
|
||||
shape [bs, num_key].
|
||||
spatial_shapes (Tensor): Spatial shape of features in
|
||||
different level. With shape (num_levels, 2),
|
||||
last dimension represent (h, w).
|
||||
level_start_index (Tensor): The start index of each level.
|
||||
A tensor has shape (num_levels) and can be represented
|
||||
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
|
||||
|
||||
Returns:
|
||||
Tensor: forwarded results with shape [num_query, bs, embed_dims].
|
||||
"""
|
||||
|
||||
if key is None:
|
||||
key = query
|
||||
if value is None:
|
||||
value = key
|
||||
|
||||
if residual is None:
|
||||
inp_residual = query
|
||||
if query_pos is not None:
|
||||
query = query + query_pos
|
||||
|
||||
# change to (bs, num_query ,embed_dims)
|
||||
query = query.permute(1, 0, 2)
|
||||
value = value.permute(1, 0, 2)
|
||||
|
||||
bs, num_query, _ = query.shape
|
||||
bs, num_key, _ = value.shape
|
||||
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key
|
||||
|
||||
value = self.value_proj(value)
|
||||
if key_padding_mask is not None:
|
||||
value = value.masked_fill(key_padding_mask[..., None], 0.0)
|
||||
value = value.view(bs, num_key, self.num_heads, -1)
|
||||
sampling_offsets = self.sampling_offsets(query).view(
|
||||
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
|
||||
attention_weights = self.attention_weights(query).view(
|
||||
bs, num_query, self.num_heads, self.num_levels * self.num_points)
|
||||
attention_weights = attention_weights.softmax(-1)
|
||||
|
||||
attention_weights = attention_weights.view(bs, num_query,
|
||||
self.num_heads,
|
||||
self.num_levels,
|
||||
self.num_points)
|
||||
if reference_points.shape[-1] == 2:
|
||||
offset_normalizer = torch.stack(
|
||||
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
||||
sampling_locations = reference_points[:, :, None, :, None, :] \
|
||||
+ sampling_offsets \
|
||||
/ offset_normalizer[None, None, None, :, None, :]
|
||||
elif reference_points.shape[-1] == 4:
|
||||
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
||||
+ sampling_offsets / self.num_points \
|
||||
* reference_points[:, :, None, :, None, 2:] \
|
||||
* 0.5
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Last dim of reference_points must be'
|
||||
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
|
||||
if torch.cuda.is_available():
|
||||
output = MultiScaleDeformableAttnFunction.apply(
|
||||
value, spatial_shapes, level_start_index, sampling_locations,
|
||||
attention_weights, self.im2col_step)
|
||||
else:
|
||||
output = multi_scale_deformable_attn_pytorch(
|
||||
value, spatial_shapes, level_start_index, sampling_locations,
|
||||
attention_weights, self.im2col_step)
|
||||
output = self.output_proj(output).permute(1, 0, 2)
|
||||
# (num_query, bs ,embed_dims)
|
||||
return self.dropout(output) + inp_residual
|
||||
|
||||
|
||||
class FFN(BaseModule):
|
||||
"""Implements feed-forward networks (FFNs) with residual connection.
|
||||
|
||||
|
|
|
@ -0,0 +1,807 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
#ifndef DEFORM_ATTN_CUDA_KERNEL
|
||||
#define DEFORM_ATTN_CUDA_KERNEL
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
inline int GET_BLOCKS(const int N, const int num_threads) {
|
||||
return (N + num_threads - 1) / num_threads;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t ms_deform_attn_im2col_bilinear(
|
||||
const scalar_t *&bottom_data, const int &height, const int &width,
|
||||
const int &nheads, const int &channels, const scalar_t &h,
|
||||
const scalar_t &w, const int &m, const int &c) {
|
||||
const int h_low = floor(h);
|
||||
const int w_low = floor(w);
|
||||
const int h_high = h_low + 1;
|
||||
const int w_high = w_low + 1;
|
||||
|
||||
const scalar_t lh = h - h_low;
|
||||
const scalar_t lw = w - w_low;
|
||||
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
const int w_stride = nheads * channels;
|
||||
const int h_stride = width * w_stride;
|
||||
const int h_low_ptr_offset = h_low * h_stride;
|
||||
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
||||
const int w_low_ptr_offset = w_low * w_stride;
|
||||
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
||||
const int base_ptr = m * channels + c;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0) {
|
||||
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
||||
v1 = bottom_data[ptr1];
|
||||
}
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1) {
|
||||
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
||||
v2 = bottom_data[ptr2];
|
||||
}
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0) {
|
||||
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
||||
v3 = bottom_data[ptr3];
|
||||
}
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1) {
|
||||
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
||||
v4 = bottom_data[ptr4];
|
||||
}
|
||||
|
||||
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void ms_deform_attn_col2im_bilinear(
|
||||
const scalar_t *&bottom_data, const int &height, const int &width,
|
||||
const int &nheads, const int &channels, const scalar_t &h,
|
||||
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
|
||||
const scalar_t &attn_weight, scalar_t *&grad_value,
|
||||
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
|
||||
const int h_low = floor(h);
|
||||
const int w_low = floor(w);
|
||||
const int h_high = h_low + 1;
|
||||
const int w_high = w_low + 1;
|
||||
|
||||
const scalar_t lh = h - h_low;
|
||||
const scalar_t lw = w - w_low;
|
||||
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
const int w_stride = nheads * channels;
|
||||
const int h_stride = width * w_stride;
|
||||
const int h_low_ptr_offset = h_low * h_stride;
|
||||
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
||||
const int w_low_ptr_offset = w_low * w_stride;
|
||||
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
||||
const int base_ptr = m * channels + c;
|
||||
|
||||
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
const scalar_t top_grad_value = top_grad * attn_weight;
|
||||
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0) {
|
||||
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
||||
v1 = bottom_data[ptr1];
|
||||
grad_h_weight -= hw * v1;
|
||||
grad_w_weight -= hh * v1;
|
||||
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
|
||||
}
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1) {
|
||||
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
||||
v2 = bottom_data[ptr2];
|
||||
grad_h_weight -= lw * v2;
|
||||
grad_w_weight += hh * v2;
|
||||
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
|
||||
}
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0) {
|
||||
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
||||
v3 = bottom_data[ptr3];
|
||||
grad_h_weight += hw * v3;
|
||||
grad_w_weight -= lh * v3;
|
||||
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
|
||||
}
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1) {
|
||||
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
||||
v4 = bottom_data[ptr4];
|
||||
grad_h_weight += lw * v4;
|
||||
grad_w_weight += lh * v4;
|
||||
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
|
||||
}
|
||||
|
||||
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
*grad_attn_weight = top_grad * val;
|
||||
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
||||
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void ms_deform_attn_col2im_bilinear_gm(
|
||||
const scalar_t *&bottom_data, const int &height, const int &width,
|
||||
const int &nheads, const int &channels, const scalar_t &h,
|
||||
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
|
||||
const scalar_t &attn_weight, scalar_t *&grad_value,
|
||||
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
|
||||
const int h_low = floor(h);
|
||||
const int w_low = floor(w);
|
||||
const int h_high = h_low + 1;
|
||||
const int w_high = w_low + 1;
|
||||
|
||||
const scalar_t lh = h - h_low;
|
||||
const scalar_t lw = w - w_low;
|
||||
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
const int w_stride = nheads * channels;
|
||||
const int h_stride = width * w_stride;
|
||||
const int h_low_ptr_offset = h_low * h_stride;
|
||||
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
||||
const int w_low_ptr_offset = w_low * w_stride;
|
||||
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
||||
const int base_ptr = m * channels + c;
|
||||
|
||||
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
const scalar_t top_grad_value = top_grad * attn_weight;
|
||||
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0) {
|
||||
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
||||
v1 = bottom_data[ptr1];
|
||||
grad_h_weight -= hw * v1;
|
||||
grad_w_weight -= hh * v1;
|
||||
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
|
||||
}
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1) {
|
||||
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
||||
v2 = bottom_data[ptr2];
|
||||
grad_h_weight -= lw * v2;
|
||||
grad_w_weight += hh * v2;
|
||||
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
|
||||
}
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0) {
|
||||
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
||||
v3 = bottom_data[ptr3];
|
||||
grad_h_weight += hw * v3;
|
||||
grad_w_weight -= lh * v3;
|
||||
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
|
||||
}
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1) {
|
||||
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
||||
v4 = bottom_data[ptr4];
|
||||
grad_h_weight += lw * v4;
|
||||
grad_w_weight += lh * v4;
|
||||
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
|
||||
}
|
||||
|
||||
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
atomicAdd(grad_attn_weight, top_grad * val);
|
||||
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
||||
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ms_deformable_im2col_gpu_kernel(
|
||||
const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes,
|
||||
const int64_t *data_level_start_index, const scalar_t *data_sampling_loc,
|
||||
const scalar_t *data_attn_weight, const int batch_size,
|
||||
const int spatial_size, const int num_heads, const int channels,
|
||||
const int num_levels, const int num_query, const int num_point,
|
||||
scalar_t *data_col) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
scalar_t *data_col_ptr = data_col + index;
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
scalar_t col = 0;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const scalar_t *data_value_ptr =
|
||||
data_value +
|
||||
(data_value_ptr_init_offset + level_start_id * qid_stride);
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h,
|
||||
spatial_w, num_heads, channels,
|
||||
h_im, w_im, m_col, c_col) *
|
||||
weight;
|
||||
}
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
}
|
||||
}
|
||||
*data_col_ptr = col;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, unsigned int blockSize>
|
||||
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
|
||||
const int n, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
||||
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
const scalar_t top_grad = grad_col[index];
|
||||
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const int value_ptr_offset =
|
||||
data_value_ptr_init_offset + level_start_id * qid_stride;
|
||||
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
||||
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
||||
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
|
||||
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
|
||||
*(cache_grad_attn_weight + threadIdx.x) = 0;
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
ms_deform_attn_col2im_bilinear(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
cache_grad_sampling_loc + (threadIdx.x << 1),
|
||||
cache_grad_attn_weight + threadIdx.x);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
scalar_t _grad_w = cache_grad_sampling_loc[0],
|
||||
_grad_h = cache_grad_sampling_loc[1],
|
||||
_grad_a = cache_grad_attn_weight[0];
|
||||
int sid = 2;
|
||||
for (unsigned int tid = 1; tid < blockSize; ++tid) {
|
||||
_grad_w += cache_grad_sampling_loc[sid];
|
||||
_grad_h += cache_grad_sampling_loc[sid + 1];
|
||||
_grad_a += cache_grad_attn_weight[tid];
|
||||
sid += 2;
|
||||
}
|
||||
|
||||
*grad_sampling_loc = _grad_w;
|
||||
*(grad_sampling_loc + 1) = _grad_h;
|
||||
*grad_attn_weight = _grad_a;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, unsigned int blockSize>
|
||||
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
|
||||
const int n, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
||||
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
const scalar_t top_grad = grad_col[index];
|
||||
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const int value_ptr_offset =
|
||||
data_value_ptr_init_offset + level_start_id * qid_stride;
|
||||
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
||||
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
||||
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
|
||||
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
|
||||
*(cache_grad_attn_weight + threadIdx.x) = 0;
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
ms_deform_attn_col2im_bilinear(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
cache_grad_sampling_loc + (threadIdx.x << 1),
|
||||
cache_grad_attn_weight + threadIdx.x);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
const unsigned int xid1 = tid << 1;
|
||||
const unsigned int xid2 = (tid + s) << 1;
|
||||
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
||||
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
||||
cache_grad_sampling_loc[xid1 + 1] +=
|
||||
cache_grad_sampling_loc[xid2 + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
||||
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
||||
*grad_attn_weight = cache_grad_attn_weight[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
|
||||
const int n, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
const scalar_t top_grad = grad_col[index];
|
||||
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const int value_ptr_offset =
|
||||
data_value_ptr_init_offset + level_start_id * qid_stride;
|
||||
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
||||
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
||||
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
|
||||
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
|
||||
*(cache_grad_attn_weight + threadIdx.x) = 0;
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
ms_deform_attn_col2im_bilinear(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
cache_grad_sampling_loc + (threadIdx.x << 1),
|
||||
cache_grad_attn_weight + threadIdx.x);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
scalar_t _grad_w = cache_grad_sampling_loc[0],
|
||||
_grad_h = cache_grad_sampling_loc[1],
|
||||
_grad_a = cache_grad_attn_weight[0];
|
||||
int sid = 2;
|
||||
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
|
||||
_grad_w += cache_grad_sampling_loc[sid];
|
||||
_grad_h += cache_grad_sampling_loc[sid + 1];
|
||||
_grad_a += cache_grad_attn_weight[tid];
|
||||
sid += 2;
|
||||
}
|
||||
|
||||
*grad_sampling_loc = _grad_w;
|
||||
*(grad_sampling_loc + 1) = _grad_h;
|
||||
*grad_attn_weight = _grad_a;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
|
||||
const int n, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
const scalar_t top_grad = grad_col[index];
|
||||
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const int value_ptr_offset =
|
||||
data_value_ptr_init_offset + level_start_id * qid_stride;
|
||||
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
||||
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
||||
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
|
||||
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
|
||||
*(cache_grad_attn_weight + threadIdx.x) = 0;
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
ms_deform_attn_col2im_bilinear(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
cache_grad_sampling_loc + (threadIdx.x << 1),
|
||||
cache_grad_attn_weight + threadIdx.x);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
|
||||
s >>= 1, spre >>= 1) {
|
||||
if (tid < s) {
|
||||
const unsigned int xid1 = tid << 1;
|
||||
const unsigned int xid2 = (tid + s) << 1;
|
||||
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
||||
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
||||
cache_grad_sampling_loc[xid1 + 1] +=
|
||||
cache_grad_sampling_loc[xid2 + 1];
|
||||
if (tid + (s << 1) < spre) {
|
||||
cache_grad_attn_weight[tid] +=
|
||||
cache_grad_attn_weight[tid + (s << 1)];
|
||||
cache_grad_sampling_loc[xid1] +=
|
||||
cache_grad_sampling_loc[xid2 + (s << 1)];
|
||||
cache_grad_sampling_loc[xid1 + 1] +=
|
||||
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
||||
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
||||
*grad_attn_weight = cache_grad_attn_weight[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
|
||||
const int n, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
extern __shared__ int _s[];
|
||||
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
|
||||
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
||||
unsigned int tid = threadIdx.x;
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
const scalar_t top_grad = grad_col[index];
|
||||
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const int value_ptr_offset =
|
||||
data_value_ptr_init_offset + level_start_id * qid_stride;
|
||||
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
||||
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
||||
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
|
||||
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
|
||||
*(cache_grad_attn_weight + threadIdx.x) = 0;
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
ms_deform_attn_col2im_bilinear(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
cache_grad_sampling_loc + (threadIdx.x << 1),
|
||||
cache_grad_attn_weight + threadIdx.x);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
|
||||
s >>= 1, spre >>= 1) {
|
||||
if (tid < s) {
|
||||
const unsigned int xid1 = tid << 1;
|
||||
const unsigned int xid2 = (tid + s) << 1;
|
||||
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
||||
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
||||
cache_grad_sampling_loc[xid1 + 1] +=
|
||||
cache_grad_sampling_loc[xid2 + 1];
|
||||
if (tid + (s << 1) < spre) {
|
||||
cache_grad_attn_weight[tid] +=
|
||||
cache_grad_attn_weight[tid + (s << 1)];
|
||||
cache_grad_sampling_loc[xid1] +=
|
||||
cache_grad_sampling_loc[xid2 + (s << 1)];
|
||||
cache_grad_sampling_loc[xid1 + 1] +=
|
||||
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
||||
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
||||
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ms_deformable_col2im_gpu_kernel_gm(
|
||||
const int n, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
int _temp = index;
|
||||
const int c_col = _temp % channels;
|
||||
_temp /= channels;
|
||||
const int sampling_index = _temp;
|
||||
const int m_col = _temp % num_heads;
|
||||
_temp /= num_heads;
|
||||
const int q_col = _temp % num_query;
|
||||
_temp /= num_query;
|
||||
const int b_col = _temp;
|
||||
|
||||
const scalar_t top_grad = grad_col[index];
|
||||
|
||||
int data_weight_ptr = sampling_index * num_levels * num_point;
|
||||
int data_loc_w_ptr = data_weight_ptr << 1;
|
||||
const int grad_sampling_ptr = data_weight_ptr;
|
||||
grad_sampling_loc += grad_sampling_ptr << 1;
|
||||
grad_attn_weight += grad_sampling_ptr;
|
||||
const int grad_weight_stride = 1;
|
||||
const int grad_loc_stride = 2;
|
||||
const int qid_stride = num_heads * channels;
|
||||
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
||||
|
||||
for (int l_col = 0; l_col < num_levels; ++l_col) {
|
||||
const int level_start_id = data_level_start_index[l_col];
|
||||
const int spatial_h_ptr = l_col << 1;
|
||||
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
||||
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
||||
const int value_ptr_offset =
|
||||
data_value_ptr_init_offset + level_start_id * qid_stride;
|
||||
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
||||
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
||||
|
||||
for (int p_col = 0; p_col < num_point; ++p_col) {
|
||||
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
||||
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
||||
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
||||
|
||||
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
||||
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
||||
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
|
||||
ms_deform_attn_col2im_bilinear_gm(
|
||||
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
|
||||
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
}
|
||||
data_weight_ptr += 1;
|
||||
data_loc_w_ptr += 2;
|
||||
grad_attn_weight += grad_weight_stride;
|
||||
grad_sampling_loc += grad_loc_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // DEFORM_ATTN_CUDA_KERNEL
|
|
@ -0,0 +1,74 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
Tensor ms_deform_attn_cuda_forward(const Tensor &value,
|
||||
const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<Tensor> ms_deform_attn_cuda_backward(
|
||||
const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index, const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight, const Tensor &grad_output,
|
||||
const int im2col_step);
|
||||
|
||||
#endif
|
||||
|
||||
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight,
|
||||
const int im2col_step) {
|
||||
if (value.type().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(value)
|
||||
CHECK_CUDA_INPUT(spatial_shapes)
|
||||
CHECK_CUDA_INPUT(level_start_index)
|
||||
CHECK_CUDA_INPUT(sampling_loc)
|
||||
CHECK_CUDA_INPUT(attn_weight)
|
||||
return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
|
||||
sampling_loc, attn_weight, im2col_step);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
||||
|
||||
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
|
||||
const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight,
|
||||
const Tensor &grad_output,
|
||||
const int im2col_step) {
|
||||
if (value.type().is_cuda()) {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
CHECK_CUDA_INPUT(value)
|
||||
CHECK_CUDA_INPUT(spatial_shapes)
|
||||
CHECK_CUDA_INPUT(level_start_index)
|
||||
CHECK_CUDA_INPUT(sampling_loc)
|
||||
CHECK_CUDA_INPUT(attn_weight)
|
||||
CHECK_CUDA_INPUT(grad_output)
|
||||
return ms_deform_attn_cuda_backward(value, spatial_shapes,
|
||||
level_start_index, sampling_loc,
|
||||
attn_weight, grad_output, im2col_step);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
|
@ -0,0 +1,365 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* Deformable DETR
|
||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ms_deform_attn_cuda_kernel.cuh>
|
||||
#include <vector>
|
||||
|
||||
template <typename scalar_t>
|
||||
void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes,
|
||||
const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc,
|
||||
const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size,
|
||||
const int num_heads, const int channels,
|
||||
const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *data_col) {
|
||||
const int num_kernels = batch_size * num_query * num_heads * channels;
|
||||
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
||||
const int num_threads = CUDA_NUM_THREADS;
|
||||
ms_deformable_im2col_gpu_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
|
||||
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
|
||||
data_sampling_loc, data_attn_weight, batch_size, spatial_size,
|
||||
num_heads, channels, num_levels, num_query, num_point, data_col);
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void ms_deformable_col2im_cuda(
|
||||
cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value,
|
||||
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
|
||||
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
|
||||
const int batch_size, const int spatial_size, const int num_heads,
|
||||
const int channels, const int num_levels, const int num_query,
|
||||
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
|
||||
scalar_t *grad_attn_weight) {
|
||||
const int num_threads =
|
||||
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
|
||||
const int num_kernels = batch_size * num_query * num_heads * channels;
|
||||
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
||||
if (channels > 1024) {
|
||||
if ((channels & 1023) == 0) {
|
||||
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
||||
num_threads * 3 * sizeof(scalar_t), stream>>>(
|
||||
num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc, data_attn_weight,
|
||||
batch_size, spatial_size, num_heads, channels, num_levels,
|
||||
num_query, num_point, grad_value, grad_sampling_loc,
|
||||
grad_attn_weight);
|
||||
} else {
|
||||
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
}
|
||||
} else {
|
||||
switch (channels) {
|
||||
case 1:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
|
||||
1>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 2:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
|
||||
2>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 4:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
|
||||
4>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 8:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
|
||||
8>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 16:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
|
||||
16>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 32:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
|
||||
32>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 64:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
|
||||
64>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 128:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
|
||||
128>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 256:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
|
||||
256>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 512:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
|
||||
512>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
case 1024:
|
||||
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
|
||||
1024>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
|
||||
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc,
|
||||
data_attn_weight, batch_size, spatial_size, num_heads,
|
||||
channels, num_levels, num_query, num_point, grad_value,
|
||||
grad_sampling_loc, grad_attn_weight);
|
||||
break;
|
||||
default:
|
||||
if (channels < 64) {
|
||||
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
||||
num_threads * 3 * sizeof(scalar_t), stream>>>(
|
||||
num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc, data_attn_weight,
|
||||
batch_size, spatial_size, num_heads, channels, num_levels,
|
||||
num_query, num_point, grad_value, grad_sampling_loc,
|
||||
grad_attn_weight);
|
||||
} else {
|
||||
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
||||
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
||||
num_threads * 3 * sizeof(scalar_t), stream>>>(
|
||||
num_kernels, grad_col, data_value, data_spatial_shapes,
|
||||
data_level_start_index, data_sampling_loc, data_attn_weight,
|
||||
batch_size, spatial_size, num_heads, channels, num_levels,
|
||||
num_query, num_point, grad_value, grad_sampling_loc,
|
||||
grad_attn_weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int im2col_step) {
|
||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||
AT_ASSERTM(spatial_shapes.is_contiguous(),
|
||||
"spatial_shapes tensor has to be contiguous");
|
||||
AT_ASSERTM(level_start_index.is_contiguous(),
|
||||
"level_start_index tensor has to be contiguous");
|
||||
AT_ASSERTM(sampling_loc.is_contiguous(),
|
||||
"sampling_loc tensor has to be contiguous");
|
||||
AT_ASSERTM(attn_weight.is_contiguous(),
|
||||
"attn_weight tensor has to be contiguous");
|
||||
|
||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
||||
AT_ASSERTM(spatial_shapes.type().is_cuda(),
|
||||
"spatial_shapes must be a CUDA tensor");
|
||||
AT_ASSERTM(level_start_index.type().is_cuda(),
|
||||
"level_start_index must be a CUDA tensor");
|
||||
AT_ASSERTM(sampling_loc.type().is_cuda(),
|
||||
"sampling_loc must be a CUDA tensor");
|
||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
||||
|
||||
const int batch = value.size(0);
|
||||
const int spatial_size = value.size(1);
|
||||
const int num_heads = value.size(2);
|
||||
const int channels = value.size(3);
|
||||
|
||||
const int num_levels = spatial_shapes.size(0);
|
||||
|
||||
const int num_query = sampling_loc.size(1);
|
||||
const int num_point = sampling_loc.size(4);
|
||||
|
||||
const int im2col_step_ = std::min(batch, im2col_step);
|
||||
|
||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
|
||||
batch, im2col_step_);
|
||||
|
||||
auto output =
|
||||
at::zeros({batch, num_query, num_heads, channels}, value.options());
|
||||
|
||||
const int batch_n = im2col_step_;
|
||||
auto output_n = output.view(
|
||||
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
|
||||
auto per_value_size = spatial_size * num_heads * channels;
|
||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||
for (int n = 0; n < batch / im2col_step_; ++n) {
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
|
||||
sampling_loc.data<scalar_t>() +
|
||||
n * im2col_step_ * per_sample_loc_size,
|
||||
attn_weight.data<scalar_t>() +
|
||||
n * im2col_step_ * per_attn_weight_size,
|
||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
|
||||
num_point, columns.data<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
output = output.view({batch, num_query, num_heads * channels});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||
const at::Tensor &value, const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight, const at::Tensor &grad_output,
|
||||
const int im2col_step) {
|
||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
||||
AT_ASSERTM(spatial_shapes.is_contiguous(),
|
||||
"spatial_shapes tensor has to be contiguous");
|
||||
AT_ASSERTM(level_start_index.is_contiguous(),
|
||||
"level_start_index tensor has to be contiguous");
|
||||
AT_ASSERTM(sampling_loc.is_contiguous(),
|
||||
"sampling_loc tensor has to be contiguous");
|
||||
AT_ASSERTM(attn_weight.is_contiguous(),
|
||||
"attn_weight tensor has to be contiguous");
|
||||
AT_ASSERTM(grad_output.is_contiguous(),
|
||||
"grad_output tensor has to be contiguous");
|
||||
|
||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
||||
AT_ASSERTM(spatial_shapes.type().is_cuda(),
|
||||
"spatial_shapes must be a CUDA tensor");
|
||||
AT_ASSERTM(level_start_index.type().is_cuda(),
|
||||
"level_start_index must be a CUDA tensor");
|
||||
AT_ASSERTM(sampling_loc.type().is_cuda(),
|
||||
"sampling_loc must be a CUDA tensor");
|
||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
||||
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
||||
|
||||
const int batch = value.size(0);
|
||||
const int spatial_size = value.size(1);
|
||||
const int num_heads = value.size(2);
|
||||
const int channels = value.size(3);
|
||||
|
||||
const int num_levels = spatial_shapes.size(0);
|
||||
|
||||
const int num_query = sampling_loc.size(1);
|
||||
const int num_point = sampling_loc.size(4);
|
||||
|
||||
const int im2col_step_ = std::min(batch, im2col_step);
|
||||
|
||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
|
||||
batch, im2col_step_);
|
||||
|
||||
auto grad_value = at::zeros_like(value);
|
||||
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
||||
auto grad_attn_weight = at::zeros_like(attn_weight);
|
||||
|
||||
const int batch_n = im2col_step_;
|
||||
auto per_value_size = spatial_size * num_heads * channels;
|
||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||
auto grad_output_n = grad_output.view(
|
||||
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
|
||||
|
||||
for (int n = 0; n < batch / im2col_step_; ++n) {
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(
|
||||
at::cuda::getCurrentCUDAStream(), grad_output_g.data<scalar_t>(),
|
||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
|
||||
sampling_loc.data<scalar_t>() +
|
||||
n * im2col_step_ * per_sample_loc_size,
|
||||
attn_weight.data<scalar_t>() +
|
||||
n * im2col_step_ * per_attn_weight_size,
|
||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
|
||||
num_point,
|
||||
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||
grad_sampling_loc.data<scalar_t>() +
|
||||
n * im2col_step_ * per_sample_loc_size,
|
||||
grad_attn_weight.data<scalar_t>() +
|
||||
n * im2col_step_ * per_attn_weight_size);
|
||||
}));
|
||||
}
|
||||
|
||||
return {grad_value, grad_sampling_loc, grad_attn_weight};
|
||||
}
|
|
@ -92,6 +92,19 @@ void modulated_deform_conv_backward(
|
|||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias);
|
||||
|
||||
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight, const int im2col_step);
|
||||
|
||||
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
|
||||
const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight,
|
||||
const Tensor &grad_output,
|
||||
const int im2col_step);
|
||||
|
||||
Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
|
||||
|
||||
Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold,
|
||||
|
@ -182,12 +195,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
|
|||
const Tensor dets_sorted, const float iou_threshold,
|
||||
const int multi_label);
|
||||
|
||||
Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y,
|
||||
Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
|
||||
int pad_y1);
|
||||
|
||||
Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias,
|
||||
const Tensor& refer, int act, int grad, float alpha,
|
||||
Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,
|
||||
const Tensor &refer, int act, int grad, float alpha,
|
||||
float scale);
|
||||
|
||||
void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
|
||||
|
@ -401,4 +414,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("grad_input"), py::arg("pooled_height"),
|
||||
py::arg("pooled_width"), py::arg("spatial_scale"),
|
||||
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise"));
|
||||
m.def("ms_deform_attn_forward", &ms_deform_attn_forward,
|
||||
"forward function of multi-scale deformable attention",
|
||||
py::arg("value"), py::arg("value_spatial_shapes"),
|
||||
py::arg("value_level_start_index"), py::arg("sampling_locations"),
|
||||
py::arg("attention_weights"), py::arg("im2col_step"));
|
||||
m.def("ms_deform_attn_backward", &ms_deform_attn_backward,
|
||||
"backward function of multi-scale deformable attention",
|
||||
py::arg("value"), py::arg("value_spatial_shapes"),
|
||||
py::arg("value_level_start_index"), py::arg("sampling_locations"),
|
||||
py::arg("attention_weights"), py::arg("grad_output"),
|
||||
py::arg("im2col_step"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd.function import Function, once_differentiable
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
|
||||
|
||||
|
||||
class MultiScaleDeformableAttnFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
|
||||
sampling_locations, attention_weights, im2col_step):
|
||||
"""GPU version of multi-scale deformable attention.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value has shape
|
||||
(bs, num_keys, mum_heads, embed_dims//num_heads)
|
||||
value_spatial_shapes (Tensor): Spatial shape of
|
||||
each feature map, has shape (num_levels, 2),
|
||||
last dimension 2 represent (h, w)
|
||||
sampling_locations (Tensor): The location of sampling points,
|
||||
has shape
|
||||
(bs ,num_queries, num_heads, num_levels, num_points, 2),
|
||||
the last dimension 2 represent (x, y).
|
||||
attention_weights (Tensor): The weight of sampling points used
|
||||
when calculate the attention, has shape
|
||||
(bs ,num_queries, num_heads, num_levels, num_points),
|
||||
im2col_step (Tensor): The step used in image to column.
|
||||
|
||||
Returns:
|
||||
Tensor: has shape (bs, num_queries, embed_dims)
|
||||
"""
|
||||
|
||||
ctx.im2col_step = im2col_step
|
||||
output = ext_module.ms_deform_attn_forward(value, value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
ctx.im2col_step)
|
||||
ctx.save_for_backward(value, value_spatial_shapes,
|
||||
value_level_start_index, sampling_locations,
|
||||
attention_weights)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
"""GPU version of backward function.
|
||||
|
||||
Args:
|
||||
grad_output (Tensor): Gradient
|
||||
of output tensor of forward.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor]: Gradient
|
||||
of input tensors in forward.
|
||||
"""
|
||||
value, value_spatial_shapes, value_level_start_index,\
|
||||
sampling_locations, attention_weights = ctx.saved_tensors
|
||||
grad_value, grad_sampling_loc, grad_attn_weight = \
|
||||
ext_module.ms_deform_attn_backward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
grad_output,
|
||||
ctx.im2col_step)
|
||||
|
||||
return grad_value, None, None, \
|
||||
grad_sampling_loc, grad_attn_weight, None
|
||||
|
||||
|
||||
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
|
||||
sampling_locations, attention_weights):
|
||||
"""CPU version of multi-scale deformable attention.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value has shape
|
||||
(bs, num_keys, mum_heads, embed_dims//num_heads)
|
||||
value_spatial_shapes (Tensor): Spatial shape of
|
||||
each feature map, has shape (num_levels, 2),
|
||||
last dimension 2 represent (h, w)
|
||||
sampling_locations (Tensor): The location of sampling points,
|
||||
has shape
|
||||
(bs ,num_queries, num_heads, num_levels, num_points, 2),
|
||||
the last dimension 2 represent (x, y).
|
||||
attention_weights (Tensor): The weight of sampling points used
|
||||
when calculate the attention, has shape
|
||||
(bs ,num_queries, num_heads, num_levels, num_points),
|
||||
|
||||
Returns:
|
||||
Tensor: has shape (bs, num_queries, embed_dims)
|
||||
"""
|
||||
|
||||
bs, _, num_heads, embed_dims = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ =\
|
||||
sampling_locations.shape
|
||||
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
|
||||
dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
||||
# bs, H_*W_, num_heads, embed_dims ->
|
||||
# bs, H_*W_, num_heads*embed_dims ->
|
||||
# bs, num_heads*embed_dims, H_*W_ ->
|
||||
# bs*num_heads, embed_dims, H_, W_
|
||||
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
|
||||
bs * num_heads, embed_dims, H_, W_)
|
||||
# bs, num_queries, num_heads, num_points, 2 ->
|
||||
# bs, num_heads, num_queries, num_points, 2 ->
|
||||
# bs*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :,
|
||||
level].transpose(1, 2).flatten(0, 1)
|
||||
# bs*num_heads, embed_dims, num_queries, num_points
|
||||
sampling_value_l_ = F.grid_sample(
|
||||
value_l_,
|
||||
sampling_grid_l_,
|
||||
mode='bilinear',
|
||||
padding_mode='zeros',
|
||||
align_corners=False)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
||||
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
||||
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
bs * num_heads, 1, num_queries, num_levels * num_points)
|
||||
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
|
||||
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
|
||||
num_queries)
|
||||
return output.transpose(1, 2).contiguous()
|
|
@ -0,0 +1,125 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.autograd import gradcheck
|
||||
|
||||
from mmcv.ops.multi_scale_deform_attn import (
|
||||
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
|
||||
|
||||
|
||||
def test_forward_multi_scale_deformable_attn_pytorch():
|
||||
N, M, D = 1, 2, 2
|
||||
Lq, L, P = 2, 2, 2
|
||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D) * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
|
||||
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
|
||||
multi_scale_deformable_attn_pytorch(value.double(), shapes,
|
||||
sampling_locations.double(),
|
||||
attention_weights.double()).detach()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_forward_equal_with_pytorch_double():
|
||||
N, M, D = 1, 2, 2
|
||||
Lq, L, P = 2, 2, 2
|
||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
||||
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
im2col_step = 2
|
||||
output_pytorch = multi_scale_deformable_attn_pytorch(
|
||||
value.double(), shapes, sampling_locations.double(),
|
||||
attention_weights.double()).detach().cpu()
|
||||
|
||||
output_cuda = MultiScaleDeformableAttnFunction.apply(
|
||||
value.double(), shapes, level_start_index, sampling_locations.double(),
|
||||
attention_weights.double(), im2col_step).detach().cpu()
|
||||
assert torch.allclose(output_cuda, output_pytorch)
|
||||
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_cuda - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
assert max_abs_err < 1e-18
|
||||
assert max_rel_err < 1e-15
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_forward_equal_with_pytorch_float():
|
||||
N, M, D = 1, 2, 2
|
||||
Lq, L, P = 2, 2, 2
|
||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
||||
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
im2col_step = 2
|
||||
output_pytorch = multi_scale_deformable_attn_pytorch(
|
||||
value, shapes, sampling_locations, attention_weights).detach().cpu()
|
||||
|
||||
output_cuda = MultiScaleDeformableAttnFunction.apply(
|
||||
value, shapes, level_start_index, sampling_locations,
|
||||
attention_weights, im2col_step).detach().cpu()
|
||||
assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_cuda - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
assert max_abs_err < 1e-9
|
||||
assert max_rel_err < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096])
|
||||
def test_gradient_numerical(channels,
|
||||
grad_value=True,
|
||||
grad_sampling_loc=True,
|
||||
grad_attn_weight=True):
|
||||
|
||||
N, M, _ = 1, 2, 2
|
||||
Lq, L, P = 2, 2, 2
|
||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum([(H * W).item() for H, W in shapes])
|
||||
|
||||
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
||||
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
im2col_step = 2
|
||||
|
||||
func = MultiScaleDeformableAttnFunction.apply
|
||||
|
||||
value.requires_grad = grad_value
|
||||
sampling_locations.requires_grad = grad_sampling_loc
|
||||
attention_weights.requires_grad = grad_attn_weight
|
||||
|
||||
assert gradcheck(
|
||||
func,
|
||||
(value.double(), shapes, level_start_index,
|
||||
sampling_locations.double(), attention_weights.double(), im2col_step))
|
Loading…
Reference in New Issue