mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support MsDeformAttnForward with fast kernel (#3157)
parent
055a056c21
commit
8c21bf6b77
|
@ -0,0 +1,23 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) [2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_
|
||||
#define MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_
|
||||
void KernelMsDeformAttnForwardFast(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const char *data_value_gdram,
|
||||
const char *data_spatial_shapes_gdram,
|
||||
const char *data_level_start_index_gdram,
|
||||
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char *data_col_gdram);
|
||||
#endif // MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_
|
|
@ -0,0 +1,823 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) [2024] by Cambricon, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the
|
||||
* "Software"), to deal in the Software without restriction, including
|
||||
* without limitation the rights to use, copy, modify, merge, publish,
|
||||
* distribute, sublicense, and/or sell copies of the Software, and to
|
||||
* permit persons to whom the Software is furnished to do so, subject to
|
||||
* the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included
|
||||
* in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "common_mlu_helper.hpp"
|
||||
#include "ms_deform_attn_fast_mlu_kernel.hpp"
|
||||
|
||||
#define NRAM_REMAIN_SIZE (48 * 1024)
|
||||
#define NRAM_AVALIABLE_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
||||
|
||||
#define SRAM_REMAIN_SIZE (32 * 1024)
|
||||
#define SRAM_AVALIABLE_SIZE (__MLU_SRAM_SIZE__ * 1024 - SRAM_REMAIN_SIZE)
|
||||
#define SRAM_FOR_VALUE_SIZE (SRAM_AVALIABLE_SIZE - 128)
|
||||
|
||||
#define MAX_MEMCPY_SEGNUM 65536
|
||||
|
||||
__nram__ char nram_buffer[NRAM_AVALIABLE_SIZE];
|
||||
__mlu_shared__ char sram_buffer[SRAM_AVALIABLE_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ inline T __mluop_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ inline T __mluop_max(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void __mluop_floor(T* dst_ram, T* src_ram, int size) {
|
||||
if (sizeof(T) == sizeof(float)) {
|
||||
int16* mid = (int16*)(dst_ram + size / 2);
|
||||
__bang_float2int16_dn(mid, (float*)src_ram, size, 0);
|
||||
__bang_int162float((float*)dst_ram, (int16_t*)mid, size, 0);
|
||||
} else {
|
||||
__bang_half2int16_dn((int16_t*)dst_ram, (half*)src_ram, size, 0);
|
||||
__bang_int162half((half*)dst_ram, (int16_t*)dst_ram, size, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void broadcastSpatialHW(
|
||||
float* spatial_offset_bd_nram, // (num_levels, num_points)
|
||||
float* spatial_h_bd_nram, // (num_levels, num_points)
|
||||
float* spatial_w_bd_nram, // (num_levels, num_points)
|
||||
int32_t* spatial_shapes_nram, // (num_levels, 2)
|
||||
int32_t* spatial_offset_nram, // (num_levels)
|
||||
const int32_t num_levels, const int32_t num_points) {
|
||||
for (int i = 0; i < num_levels * 2; i++) {
|
||||
((float*)spatial_shapes_nram)[i] = (float)spatial_shapes_nram[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_levels; i++) {
|
||||
((float*)spatial_offset_nram)[i] = (float)spatial_offset_nram[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_levels; i++) {
|
||||
__memcpy(spatial_h_bd_nram + i * num_points, spatial_shapes_nram + i * 2,
|
||||
sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_levels; i++) {
|
||||
__memcpy(spatial_w_bd_nram + i * num_points,
|
||||
spatial_shapes_nram + 1 + i * 2, sizeof(float), NRAM2NRAM,
|
||||
sizeof(float), 0, num_points - 1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_levels; i++) {
|
||||
__memcpy(spatial_offset_bd_nram + i * num_points, spatial_offset_nram + i,
|
||||
sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void getConditionCoordWeight(
|
||||
int32_t* data_offset_nram, T* weight_polation_nram,
|
||||
T* cond_point_polation_nram, T* cond_point_valid_nram, T* loc_nram,
|
||||
T* weight_attn_nram, T* spatial_offset_bd_nram, T* spatial_w_bd_nram,
|
||||
T* spatial_h_bd_nram, T* buf_nram, const int32_t deal_n,
|
||||
const int32_t num_levels, const int32_t num_points, const int32_t num_heads,
|
||||
int32_t pad_num_levels_points) {
|
||||
int32_t pad_total_points = deal_n * pad_num_levels_points;
|
||||
int32_t pad_block_points = pad_num_levels_points;
|
||||
T* buf_x_nram = buf_nram;
|
||||
T* buf_y_nram = buf_nram + pad_total_points;
|
||||
T* buf_cond_nram = buf_nram + 2 * pad_total_points;
|
||||
T* buf_x_floor = buf_nram + 2 * pad_total_points;
|
||||
T* buf_y_floor = buf_nram + 3 * pad_total_points;
|
||||
T* buf_x_ceil = buf_nram + 4 * pad_total_points;
|
||||
T* buf_y_ceil = buf_nram + 5 * pad_total_points;
|
||||
|
||||
__bang_write_value(buf_x_nram, pad_total_points, 0);
|
||||
__bang_write_value(buf_y_nram, pad_total_points, 0);
|
||||
__bang_write_value(buf_x_floor, pad_total_points, 0);
|
||||
__bang_write_value(buf_x_ceil, pad_total_points, 0);
|
||||
__bang_write_value(buf_y_floor, pad_total_points, 0);
|
||||
__bang_write_value(buf_y_ceil, pad_total_points, 0);
|
||||
|
||||
//================================================================================================
|
||||
__memcpy(buf_x_nram, loc_nram, sizeof(T), NRAM2NRAM, sizeof(T), 2 * sizeof(T),
|
||||
pad_total_points - 1);
|
||||
__memcpy(buf_y_nram, loc_nram + 1, sizeof(T), NRAM2NRAM, sizeof(T),
|
||||
2 * sizeof(T), pad_total_points - 1);
|
||||
|
||||
// x = loc_x * spatial_w - 0.5; y = loc_y * spatial_h - 0.5;
|
||||
__bang_cycle_mul(buf_x_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points,
|
||||
pad_block_points);
|
||||
__bang_sub_scalar(buf_x_nram, buf_x_nram, (T)0.5, pad_total_points);
|
||||
__bang_cycle_mul(buf_y_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points,
|
||||
pad_block_points);
|
||||
__bang_sub_scalar(buf_y_nram, buf_y_nram, (T)0.5, pad_total_points);
|
||||
|
||||
//================================================================================================
|
||||
// get point condition. use buf0, buf1, buf2
|
||||
// (x > -1 && y > -1 && y < spatial_h && x < spatial_w)
|
||||
__bang_write_value(cond_point_valid_nram, pad_total_points, (T)-1.0);
|
||||
__bang_gt(cond_point_valid_nram, buf_x_nram, cond_point_valid_nram,
|
||||
pad_total_points);
|
||||
__bang_write_value(buf_cond_nram, pad_total_points, (T)-1.0);
|
||||
__bang_gt(buf_cond_nram, buf_y_nram, buf_cond_nram, pad_total_points);
|
||||
|
||||
__bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram,
|
||||
pad_total_points);
|
||||
__bang_cycle_lt(buf_cond_nram, buf_x_nram, spatial_w_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram,
|
||||
pad_total_points);
|
||||
__bang_cycle_lt(buf_cond_nram, buf_y_nram, spatial_h_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram,
|
||||
pad_total_points);
|
||||
//================================================================================================
|
||||
__mluop_floor(buf_x_floor, buf_x_nram, 2 * pad_total_points);
|
||||
__bang_add_scalar(buf_x_ceil, buf_x_floor, 1.0, pad_total_points);
|
||||
__bang_add_scalar(buf_y_ceil, buf_y_floor, 1.0, pad_total_points);
|
||||
|
||||
T* cond_point_polation_nram_tl = cond_point_polation_nram;
|
||||
T* cond_point_polation_nram_bl = cond_point_polation_nram + pad_total_points;
|
||||
T* cond_point_polation_nram_tr =
|
||||
cond_point_polation_nram + 2 * pad_total_points;
|
||||
T* cond_point_polation_nram_br =
|
||||
cond_point_polation_nram + 3 * pad_total_points;
|
||||
T* cond_point_polation_nram_cond1 = weight_polation_nram;
|
||||
T* cond_point_polation_nram_cond2 = weight_polation_nram + pad_total_points;
|
||||
T* cond_point_polation_nram_cond3 =
|
||||
weight_polation_nram + 2 * pad_total_points;
|
||||
T* cond_point_polation_nram_cond4 =
|
||||
weight_polation_nram + 3 * pad_total_points;
|
||||
__bang_ge_scalar(cond_point_polation_nram_cond1, buf_x_floor, (T)0,
|
||||
pad_total_points);
|
||||
__bang_cycle_lt(cond_point_polation_nram_cond2, buf_x_ceil, spatial_w_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_ge_scalar(cond_point_polation_nram_cond3, buf_y_floor, (T)0,
|
||||
pad_total_points);
|
||||
__bang_cycle_lt(cond_point_polation_nram_cond4, buf_y_ceil, spatial_h_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_and(cond_point_polation_nram_tl, cond_point_polation_nram_cond1,
|
||||
cond_point_polation_nram_cond4, pad_total_points);
|
||||
__bang_and(cond_point_polation_nram_bl, cond_point_polation_nram_cond1,
|
||||
cond_point_polation_nram_cond3, pad_total_points);
|
||||
__bang_and(cond_point_polation_nram_tr, cond_point_polation_nram_cond2,
|
||||
cond_point_polation_nram_cond4, pad_total_points);
|
||||
__bang_and(cond_point_polation_nram_br, cond_point_polation_nram_cond2,
|
||||
cond_point_polation_nram_cond3, pad_total_points);
|
||||
//================================================================================================
|
||||
// get polation weight.
|
||||
T* buf_dx = (T*)data_offset_nram;
|
||||
T* buf_dy = buf_dx + pad_total_points;
|
||||
T* buf_dx_1 = buf_dy + pad_total_points;
|
||||
T* buf_dy_1 = buf_dx_1 + pad_total_points;
|
||||
// -dx = x_floor-x
|
||||
// -dy = y_floor-y
|
||||
// w1 = (1-dx)*dy = (dx-1)*(-dy)
|
||||
// w2 = (1-dx)*(1-dy) = (dx-1)*(dy-1)
|
||||
// w3 = dx*dy = (-dx)*(-dy)
|
||||
// w4 = dx*(1-dy) = (-dx)*(dy-1)
|
||||
T* weight_polation_nram_1 = weight_polation_nram;
|
||||
T* weight_polation_nram_2 = weight_polation_nram + 1 * pad_total_points;
|
||||
T* weight_polation_nram_3 = weight_polation_nram + 2 * pad_total_points;
|
||||
T* weight_polation_nram_4 = weight_polation_nram + 3 * pad_total_points;
|
||||
// T* weight_polation_nram_buf = buf_nram + 4 * total_points;
|
||||
__bang_sub(buf_dx, buf_x_floor, buf_x_nram, pad_total_points);
|
||||
__bang_sub(buf_dy, buf_y_floor, buf_y_nram, pad_total_points);
|
||||
|
||||
__bang_sub(buf_dx_1, buf_x_nram, buf_x_floor, pad_total_points);
|
||||
__bang_sub_scalar(buf_dx_1, buf_dx_1, (T)1.0, pad_total_points);
|
||||
|
||||
__bang_sub(buf_dy_1, buf_y_nram, buf_y_floor, pad_total_points);
|
||||
__bang_sub_scalar(buf_dy_1, buf_dy_1, (T)1.0, pad_total_points);
|
||||
|
||||
__bang_mul(weight_polation_nram_1, buf_dx_1, buf_dy, pad_total_points);
|
||||
__bang_mul(weight_polation_nram_2, buf_dx_1, buf_dy_1, pad_total_points);
|
||||
__bang_mul(weight_polation_nram_3, buf_dx, buf_dy, pad_total_points);
|
||||
__bang_mul(weight_polation_nram_4, buf_dx, buf_dy_1, pad_total_points);
|
||||
//================================================================================================
|
||||
// correct the x,y in [0, w-1] and [0, h-1]
|
||||
T* spatial_w1_bd_nram = buf_nram;
|
||||
T* spatial_h1_bd_nram = buf_nram + pad_total_points;
|
||||
__bang_sub_scalar(spatial_w1_bd_nram, spatial_w_bd_nram, (T)1,
|
||||
pad_total_points);
|
||||
__bang_sub_scalar(spatial_h1_bd_nram, spatial_h_bd_nram, (T)1,
|
||||
pad_total_points);
|
||||
T* maxtemp = (T*)data_offset_nram;
|
||||
__bang_write_value(maxtemp, pad_total_points, (T)0);
|
||||
__bang_maxequal(buf_x_floor, buf_x_floor, maxtemp, pad_total_points);
|
||||
__bang_maxequal(buf_x_ceil, buf_x_ceil, maxtemp, pad_total_points);
|
||||
__bang_cycle_minequal(buf_x_floor, buf_x_floor, spatial_w1_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_cycle_minequal(buf_x_ceil, buf_x_ceil, spatial_w1_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_maxequal(buf_y_floor, buf_y_floor, maxtemp, pad_total_points);
|
||||
__bang_maxequal(buf_y_ceil, buf_y_ceil, maxtemp, pad_total_points);
|
||||
__bang_cycle_minequal(buf_y_floor, buf_y_floor, spatial_h1_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_cycle_minequal(buf_y_ceil, buf_y_ceil, spatial_h1_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
//================================================================================================
|
||||
// offset = y*w + x
|
||||
T* buf_hw_offset = buf_nram;
|
||||
T* data_offset_nram_tl = (T*)data_offset_nram;
|
||||
T* data_offset_nram_bl = data_offset_nram_tl + pad_total_points;
|
||||
T* data_offset_nram_tr = data_offset_nram_bl + pad_total_points;
|
||||
T* data_offset_nram_br = data_offset_nram_tr + pad_total_points;
|
||||
// y_ceil*w + offset + x_floor
|
||||
__bang_cycle_mul(buf_hw_offset, buf_y_ceil, spatial_w_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_add(data_offset_nram_tl, buf_hw_offset, buf_x_floor, pad_total_points);
|
||||
// y_ceil*w + offset + x_ceil
|
||||
__bang_add(data_offset_nram_tr, buf_hw_offset, buf_x_ceil, pad_total_points);
|
||||
// y_floor*w + offset + x_foor
|
||||
__bang_cycle_mul(buf_hw_offset, buf_y_floor, spatial_w_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
__bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram,
|
||||
pad_total_points, pad_block_points);
|
||||
|
||||
__bang_add(data_offset_nram_bl, buf_hw_offset, buf_x_floor, pad_total_points);
|
||||
// y_floor*w + offset + x_ceil
|
||||
__bang_add(data_offset_nram_br, buf_hw_offset, buf_x_ceil, pad_total_points);
|
||||
__bang_cycle_and(cond_point_polation_nram, cond_point_polation_nram,
|
||||
cond_point_valid_nram, 4 * pad_total_points,
|
||||
pad_total_points);
|
||||
__bang_cycle_mul(weight_polation_nram, weight_polation_nram, weight_attn_nram,
|
||||
4 * pad_total_points, pad_total_points);
|
||||
__bang_mul(weight_polation_nram, weight_polation_nram,
|
||||
cond_point_polation_nram, pad_total_points * 4);
|
||||
__bang_sub((float*)data_offset_nram_bl, (float*)data_offset_nram_bl,
|
||||
(float*)data_offset_nram_tl, pad_total_points);
|
||||
__bang_sub((float*)data_offset_nram_tr, (float*)data_offset_nram_tr,
|
||||
(float*)data_offset_nram_tl, pad_total_points);
|
||||
}
|
||||
|
||||
/*
|
||||
shape of each tensor:
|
||||
output_nram: (channels)
|
||||
input_nram: (4, valid_num, channels)
|
||||
input_trans: (channels, 4, valid_num)
|
||||
weight_selected_base: (4, deal_n, num_levels, num_points)
|
||||
weight_compute: (4, valid_num)
|
||||
*/
|
||||
template <typename T>
|
||||
__mlu_func__ void reduceLevel(T* output_nram, T* input_nram, T* input_trans,
|
||||
T* weight_selected_base, T* weight_compute,
|
||||
const int32_t pad_num_levels_points,
|
||||
const int32_t pad_channels,
|
||||
const int32_t pad_sample_stride_3) {
|
||||
int32_t ci = 4 * pad_num_levels_points;
|
||||
int32_t co = pad_channels;
|
||||
__bang_write_value(weight_compute, 4 * pad_num_levels_points, 0);
|
||||
__memcpy(weight_compute, weight_selected_base,
|
||||
pad_num_levels_points * sizeof(T), NRAM2NRAM,
|
||||
pad_num_levels_points * sizeof(T), pad_sample_stride_3 * sizeof(T),
|
||||
3);
|
||||
__bang_transpose(input_trans, input_nram, ci, co);
|
||||
__bang_cycle_mul(input_trans, input_trans, weight_compute, co * ci, ci);
|
||||
__bang_sumpool(input_nram, input_trans, pad_num_levels_points, pad_channels,
|
||||
4, 1, 4, 1, 1);
|
||||
__bang_transpose(input_trans, input_nram, pad_channels,
|
||||
pad_num_levels_points);
|
||||
__bang_sumpool(output_nram, input_trans, pad_channels, pad_num_levels_points,
|
||||
1, pad_num_levels_points, 1, 1, 1);
|
||||
}
|
||||
|
||||
__mlu_func__ void loadNram2Gpr(int32_t& v1, int32_t& v2, int32_t& v3,
|
||||
int32_t* p1, int32_t* p2, int32_t* p3,
|
||||
int32_t num_heads, int32_t channels_size,
|
||||
bool sram_stay, int32_t sram_level_start_index) {
|
||||
v1 = (int32_t)(*(float*)p1);
|
||||
v2 = (int32_t)(*(float*)p2);
|
||||
v3 = (int32_t)(*(float*)p3);
|
||||
int32_t stride = sram_stay ? channels_size : num_heads * channels_size;
|
||||
if (sram_stay) {
|
||||
v1 = (v1 - sram_level_start_index) * stride;
|
||||
v2 = v2 * stride;
|
||||
v3 = v3 * stride;
|
||||
} else {
|
||||
v1 = v1 * stride;
|
||||
v2 = v2 * stride;
|
||||
v3 = v3 * stride;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Load 4 neighbors use 2 2D-memcpy, just use offset of N1, stride_3_1
|
||||
and
|
||||
stride_2_1.
|
||||
|<- stride_3_1 ->|
|
||||
N1 N3
|
||||
^
|
||||
|
|
||||
stride_2_1
|
||||
|
|
||||
v
|
||||
N2 N4
|
||||
|
||||
Trickly fold the loop as 2.
|
||||
*/
|
||||
template <typename T, mluMemcpyDirection_t DIR>
|
||||
__mlu_func__ void loadDataValueXram2NramAsync(
|
||||
T* buf_value_nram_1, int32_t* offset_1, int32_t* stride_2_1,
|
||||
int32_t* stride_3_1, T* value_src, const int32_t pad_num_levels_points,
|
||||
const int32_t deal_points, const int32_t start_points_index,
|
||||
const int32_t channel_size, const int32_t num_heads, bool sram_stay,
|
||||
const int32_t sram_level_start_offset) {
|
||||
int32_t offset_1_a, stride_2_1_a, stride_3_1_a;
|
||||
int32_t offset_1_b, stride_2_1_b, stride_3_1_b;
|
||||
loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a,
|
||||
offset_1 + start_points_index, stride_2_1 + start_points_index,
|
||||
stride_3_1 + start_points_index, num_heads, channel_size,
|
||||
sram_stay, sram_level_start_offset);
|
||||
loadNram2Gpr(
|
||||
offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + start_points_index + 1,
|
||||
stride_2_1 + start_points_index + 1, stride_3_1 + start_points_index + 1,
|
||||
num_heads, channel_size, sram_stay, sram_level_start_offset);
|
||||
|
||||
int32_t value_offset = 0;
|
||||
int32_t next = 0;
|
||||
int32_t loop_num = deal_points / 2;
|
||||
int32_t remain = deal_points % 2;
|
||||
int32_t pad_channels =
|
||||
PAD_UP(channel_size / sizeof(T), NFU_ALIGN_SIZE / sizeof(T));
|
||||
int32_t pad_channels_size = pad_channels * sizeof(T);
|
||||
int32_t pad_data_value_stride = pad_num_levels_points * pad_channels_size;
|
||||
for (int32_t j = start_points_index; j < start_points_index + loop_num * 2;
|
||||
j += 2) {
|
||||
value_offset = j * pad_channels_size;
|
||||
next = j + 2;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
__memcpy_async(
|
||||
(int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i,
|
||||
(int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, DIR,
|
||||
2 * pad_data_value_stride, stride_3_1_a, 1);
|
||||
}
|
||||
|
||||
loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1 + next,
|
||||
stride_2_1 + next, stride_3_1 + next, num_heads, channel_size,
|
||||
sram_stay, sram_level_start_offset);
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
__memcpy_async((int8_t*)buf_value_nram_1 + value_offset +
|
||||
pad_channels_size + pad_data_value_stride * i,
|
||||
(int8_t*)value_src + offset_1_b + i * stride_2_1_b,
|
||||
channel_size, DIR, 2 * pad_data_value_stride, stride_3_1_b,
|
||||
1);
|
||||
}
|
||||
|
||||
loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + next + 1,
|
||||
stride_2_1 + next + 1, stride_3_1 + next + 1, num_heads,
|
||||
channel_size, sram_stay, sram_level_start_offset);
|
||||
}
|
||||
|
||||
if (remain > 0) {
|
||||
value_offset = (start_points_index + loop_num * 2) * pad_channels_size;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
__memcpy_async(
|
||||
(int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i,
|
||||
(int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, DIR,
|
||||
2 * pad_data_value_stride, stride_3_1_a, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void loadNeighborPolationAttn(
|
||||
T* value_output_nram, T* value_gdram, int32_t* data_offset_nram,
|
||||
T* weight_polation_nram, T* cond_point_polation_nram,
|
||||
T* cond_point_valid_nram, T* weight_attn_nram, T* buf_nram,
|
||||
T* compute_buf_nram, const int32_t deal_n, const int32_t num_levels,
|
||||
const int32_t num_points, const int32_t num_keys, const int32_t channels,
|
||||
const int32_t num_heads, const int32_t pad_channels,
|
||||
const int32_t pad_num_levels_points, T* value_sram,
|
||||
const int32_t sram_level_start_index,
|
||||
const int32_t sram_level_start_offset) {
|
||||
int32_t channel_size = channels * sizeof(T);
|
||||
int32_t pad_sample_stride_3 = deal_n * pad_num_levels_points;
|
||||
|
||||
T* buf_value_nram = buf_nram; // (4, num_levels, num_points, channels)
|
||||
T* buf_value_nram_trans = buf_nram + 4 * pad_num_levels_points * pad_channels;
|
||||
T* weight_compute_nram = compute_buf_nram; // (4, num_levels, num_points)
|
||||
|
||||
int32_t* offset = data_offset_nram;
|
||||
int32_t* stride_2_1 = offset + pad_sample_stride_3;
|
||||
int32_t* stride_3_1 = stride_2_1 + pad_sample_stride_3;
|
||||
T* output_nram = value_output_nram;
|
||||
int32_t step_offset = 0;
|
||||
for (int32_t i = 0; i < deal_n; i++) {
|
||||
__bang_write_value(buf_value_nram, 4 * pad_num_levels_points * pad_channels,
|
||||
0);
|
||||
__sync_compute();
|
||||
if (sram_level_start_index > 0) {
|
||||
loadDataValueXram2NramAsync<T, GDRAM2NRAM>(
|
||||
buf_value_nram, offset, stride_2_1, stride_3_1, value_gdram,
|
||||
pad_num_levels_points, num_points * sram_level_start_index, 0,
|
||||
channel_size, num_heads, false, sram_level_start_offset);
|
||||
}
|
||||
if (sram_level_start_index < num_levels) {
|
||||
loadDataValueXram2NramAsync<T, SRAM2NRAM>(
|
||||
buf_value_nram, offset, stride_2_1, stride_3_1, value_sram,
|
||||
pad_num_levels_points,
|
||||
num_points * (num_levels - sram_level_start_index),
|
||||
num_points * sram_level_start_index, channel_size, num_heads, true,
|
||||
sram_level_start_offset);
|
||||
}
|
||||
__sync_io_move_compute();
|
||||
reduceLevel(output_nram, buf_value_nram, buf_value_nram_trans,
|
||||
weight_polation_nram + step_offset, weight_compute_nram,
|
||||
pad_num_levels_points, pad_channels, pad_sample_stride_3);
|
||||
step_offset += pad_num_levels_points;
|
||||
offset = data_offset_nram + step_offset;
|
||||
stride_2_1 = offset + pad_sample_stride_3;
|
||||
stride_3_1 = stride_2_1 + pad_sample_stride_3;
|
||||
output_nram += pad_channels;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void prepareLoop(
|
||||
int32_t* spatial_offset_nram, int32_t* spatial_hw_nram,
|
||||
T* spatial_offset_bd_nram, T* spatial_h_bd_nram, T* spatial_w_bd_nram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_spatial_shapes_gdram, const int32_t num_keys,
|
||||
const int32_t num_levels, const int32_t num_points,
|
||||
const int32_t max_deal_n, const int32_t channels) {
|
||||
__memcpy(spatial_offset_nram, data_level_start_index_gdram,
|
||||
num_levels * sizeof(int32_t), GDRAM2NRAM);
|
||||
__memcpy(spatial_hw_nram, data_spatial_shapes_gdram,
|
||||
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
|
||||
broadcastSpatialHW(spatial_offset_bd_nram, spatial_h_bd_nram,
|
||||
spatial_w_bd_nram, spatial_hw_nram, spatial_offset_nram,
|
||||
num_levels, num_points);
|
||||
}
|
||||
|
||||
/*
|
||||
The shape of each tensor:
|
||||
buf_compute_nram: (8, num_levels, num_points)
|
||||
spatial_offset_nram: (num_levels)
|
||||
spatial_hw_nram: (num_levels, 2)
|
||||
spatial_offset_bd_nram: (num_levels, num_points)
|
||||
spatial_w_bd_nram: (num_levels, num_points)
|
||||
spatial_h_bd_nram: (num_levels, num_points)
|
||||
value_output_nram: (deal_n, channels)
|
||||
data_offset_nram: (4, deal_n, num_levels, num_points)
|
||||
weight_polation_nram: (4, deal_n, num_levels, num_points)
|
||||
cond_point_polation_nram: (4, deal_n, num_levels, num_points)
|
||||
cond_point_valid_nram: (deal_n, num_levels, num_points)
|
||||
loc_nram: (deal_n, num_levels, num_points, 2)
|
||||
weight_attn_nram: (deal_n, num_levels, num_points)
|
||||
buf_nram: (6, deal_n, num_levels, num_points)
|
||||
|
||||
Note: buf_nram is reused in polation computing.
|
||||
*/
|
||||
template <typename T>
|
||||
__mlu_func__ void memPolicyCommon(
|
||||
T*& buf_compute_nram, T*& value_output_nram, int32_t*& data_offset_nram,
|
||||
T*& weight_polation_nram, T*& cond_point_polation_nram,
|
||||
T*& cond_point_valid_nram, T*& loc_nram, T*& weight_attn_nram, T*& buf_nram,
|
||||
T*& buf_nram_end, T*& spatial_offset_bd_nram, T*& spatial_w_bd_nram,
|
||||
T*& spatial_h_bd_nram, T*& value_sram, int32_t*& spatial_offset_nram,
|
||||
int32_t*& spatial_hw_nram, int32_t& max_deal_n, int32_t& pad_channels,
|
||||
int32_t& pad_num_levels_points, int32_t& pad_total_points,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points) {
|
||||
pad_channels = PAD_UP(channels, NFU_ALIGN_SIZE / sizeof(T));
|
||||
int32_t num_levels_points = num_levels * num_points;
|
||||
pad_num_levels_points = PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T));
|
||||
int32_t pad_num_levels_points_8 = 8 * pad_num_levels_points;
|
||||
int32_t spatial_info_size =
|
||||
PAD_UP(3 * num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
|
||||
int32_t fix_space_size =
|
||||
spatial_info_size +
|
||||
(3 * pad_num_levels_points + pad_num_levels_points) * sizeof(T);
|
||||
int32_t left_space_size = NRAM_AVALIABLE_SIZE - fix_space_size;
|
||||
int32_t common_buffer_size_each = 6 * pad_num_levels_points * sizeof(T);
|
||||
int32_t inter_result_size_each =
|
||||
17 * pad_num_levels_points * sizeof(T) + pad_channels * sizeof(T);
|
||||
|
||||
max_deal_n =
|
||||
left_space_size / (common_buffer_size_each + inter_result_size_each);
|
||||
|
||||
int32_t compute_buffer_size =
|
||||
(9 * pad_num_levels_points * pad_channels) * sizeof(T);
|
||||
int32_t common_buffer_size = max_deal_n * common_buffer_size_each;
|
||||
// make sure buf_nram is large enough for compute
|
||||
if (compute_buffer_size > common_buffer_size) {
|
||||
int32_t tmp_deal_n =
|
||||
(left_space_size - compute_buffer_size) / inter_result_size_each;
|
||||
max_deal_n = __mluop_min(max_deal_n, tmp_deal_n);
|
||||
}
|
||||
|
||||
pad_total_points = max_deal_n * pad_num_levels_points;
|
||||
buf_compute_nram = (T*)nram_buffer;
|
||||
spatial_offset_nram = (int32_t*)(buf_compute_nram + pad_num_levels_points_8);
|
||||
int32_t pad_3_levels = PAD_UP(3 * num_levels, NFU_ALIGN_SIZE / sizeof(T));
|
||||
spatial_hw_nram = spatial_offset_nram + num_levels;
|
||||
spatial_offset_bd_nram = (T*)(spatial_offset_nram + pad_3_levels);
|
||||
spatial_w_bd_nram = spatial_offset_bd_nram + pad_num_levels_points;
|
||||
spatial_h_bd_nram = spatial_w_bd_nram + pad_num_levels_points;
|
||||
value_output_nram = spatial_h_bd_nram + pad_num_levels_points;
|
||||
data_offset_nram = (int32_t*)(value_output_nram + max_deal_n * pad_channels);
|
||||
weight_polation_nram = (T*)(data_offset_nram + 4 * pad_total_points);
|
||||
cond_point_polation_nram = weight_polation_nram + 4 * pad_total_points;
|
||||
cond_point_valid_nram = cond_point_polation_nram + 4 * pad_total_points;
|
||||
loc_nram = cond_point_valid_nram + pad_total_points;
|
||||
weight_attn_nram =
|
||||
loc_nram +
|
||||
2 * pad_total_points; // total_coord_pad = 2 * pad_total_points
|
||||
buf_nram = weight_attn_nram + pad_total_points;
|
||||
buf_nram_end = buf_nram + 6 * max_deal_n * pad_num_levels_points;
|
||||
value_sram = (T*)sram_buffer;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void loadDataValueGdram2Sram(T* value_sram, T* data_value_gdram,
|
||||
const int32_t batch_idx,
|
||||
const int32_t head_idx,
|
||||
const int32_t sram_num_keys,
|
||||
const int32_t num_heads,
|
||||
const int32_t channels,
|
||||
const int32_t skip_num_key) {
|
||||
int32_t loop_num =
|
||||
(sram_num_keys + MAX_MEMCPY_SEGNUM - 1) / MAX_MEMCPY_SEGNUM;
|
||||
int32_t num_heads_channels = num_heads * channels;
|
||||
for (int32_t i = 0; i < loop_num; i++) {
|
||||
int32_t load_num =
|
||||
__mluop_min(MAX_MEMCPY_SEGNUM, sram_num_keys - i * MAX_MEMCPY_SEGNUM);
|
||||
size_t src_offset = ((size_t)batch_idx * sram_num_keys + skip_num_key +
|
||||
i * MAX_MEMCPY_SEGNUM) *
|
||||
num_heads_channels +
|
||||
head_idx * channels;
|
||||
int32_t dst_offset = i * MAX_MEMCPY_SEGNUM * channels;
|
||||
__memcpy(value_sram + dst_offset, (T*)data_value_gdram + src_offset,
|
||||
channels * sizeof(T), GDRAM2SRAM, channels * sizeof(T),
|
||||
num_heads_channels * sizeof(T), load_num - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void computeSramCacheSizeAndOffset(
|
||||
int32_t* sram_level_cache_size, int32_t* sram_level_start_index,
|
||||
int32_t* sram_level_start_offset, const int32_t num_levels,
|
||||
const int32_t num_keys, const int32_t channels,
|
||||
const T* data_level_start_index_gdram, const int32_t sram_size) {
|
||||
for (int32_t level_id = num_levels; level_id > 0; level_id--) {
|
||||
int current_level_end_index =
|
||||
level_id == num_levels
|
||||
? num_keys
|
||||
: ((int32_t*)data_level_start_index_gdram)[level_id];
|
||||
int32_t current_level_size =
|
||||
current_level_end_index -
|
||||
((int32_t*)data_level_start_index_gdram)[level_id - 1];
|
||||
if ((*sram_level_cache_size + current_level_size) * channels * sizeof(T) >
|
||||
sram_size) {
|
||||
break;
|
||||
}
|
||||
*sram_level_cache_size += current_level_size;
|
||||
*sram_level_start_index = level_id - 1;
|
||||
*sram_level_start_offset = num_keys - *sram_level_cache_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl(
|
||||
const char* data_value_gdram, const char* data_spatial_shapes_gdram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram) {
|
||||
int32_t input_stride_4 = num_queries * num_heads * num_levels * num_points;
|
||||
int32_t input_stride_3 = num_heads * num_levels * num_points;
|
||||
int32_t input_stride_2 = num_levels * num_points;
|
||||
int32_t output_stride_3 = num_queries * num_heads * channels;
|
||||
int32_t output_stride_2 = num_heads * channels;
|
||||
int32_t data_value_stride_3 = num_keys * num_heads * channels;
|
||||
|
||||
T* value_output_nram = nullptr; // (deal_n, channels)
|
||||
int32_t* data_offset_nram = nullptr; // (4, deal_n, num_levels, num_points)
|
||||
T* weight_polation_nram = nullptr; // (4, deal_n, num_levels, num_points)
|
||||
T* cond_point_polation_nram = nullptr; // (4, deal_n, num_levels, num_points)
|
||||
T* cond_point_valid_nram = nullptr; // (deal_n, num_levels, num_points)
|
||||
T* loc_nram = nullptr; // (deal_n, num_levels, num_points, 2)
|
||||
T* weight_attn_nram = nullptr; // (deal_n, num_levels, num_points)
|
||||
T* buf_nram = nullptr; // (6, deal_n, num_levels, num_points)
|
||||
T* buf_nram_end = nullptr;
|
||||
T* spatial_offset_bd_nram = nullptr; // (num_levels, num_points)
|
||||
T* spatial_w_bd_nram = nullptr; // (num_levels, num_points)
|
||||
T* spatial_h_bd_nram = nullptr; // (num_levels, num_points)
|
||||
int32_t* spatial_offset_nram = nullptr; // (num_levels)
|
||||
int32_t* spatial_hw_nram = nullptr; // (num_levels, 2)
|
||||
T* buf_compute_nram = nullptr; // (8, num_levels, num_points)
|
||||
int32_t max_deal_n = 0;
|
||||
int32_t pad_channels = 0;
|
||||
int32_t pad_num_levels_points = 0;
|
||||
int32_t pad_total_points = 0;
|
||||
T* value_sram = nullptr;
|
||||
|
||||
memPolicyCommon(buf_compute_nram, value_output_nram, data_offset_nram,
|
||||
weight_polation_nram, cond_point_polation_nram,
|
||||
cond_point_valid_nram, loc_nram, weight_attn_nram, buf_nram,
|
||||
buf_nram_end, spatial_offset_bd_nram, spatial_w_bd_nram,
|
||||
spatial_h_bd_nram, value_sram, spatial_offset_nram,
|
||||
spatial_hw_nram, max_deal_n, pad_channels,
|
||||
pad_num_levels_points, pad_total_points, batch_size, num_keys,
|
||||
num_heads, channels, num_levels, num_queries, num_points);
|
||||
if (max_deal_n <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// split batch*head into taskDimY
|
||||
int32_t batch_head = batch_size * num_heads;
|
||||
int32_t cluster_avg_batch_head = (batch_head + taskDimY - 1) / taskDimY;
|
||||
int32_t cluster_begin_batch_head = taskIdY * cluster_avg_batch_head;
|
||||
int32_t cluster_act_batch_head = __mluop_min(
|
||||
cluster_avg_batch_head, batch_head - cluster_begin_batch_head);
|
||||
int32_t cluster_end_batch_head =
|
||||
cluster_begin_batch_head + cluster_act_batch_head;
|
||||
// split query into coreDim
|
||||
int32_t core_avg_query = (num_queries + coreDim - 1) / coreDim;
|
||||
int32_t core_begin_query = coreId * core_avg_query;
|
||||
int32_t core_act_query =
|
||||
__mluop_min(num_queries - core_begin_query, core_avg_query);
|
||||
int32_t core_loop_num = (core_act_query + max_deal_n - 1) / max_deal_n;
|
||||
int32_t core_step_query =
|
||||
core_loop_num > 0 ? (core_act_query + core_loop_num - 1) / core_loop_num
|
||||
: 0;
|
||||
int32_t core_remain_query =
|
||||
core_act_query - (core_loop_num - 1) * core_step_query;
|
||||
int32_t first_deal_query =
|
||||
(int)(core_loop_num > 0) *
|
||||
(core_loop_num > 1 ? core_step_query : core_remain_query);
|
||||
|
||||
prepareLoop(spatial_offset_nram, spatial_hw_nram, spatial_offset_bd_nram,
|
||||
spatial_h_bd_nram, spatial_w_bd_nram,
|
||||
data_level_start_index_gdram, data_spatial_shapes_gdram, num_keys,
|
||||
num_levels, num_points, max_deal_n, channels);
|
||||
|
||||
int sram_total_size = 0;
|
||||
int sram_level_start_index = num_levels;
|
||||
int sram_level_start_offset = 0;
|
||||
computeSramCacheSizeAndOffset(
|
||||
&sram_total_size, &sram_level_start_index, &sram_level_start_offset,
|
||||
num_levels, num_keys, channels, (int32_t*)data_level_start_index_gdram,
|
||||
SRAM_FOR_VALUE_SIZE);
|
||||
|
||||
for (int32_t bh_idx = cluster_begin_batch_head;
|
||||
bh_idx < cluster_end_batch_head; bh_idx++) {
|
||||
int32_t b = bh_idx / num_heads;
|
||||
int32_t head_idx = bh_idx % num_heads;
|
||||
|
||||
size_t output_base_offset =
|
||||
(size_t)b * output_stride_3 + head_idx * channels;
|
||||
int32_t attn_weight_base_offset =
|
||||
b * input_stride_4 + head_idx * input_stride_2;
|
||||
|
||||
if (__is_mpu() && (sram_level_start_index != num_levels)) {
|
||||
loadDataValueGdram2Sram(value_sram, (T*)data_value_gdram, b, head_idx,
|
||||
sram_total_size, num_heads, channels,
|
||||
sram_level_start_offset);
|
||||
}
|
||||
|
||||
__sync_cluster();
|
||||
|
||||
if (__is_ipu()) {
|
||||
// compute weight, offset and condition
|
||||
int32_t attn_weight_offset =
|
||||
attn_weight_base_offset + core_begin_query * input_stride_3;
|
||||
int32_t loc_offset = attn_weight_offset * 2;
|
||||
if (first_deal_query > 0) {
|
||||
__bang_write_value(loc_nram, 2 * pad_total_points, 0);
|
||||
__bang_write_value(weight_attn_nram, pad_total_points, 0);
|
||||
__sync_compute();
|
||||
__memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset,
|
||||
input_stride_2 * 2 * sizeof(T), GDRAM2NRAM,
|
||||
pad_num_levels_points * 2 * sizeof(T),
|
||||
input_stride_3 * 2 * sizeof(T), first_deal_query - 1);
|
||||
__memcpy_async(weight_attn_nram,
|
||||
(T*)data_attn_weight_gdram + attn_weight_offset,
|
||||
input_stride_2 * sizeof(T), GDRAM2NRAM,
|
||||
pad_num_levels_points * sizeof(T),
|
||||
input_stride_3 * sizeof(T), first_deal_query - 1);
|
||||
getConditionCoordWeight<T>(
|
||||
data_offset_nram, weight_polation_nram, cond_point_polation_nram,
|
||||
cond_point_valid_nram, loc_nram, weight_attn_nram,
|
||||
spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram,
|
||||
buf_nram, first_deal_query, num_levels, num_points, num_heads,
|
||||
pad_num_levels_points);
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t i = 0; __is_ipu() && i < core_loop_num; i++) {
|
||||
__bang_write_value(loc_nram, 2 * pad_total_points, 0);
|
||||
__bang_write_value(weight_attn_nram, pad_total_points, 0);
|
||||
int32_t deal_n =
|
||||
i < core_loop_num - 1 ? core_step_query : core_remain_query;
|
||||
int32_t load_n =
|
||||
i < core_loop_num - 2 ? core_step_query : core_remain_query;
|
||||
// load value and polation
|
||||
loadNeighborPolationAttn<T>(
|
||||
value_output_nram,
|
||||
(T*)data_value_gdram + b * data_value_stride_3 + head_idx * channels,
|
||||
data_offset_nram, weight_polation_nram, cond_point_polation_nram,
|
||||
cond_point_valid_nram, weight_attn_nram, buf_nram, buf_compute_nram,
|
||||
deal_n, num_levels, num_points, num_keys, channels, num_heads,
|
||||
pad_channels, pad_num_levels_points, value_sram,
|
||||
sram_level_start_index, sram_level_start_offset);
|
||||
__sync_io_move_compute();
|
||||
// load next weight and loc
|
||||
if (i < core_loop_num - 1) {
|
||||
int32_t core_query_offset = (i + 1) * core_step_query;
|
||||
int32_t attn_weight_offset =
|
||||
attn_weight_base_offset +
|
||||
(core_begin_query + core_query_offset) * input_stride_3;
|
||||
int32_t loc_offset = attn_weight_offset * 2;
|
||||
__memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset,
|
||||
input_stride_2 * 2 * sizeof(T), GDRAM2NRAM,
|
||||
pad_num_levels_points * 2 * sizeof(T),
|
||||
input_stride_3 * 2 * sizeof(T), load_n - 1);
|
||||
__memcpy_async(weight_attn_nram,
|
||||
(T*)data_attn_weight_gdram + attn_weight_offset,
|
||||
input_stride_2 * sizeof(T), GDRAM2NRAM,
|
||||
pad_num_levels_points * sizeof(T),
|
||||
input_stride_3 * sizeof(T), load_n - 1);
|
||||
__sync_io_move_compute();
|
||||
}
|
||||
// store result
|
||||
size_t output_offset =
|
||||
((size_t)core_begin_query + i * core_step_query) * output_stride_2;
|
||||
__memcpy_async((T*)data_col_gdram + output_base_offset + output_offset,
|
||||
value_output_nram, channels * sizeof(T), NRAM2GDRAM,
|
||||
output_stride_2 * sizeof(T), pad_channels * sizeof(T),
|
||||
deal_n - 1);
|
||||
|
||||
// compute cond/weight/offset
|
||||
if (i < core_loop_num - 1) {
|
||||
getConditionCoordWeight<T>(
|
||||
data_offset_nram, weight_polation_nram, cond_point_polation_nram,
|
||||
cond_point_valid_nram, loc_nram, weight_attn_nram,
|
||||
spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram,
|
||||
buf_nram, load_n, num_levels, num_points, num_heads,
|
||||
pad_num_levels_points);
|
||||
}
|
||||
__sync_io_move_compute();
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUKernelMsDeformAttnForwardFast(
|
||||
const char* data_value_gdram, const char* data_spatial_shapes_gdram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram) {
|
||||
MLUKernelMsDeformAttnForwardFastImpl<float>(
|
||||
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
|
||||
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
|
||||
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
|
||||
}
|
||||
|
||||
template __mlu_global__ void MLUKernelMsDeformAttnForwardFast<float>(
|
||||
const char* data_value_gdram, const char* data_spatial_shapes_gdram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram);
|
||||
|
||||
void KernelMsDeformAttnForwardFast(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const char* data_value_gdram,
|
||||
const char* data_spatial_shapes_gdram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram) {
|
||||
MLUKernelMsDeformAttnForwardFast<float><<<k_dim, k_type, queue>>>(
|
||||
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
|
||||
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
|
||||
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
|
||||
}
|
|
@ -11,6 +11,7 @@
|
|||
*************************************************************************/
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
#include "ms_deform_attn_fast_mlu_kernel.hpp"
|
||||
|
||||
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
|
||||
|
||||
|
@ -20,6 +21,8 @@ typedef enum {
|
|||
1, /*!< MLUKernelMsDeformAttnForwardDefault */
|
||||
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
|
||||
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
|
||||
MS_DEFORM_ATTN_FORWARD_FAST =
|
||||
3, /*!< MLUKernelMsDeformAttnForwardFast */
|
||||
} MsDeformAttnForwardPolicy;
|
||||
|
||||
void KernelMsDeformAttnForwardDefault(
|
||||
|
@ -40,6 +43,15 @@ void KernelMsDeformAttnForwardSmallChannel(
|
|||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram);
|
||||
void KernelMsDeformAttnForwardFast(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
const cnrtDataType_t d_type, const char* data_value_gdram,
|
||||
const char* data_spatial_shapes_gdram,
|
||||
const char* data_level_start_index_gdram,
|
||||
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
|
||||
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
|
||||
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
|
||||
const int32_t num_points, char* data_col_gdram);
|
||||
|
||||
typedef enum {
|
||||
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
|
||||
|
@ -99,7 +111,9 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
|
|||
#endif
|
||||
|
||||
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
|
||||
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
|
||||
if (num_levels * num_points <= 128 && num_levels * num_points * channels <= 8192) {
|
||||
return MS_DEFORM_ATTN_FORWARD_FAST;
|
||||
} else if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
|
||||
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||
} else if (channels > nram_size / 12 / sizeof(float)) {
|
||||
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
|
||||
|
@ -310,6 +324,18 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
|
|||
(char*)output_ptr);
|
||||
break;
|
||||
}
|
||||
case MS_DEFORM_ATTN_FORWARD_FAST: {
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardFast<<<"
|
||||
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
KernelMsDeformAttnForwardFast(
|
||||
k_dim, k_type, queue, data_type, (char*)value_ptr,
|
||||
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
|
||||
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
|
||||
num_heads, channels, num_levels, num_queries, num_points,
|
||||
(char*)output_ptr);
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
output = output.view({batch_size, num_queries, num_heads * channels});
|
||||
|
|
Loading…
Reference in New Issue