[Feature]: support Multi-Scale-DeformAttention in deformable-detr (#878)

* add c++ ms_deform_atten

* fix cpp lint

* fix cpp lint

* clang format

* remove cmakefile

* google style

* clang-format precommit

* use clang-format-lint-action

* add transformer base class

* add merge

* add docstr

* add pyargs

* fix according to commments

* resiger module

* change to use basemodule

* add _ between build function

* split the name

* fix according to comments

* fix lint and fix unitest

* fix cpp lint

* fix bug of deformdetr_atten

* fix drop out

* fix residual

* use CUDA_1D_KERNEL_LOOP
pull/981/head
ZhangShilong 2021-04-23 16:35:15 +08:00 committed by GitHub
parent 0dd0c49a5b
commit 54a7ebb4ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1732 additions and 4 deletions

View File

@ -1,11 +1,15 @@
import copy
import math
import warnings
import torch
import torch.nn as nn
from mmcv import ConfigDict
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
constant_init, xavier_init)
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmcv.runner.base_module import BaseModule
from mmcv.utils import build_from_cfg
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
@ -135,6 +139,201 @@ class MultiheadAttention(BaseModule):
return residual + self.dropout(out)
@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weight()
def init_weight(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_key, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_key, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output).permute(1, 0, 2)
# (num_query, bs ,embed_dims)
return self.dropout(output) + inp_residual
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with residual connection.

View File

@ -0,0 +1,807 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#ifndef DEFORM_ATTN_CUDA_KERNEL
#define DEFORM_ATTN_CUDA_KERNEL
#include "common_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp"
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value,
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value,
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(
const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const int batch_size,
const int spatial_size, const int num_heads, const int channels,
const int num_levels, const int num_query, const int num_point,
scalar_t *data_col) {
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr =
data_value +
(data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h,
spatial_w, num_heads, channels,
h_im, w_im, m_col, c_col) *
weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
scalar_t _grad_w = cache_grad_sampling_loc[0],
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
scalar_t _grad_w = cache_grad_sampling_loc[0],
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_attn_weight[tid] +=
cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] +=
cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_attn_weight[tid] +=
cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] +=
cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
#endif // DEFORM_ATTN_CUDA_KERNEL

View File

@ -0,0 +1,74 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);
std::vector<Tensor> ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output,
const int im2col_step);
#endif
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
return ms_deform_attn_cuda_backward(value, spatial_shapes,
level_start_index, sampling_loc,
attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}

View File

@ -0,0 +1,365 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <ms_deform_attn_cuda_kernel.cuh>
#include <vector>
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size,
const int num_heads, const int channels,
const int num_levels, const int num_query,
const int num_point, scalar_t *data_col) {
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
data_sampling_loc, data_attn_weight, batch_size, spatial_size,
num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(
cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024) {
if ((channels & 1023) == 0) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
}
} else {
switch (channels) {
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
default:
if (channels < 64) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto output =
at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point, columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads * channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size);
}));
}
return {grad_value, grad_sampling_loc, grad_attn_weight};
}

View File

@ -92,6 +92,19 @@ void modulated_deform_conv_backward(
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step);
Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold,
@ -182,12 +195,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label);
Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y,
Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
int pad_y1);
Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias,
const Tensor& refer, int act, int grad, float alpha,
Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad, float alpha,
float scale);
void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
@ -401,4 +414,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_input"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise"));
m.def("ms_deform_attn_forward", &ms_deform_attn_forward,
"forward function of multi-scale deformable attention",
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("im2col_step"));
m.def("ms_deform_attn_backward", &ms_deform_attn_backward,
"backward function of multi-scale deformable attention",
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("grad_output"),
py::arg("im2col_step"));
}

View File

@ -0,0 +1,134 @@
import torch
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(value, value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()

View File

@ -0,0 +1,125 @@
import pytest
import torch
from torch.autograd import gradcheck
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
S = sum([(H * W).item() for H, W in shapes])
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
multi_scale_deformable_attn_pytorch(value.double(), shapes,
sampling_locations.double(),
attention_weights.double()).detach()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_double():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.double(), shapes, sampling_locations.double(),
attention_weights.double()).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value.double(), shapes, level_start_index, sampling_locations.double(),
attention_weights.double(), im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_float():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value, shapes, level_start_index, sampling_locations,
attention_weights, im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096])
def test_gradient_numerical(channels,
grad_value=True,
grad_sampling_loc=True,
grad_attn_weight=True):
N, M, _ = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
func = MultiScaleDeformableAttnFunction.apply
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
assert gradcheck(
func,
(value.double(), shapes, level_start_index,
sampling_locations.double(), attention_weights.double(), im2col_step))