[Feature] Support MsDeformAttnForward with fast kernel (#3157)

dev-mlu290
liuduanhui 2024-08-02 16:04:17 +08:00 committed by GitHub
parent 055a056c21
commit 8c21bf6b77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 873 additions and 1 deletions

View File

@ -0,0 +1,23 @@
/*************************************************************************
* Copyright (C) [2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_
#define MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_
void KernelMsDeformAttnForwardFast(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
#endif // MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_

View File

@ -0,0 +1,823 @@
/*************************************************************************
* Copyright (C) [2024] by Cambricon, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "ms_deform_attn_fast_mlu_kernel.hpp"
#define NRAM_REMAIN_SIZE (48 * 1024)
#define NRAM_AVALIABLE_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
#define SRAM_REMAIN_SIZE (32 * 1024)
#define SRAM_AVALIABLE_SIZE (__MLU_SRAM_SIZE__ * 1024 - SRAM_REMAIN_SIZE)
#define SRAM_FOR_VALUE_SIZE (SRAM_AVALIABLE_SIZE - 128)
#define MAX_MEMCPY_SEGNUM 65536
__nram__ char nram_buffer[NRAM_AVALIABLE_SIZE];
__mlu_shared__ char sram_buffer[SRAM_AVALIABLE_SIZE];
template <typename T>
__mlu_func__ inline T __mluop_min(T a, T b) {
return a < b ? a : b;
}
template <typename T>
__mlu_func__ inline T __mluop_max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__mlu_func__ void __mluop_floor(T* dst_ram, T* src_ram, int size) {
if (sizeof(T) == sizeof(float)) {
int16* mid = (int16*)(dst_ram + size / 2);
__bang_float2int16_dn(mid, (float*)src_ram, size, 0);
__bang_int162float((float*)dst_ram, (int16_t*)mid, size, 0);
} else {
__bang_half2int16_dn((int16_t*)dst_ram, (half*)src_ram, size, 0);
__bang_int162half((half*)dst_ram, (int16_t*)dst_ram, size, 0);
}
}
__mlu_func__ void broadcastSpatialHW(
float* spatial_offset_bd_nram, // (num_levels, num_points)
float* spatial_h_bd_nram, // (num_levels, num_points)
float* spatial_w_bd_nram, // (num_levels, num_points)
int32_t* spatial_shapes_nram, // (num_levels, 2)
int32_t* spatial_offset_nram, // (num_levels)
const int32_t num_levels, const int32_t num_points) {
for (int i = 0; i < num_levels * 2; i++) {
((float*)spatial_shapes_nram)[i] = (float)spatial_shapes_nram[i];
}
for (int i = 0; i < num_levels; i++) {
((float*)spatial_offset_nram)[i] = (float)spatial_offset_nram[i];
}
for (int i = 0; i < num_levels; i++) {
__memcpy(spatial_h_bd_nram + i * num_points, spatial_shapes_nram + i * 2,
sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1);
}
for (int i = 0; i < num_levels; i++) {
__memcpy(spatial_w_bd_nram + i * num_points,
spatial_shapes_nram + 1 + i * 2, sizeof(float), NRAM2NRAM,
sizeof(float), 0, num_points - 1);
}
for (int i = 0; i < num_levels; i++) {
__memcpy(spatial_offset_bd_nram + i * num_points, spatial_offset_nram + i,
sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1);
}
}
template <typename T>
__mlu_func__ void getConditionCoordWeight(
int32_t* data_offset_nram, T* weight_polation_nram,
T* cond_point_polation_nram, T* cond_point_valid_nram, T* loc_nram,
T* weight_attn_nram, T* spatial_offset_bd_nram, T* spatial_w_bd_nram,
T* spatial_h_bd_nram, T* buf_nram, const int32_t deal_n,
const int32_t num_levels, const int32_t num_points, const int32_t num_heads,
int32_t pad_num_levels_points) {
int32_t pad_total_points = deal_n * pad_num_levels_points;
int32_t pad_block_points = pad_num_levels_points;
T* buf_x_nram = buf_nram;
T* buf_y_nram = buf_nram + pad_total_points;
T* buf_cond_nram = buf_nram + 2 * pad_total_points;
T* buf_x_floor = buf_nram + 2 * pad_total_points;
T* buf_y_floor = buf_nram + 3 * pad_total_points;
T* buf_x_ceil = buf_nram + 4 * pad_total_points;
T* buf_y_ceil = buf_nram + 5 * pad_total_points;
__bang_write_value(buf_x_nram, pad_total_points, 0);
__bang_write_value(buf_y_nram, pad_total_points, 0);
__bang_write_value(buf_x_floor, pad_total_points, 0);
__bang_write_value(buf_x_ceil, pad_total_points, 0);
__bang_write_value(buf_y_floor, pad_total_points, 0);
__bang_write_value(buf_y_ceil, pad_total_points, 0);
//================================================================================================
__memcpy(buf_x_nram, loc_nram, sizeof(T), NRAM2NRAM, sizeof(T), 2 * sizeof(T),
pad_total_points - 1);
__memcpy(buf_y_nram, loc_nram + 1, sizeof(T), NRAM2NRAM, sizeof(T),
2 * sizeof(T), pad_total_points - 1);
// x = loc_x * spatial_w - 0.5; y = loc_y * spatial_h - 0.5;
__bang_cycle_mul(buf_x_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points,
pad_block_points);
__bang_sub_scalar(buf_x_nram, buf_x_nram, (T)0.5, pad_total_points);
__bang_cycle_mul(buf_y_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points,
pad_block_points);
__bang_sub_scalar(buf_y_nram, buf_y_nram, (T)0.5, pad_total_points);
//================================================================================================
// get point condition. use buf0, buf1, buf2
// (x > -1 && y > -1 && y < spatial_h && x < spatial_w)
__bang_write_value(cond_point_valid_nram, pad_total_points, (T)-1.0);
__bang_gt(cond_point_valid_nram, buf_x_nram, cond_point_valid_nram,
pad_total_points);
__bang_write_value(buf_cond_nram, pad_total_points, (T)-1.0);
__bang_gt(buf_cond_nram, buf_y_nram, buf_cond_nram, pad_total_points);
__bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram,
pad_total_points);
__bang_cycle_lt(buf_cond_nram, buf_x_nram, spatial_w_bd_nram,
pad_total_points, pad_block_points);
__bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram,
pad_total_points);
__bang_cycle_lt(buf_cond_nram, buf_y_nram, spatial_h_bd_nram,
pad_total_points, pad_block_points);
__bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram,
pad_total_points);
//================================================================================================
__mluop_floor(buf_x_floor, buf_x_nram, 2 * pad_total_points);
__bang_add_scalar(buf_x_ceil, buf_x_floor, 1.0, pad_total_points);
__bang_add_scalar(buf_y_ceil, buf_y_floor, 1.0, pad_total_points);
T* cond_point_polation_nram_tl = cond_point_polation_nram;
T* cond_point_polation_nram_bl = cond_point_polation_nram + pad_total_points;
T* cond_point_polation_nram_tr =
cond_point_polation_nram + 2 * pad_total_points;
T* cond_point_polation_nram_br =
cond_point_polation_nram + 3 * pad_total_points;
T* cond_point_polation_nram_cond1 = weight_polation_nram;
T* cond_point_polation_nram_cond2 = weight_polation_nram + pad_total_points;
T* cond_point_polation_nram_cond3 =
weight_polation_nram + 2 * pad_total_points;
T* cond_point_polation_nram_cond4 =
weight_polation_nram + 3 * pad_total_points;
__bang_ge_scalar(cond_point_polation_nram_cond1, buf_x_floor, (T)0,
pad_total_points);
__bang_cycle_lt(cond_point_polation_nram_cond2, buf_x_ceil, spatial_w_bd_nram,
pad_total_points, pad_block_points);
__bang_ge_scalar(cond_point_polation_nram_cond3, buf_y_floor, (T)0,
pad_total_points);
__bang_cycle_lt(cond_point_polation_nram_cond4, buf_y_ceil, spatial_h_bd_nram,
pad_total_points, pad_block_points);
__bang_and(cond_point_polation_nram_tl, cond_point_polation_nram_cond1,
cond_point_polation_nram_cond4, pad_total_points);
__bang_and(cond_point_polation_nram_bl, cond_point_polation_nram_cond1,
cond_point_polation_nram_cond3, pad_total_points);
__bang_and(cond_point_polation_nram_tr, cond_point_polation_nram_cond2,
cond_point_polation_nram_cond4, pad_total_points);
__bang_and(cond_point_polation_nram_br, cond_point_polation_nram_cond2,
cond_point_polation_nram_cond3, pad_total_points);
//================================================================================================
// get polation weight.
T* buf_dx = (T*)data_offset_nram;
T* buf_dy = buf_dx + pad_total_points;
T* buf_dx_1 = buf_dy + pad_total_points;
T* buf_dy_1 = buf_dx_1 + pad_total_points;
// -dx = x_floor-x
// -dy = y_floor-y
// w1 = (1-dx)*dy = (dx-1)*(-dy)
// w2 = (1-dx)*(1-dy) = (dx-1)*(dy-1)
// w3 = dx*dy = (-dx)*(-dy)
// w4 = dx*(1-dy) = (-dx)*(dy-1)
T* weight_polation_nram_1 = weight_polation_nram;
T* weight_polation_nram_2 = weight_polation_nram + 1 * pad_total_points;
T* weight_polation_nram_3 = weight_polation_nram + 2 * pad_total_points;
T* weight_polation_nram_4 = weight_polation_nram + 3 * pad_total_points;
// T* weight_polation_nram_buf = buf_nram + 4 * total_points;
__bang_sub(buf_dx, buf_x_floor, buf_x_nram, pad_total_points);
__bang_sub(buf_dy, buf_y_floor, buf_y_nram, pad_total_points);
__bang_sub(buf_dx_1, buf_x_nram, buf_x_floor, pad_total_points);
__bang_sub_scalar(buf_dx_1, buf_dx_1, (T)1.0, pad_total_points);
__bang_sub(buf_dy_1, buf_y_nram, buf_y_floor, pad_total_points);
__bang_sub_scalar(buf_dy_1, buf_dy_1, (T)1.0, pad_total_points);
__bang_mul(weight_polation_nram_1, buf_dx_1, buf_dy, pad_total_points);
__bang_mul(weight_polation_nram_2, buf_dx_1, buf_dy_1, pad_total_points);
__bang_mul(weight_polation_nram_3, buf_dx, buf_dy, pad_total_points);
__bang_mul(weight_polation_nram_4, buf_dx, buf_dy_1, pad_total_points);
//================================================================================================
// correct the x,y in [0, w-1] and [0, h-1]
T* spatial_w1_bd_nram = buf_nram;
T* spatial_h1_bd_nram = buf_nram + pad_total_points;
__bang_sub_scalar(spatial_w1_bd_nram, spatial_w_bd_nram, (T)1,
pad_total_points);
__bang_sub_scalar(spatial_h1_bd_nram, spatial_h_bd_nram, (T)1,
pad_total_points);
T* maxtemp = (T*)data_offset_nram;
__bang_write_value(maxtemp, pad_total_points, (T)0);
__bang_maxequal(buf_x_floor, buf_x_floor, maxtemp, pad_total_points);
__bang_maxequal(buf_x_ceil, buf_x_ceil, maxtemp, pad_total_points);
__bang_cycle_minequal(buf_x_floor, buf_x_floor, spatial_w1_bd_nram,
pad_total_points, pad_block_points);
__bang_cycle_minequal(buf_x_ceil, buf_x_ceil, spatial_w1_bd_nram,
pad_total_points, pad_block_points);
__bang_maxequal(buf_y_floor, buf_y_floor, maxtemp, pad_total_points);
__bang_maxequal(buf_y_ceil, buf_y_ceil, maxtemp, pad_total_points);
__bang_cycle_minequal(buf_y_floor, buf_y_floor, spatial_h1_bd_nram,
pad_total_points, pad_block_points);
__bang_cycle_minequal(buf_y_ceil, buf_y_ceil, spatial_h1_bd_nram,
pad_total_points, pad_block_points);
//================================================================================================
// offset = y*w + x
T* buf_hw_offset = buf_nram;
T* data_offset_nram_tl = (T*)data_offset_nram;
T* data_offset_nram_bl = data_offset_nram_tl + pad_total_points;
T* data_offset_nram_tr = data_offset_nram_bl + pad_total_points;
T* data_offset_nram_br = data_offset_nram_tr + pad_total_points;
// y_ceil*w + offset + x_floor
__bang_cycle_mul(buf_hw_offset, buf_y_ceil, spatial_w_bd_nram,
pad_total_points, pad_block_points);
__bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram,
pad_total_points, pad_block_points);
__bang_add(data_offset_nram_tl, buf_hw_offset, buf_x_floor, pad_total_points);
// y_ceil*w + offset + x_ceil
__bang_add(data_offset_nram_tr, buf_hw_offset, buf_x_ceil, pad_total_points);
// y_floor*w + offset + x_foor
__bang_cycle_mul(buf_hw_offset, buf_y_floor, spatial_w_bd_nram,
pad_total_points, pad_block_points);
__bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram,
pad_total_points, pad_block_points);
__bang_add(data_offset_nram_bl, buf_hw_offset, buf_x_floor, pad_total_points);
// y_floor*w + offset + x_ceil
__bang_add(data_offset_nram_br, buf_hw_offset, buf_x_ceil, pad_total_points);
__bang_cycle_and(cond_point_polation_nram, cond_point_polation_nram,
cond_point_valid_nram, 4 * pad_total_points,
pad_total_points);
__bang_cycle_mul(weight_polation_nram, weight_polation_nram, weight_attn_nram,
4 * pad_total_points, pad_total_points);
__bang_mul(weight_polation_nram, weight_polation_nram,
cond_point_polation_nram, pad_total_points * 4);
__bang_sub((float*)data_offset_nram_bl, (float*)data_offset_nram_bl,
(float*)data_offset_nram_tl, pad_total_points);
__bang_sub((float*)data_offset_nram_tr, (float*)data_offset_nram_tr,
(float*)data_offset_nram_tl, pad_total_points);
}
/*
shape of each tensor:
output_nram: (channels)
input_nram: (4, valid_num, channels)
input_trans: (channels, 4, valid_num)
weight_selected_base: (4, deal_n, num_levels, num_points)
weight_compute: (4, valid_num)
*/
template <typename T>
__mlu_func__ void reduceLevel(T* output_nram, T* input_nram, T* input_trans,
T* weight_selected_base, T* weight_compute,
const int32_t pad_num_levels_points,
const int32_t pad_channels,
const int32_t pad_sample_stride_3) {
int32_t ci = 4 * pad_num_levels_points;
int32_t co = pad_channels;
__bang_write_value(weight_compute, 4 * pad_num_levels_points, 0);
__memcpy(weight_compute, weight_selected_base,
pad_num_levels_points * sizeof(T), NRAM2NRAM,
pad_num_levels_points * sizeof(T), pad_sample_stride_3 * sizeof(T),
3);
__bang_transpose(input_trans, input_nram, ci, co);
__bang_cycle_mul(input_trans, input_trans, weight_compute, co * ci, ci);
__bang_sumpool(input_nram, input_trans, pad_num_levels_points, pad_channels,
4, 1, 4, 1, 1);
__bang_transpose(input_trans, input_nram, pad_channels,
pad_num_levels_points);
__bang_sumpool(output_nram, input_trans, pad_channels, pad_num_levels_points,
1, pad_num_levels_points, 1, 1, 1);
}
__mlu_func__ void loadNram2Gpr(int32_t& v1, int32_t& v2, int32_t& v3,
int32_t* p1, int32_t* p2, int32_t* p3,
int32_t num_heads, int32_t channels_size,
bool sram_stay, int32_t sram_level_start_index) {
v1 = (int32_t)(*(float*)p1);
v2 = (int32_t)(*(float*)p2);
v3 = (int32_t)(*(float*)p3);
int32_t stride = sram_stay ? channels_size : num_heads * channels_size;
if (sram_stay) {
v1 = (v1 - sram_level_start_index) * stride;
v2 = v2 * stride;
v3 = v3 * stride;
} else {
v1 = v1 * stride;
v2 = v2 * stride;
v3 = v3 * stride;
}
}
/*
Load 4 neighbors use 2 2D-memcpy, just use offset of N1, stride_3_1
and
stride_2_1.
|<- stride_3_1 ->|
N1 N3
^
|
stride_2_1
|
v
N2 N4
Trickly fold the loop as 2.
*/
template <typename T, mluMemcpyDirection_t DIR>
__mlu_func__ void loadDataValueXram2NramAsync(
T* buf_value_nram_1, int32_t* offset_1, int32_t* stride_2_1,
int32_t* stride_3_1, T* value_src, const int32_t pad_num_levels_points,
const int32_t deal_points, const int32_t start_points_index,
const int32_t channel_size, const int32_t num_heads, bool sram_stay,
const int32_t sram_level_start_offset) {
int32_t offset_1_a, stride_2_1_a, stride_3_1_a;
int32_t offset_1_b, stride_2_1_b, stride_3_1_b;
loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a,
offset_1 + start_points_index, stride_2_1 + start_points_index,
stride_3_1 + start_points_index, num_heads, channel_size,
sram_stay, sram_level_start_offset);
loadNram2Gpr(
offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + start_points_index + 1,
stride_2_1 + start_points_index + 1, stride_3_1 + start_points_index + 1,
num_heads, channel_size, sram_stay, sram_level_start_offset);
int32_t value_offset = 0;
int32_t next = 0;
int32_t loop_num = deal_points / 2;
int32_t remain = deal_points % 2;
int32_t pad_channels =
PAD_UP(channel_size / sizeof(T), NFU_ALIGN_SIZE / sizeof(T));
int32_t pad_channels_size = pad_channels * sizeof(T);
int32_t pad_data_value_stride = pad_num_levels_points * pad_channels_size;
for (int32_t j = start_points_index; j < start_points_index + loop_num * 2;
j += 2) {
value_offset = j * pad_channels_size;
next = j + 2;
for (int i = 0; i < 2; i++) {
__memcpy_async(
(int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i,
(int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, DIR,
2 * pad_data_value_stride, stride_3_1_a, 1);
}
loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1 + next,
stride_2_1 + next, stride_3_1 + next, num_heads, channel_size,
sram_stay, sram_level_start_offset);
for (int i = 0; i < 2; i++) {
__memcpy_async((int8_t*)buf_value_nram_1 + value_offset +
pad_channels_size + pad_data_value_stride * i,
(int8_t*)value_src + offset_1_b + i * stride_2_1_b,
channel_size, DIR, 2 * pad_data_value_stride, stride_3_1_b,
1);
}
loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + next + 1,
stride_2_1 + next + 1, stride_3_1 + next + 1, num_heads,
channel_size, sram_stay, sram_level_start_offset);
}
if (remain > 0) {
value_offset = (start_points_index + loop_num * 2) * pad_channels_size;
for (int i = 0; i < 2; i++) {
__memcpy_async(
(int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i,
(int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, DIR,
2 * pad_data_value_stride, stride_3_1_a, 1);
}
}
}
template <typename T>
__mlu_func__ void loadNeighborPolationAttn(
T* value_output_nram, T* value_gdram, int32_t* data_offset_nram,
T* weight_polation_nram, T* cond_point_polation_nram,
T* cond_point_valid_nram, T* weight_attn_nram, T* buf_nram,
T* compute_buf_nram, const int32_t deal_n, const int32_t num_levels,
const int32_t num_points, const int32_t num_keys, const int32_t channels,
const int32_t num_heads, const int32_t pad_channels,
const int32_t pad_num_levels_points, T* value_sram,
const int32_t sram_level_start_index,
const int32_t sram_level_start_offset) {
int32_t channel_size = channels * sizeof(T);
int32_t pad_sample_stride_3 = deal_n * pad_num_levels_points;
T* buf_value_nram = buf_nram; // (4, num_levels, num_points, channels)
T* buf_value_nram_trans = buf_nram + 4 * pad_num_levels_points * pad_channels;
T* weight_compute_nram = compute_buf_nram; // (4, num_levels, num_points)
int32_t* offset = data_offset_nram;
int32_t* stride_2_1 = offset + pad_sample_stride_3;
int32_t* stride_3_1 = stride_2_1 + pad_sample_stride_3;
T* output_nram = value_output_nram;
int32_t step_offset = 0;
for (int32_t i = 0; i < deal_n; i++) {
__bang_write_value(buf_value_nram, 4 * pad_num_levels_points * pad_channels,
0);
__sync_compute();
if (sram_level_start_index > 0) {
loadDataValueXram2NramAsync<T, GDRAM2NRAM>(
buf_value_nram, offset, stride_2_1, stride_3_1, value_gdram,
pad_num_levels_points, num_points * sram_level_start_index, 0,
channel_size, num_heads, false, sram_level_start_offset);
}
if (sram_level_start_index < num_levels) {
loadDataValueXram2NramAsync<T, SRAM2NRAM>(
buf_value_nram, offset, stride_2_1, stride_3_1, value_sram,
pad_num_levels_points,
num_points * (num_levels - sram_level_start_index),
num_points * sram_level_start_index, channel_size, num_heads, true,
sram_level_start_offset);
}
__sync_io_move_compute();
reduceLevel(output_nram, buf_value_nram, buf_value_nram_trans,
weight_polation_nram + step_offset, weight_compute_nram,
pad_num_levels_points, pad_channels, pad_sample_stride_3);
step_offset += pad_num_levels_points;
offset = data_offset_nram + step_offset;
stride_2_1 = offset + pad_sample_stride_3;
stride_3_1 = stride_2_1 + pad_sample_stride_3;
output_nram += pad_channels;
}
}
template <typename T>
__mlu_func__ void prepareLoop(
int32_t* spatial_offset_nram, int32_t* spatial_hw_nram,
T* spatial_offset_bd_nram, T* spatial_h_bd_nram, T* spatial_w_bd_nram,
const char* data_level_start_index_gdram,
const char* data_spatial_shapes_gdram, const int32_t num_keys,
const int32_t num_levels, const int32_t num_points,
const int32_t max_deal_n, const int32_t channels) {
__memcpy(spatial_offset_nram, data_level_start_index_gdram,
num_levels * sizeof(int32_t), GDRAM2NRAM);
__memcpy(spatial_hw_nram, data_spatial_shapes_gdram,
num_levels * 2 * sizeof(int32_t), GDRAM2NRAM);
broadcastSpatialHW(spatial_offset_bd_nram, spatial_h_bd_nram,
spatial_w_bd_nram, spatial_hw_nram, spatial_offset_nram,
num_levels, num_points);
}
/*
The shape of each tensor:
buf_compute_nram: (8, num_levels, num_points)
spatial_offset_nram: (num_levels)
spatial_hw_nram: (num_levels, 2)
spatial_offset_bd_nram: (num_levels, num_points)
spatial_w_bd_nram: (num_levels, num_points)
spatial_h_bd_nram: (num_levels, num_points)
value_output_nram: (deal_n, channels)
data_offset_nram: (4, deal_n, num_levels, num_points)
weight_polation_nram: (4, deal_n, num_levels, num_points)
cond_point_polation_nram: (4, deal_n, num_levels, num_points)
cond_point_valid_nram: (deal_n, num_levels, num_points)
loc_nram: (deal_n, num_levels, num_points, 2)
weight_attn_nram: (deal_n, num_levels, num_points)
buf_nram: (6, deal_n, num_levels, num_points)
Note: buf_nram is reused in polation computing.
*/
template <typename T>
__mlu_func__ void memPolicyCommon(
T*& buf_compute_nram, T*& value_output_nram, int32_t*& data_offset_nram,
T*& weight_polation_nram, T*& cond_point_polation_nram,
T*& cond_point_valid_nram, T*& loc_nram, T*& weight_attn_nram, T*& buf_nram,
T*& buf_nram_end, T*& spatial_offset_bd_nram, T*& spatial_w_bd_nram,
T*& spatial_h_bd_nram, T*& value_sram, int32_t*& spatial_offset_nram,
int32_t*& spatial_hw_nram, int32_t& max_deal_n, int32_t& pad_channels,
int32_t& pad_num_levels_points, int32_t& pad_total_points,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points) {
pad_channels = PAD_UP(channels, NFU_ALIGN_SIZE / sizeof(T));
int32_t num_levels_points = num_levels * num_points;
pad_num_levels_points = PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T));
int32_t pad_num_levels_points_8 = 8 * pad_num_levels_points;
int32_t spatial_info_size =
PAD_UP(3 * num_levels * sizeof(int32_t), NFU_ALIGN_SIZE);
int32_t fix_space_size =
spatial_info_size +
(3 * pad_num_levels_points + pad_num_levels_points) * sizeof(T);
int32_t left_space_size = NRAM_AVALIABLE_SIZE - fix_space_size;
int32_t common_buffer_size_each = 6 * pad_num_levels_points * sizeof(T);
int32_t inter_result_size_each =
17 * pad_num_levels_points * sizeof(T) + pad_channels * sizeof(T);
max_deal_n =
left_space_size / (common_buffer_size_each + inter_result_size_each);
int32_t compute_buffer_size =
(9 * pad_num_levels_points * pad_channels) * sizeof(T);
int32_t common_buffer_size = max_deal_n * common_buffer_size_each;
// make sure buf_nram is large enough for compute
if (compute_buffer_size > common_buffer_size) {
int32_t tmp_deal_n =
(left_space_size - compute_buffer_size) / inter_result_size_each;
max_deal_n = __mluop_min(max_deal_n, tmp_deal_n);
}
pad_total_points = max_deal_n * pad_num_levels_points;
buf_compute_nram = (T*)nram_buffer;
spatial_offset_nram = (int32_t*)(buf_compute_nram + pad_num_levels_points_8);
int32_t pad_3_levels = PAD_UP(3 * num_levels, NFU_ALIGN_SIZE / sizeof(T));
spatial_hw_nram = spatial_offset_nram + num_levels;
spatial_offset_bd_nram = (T*)(spatial_offset_nram + pad_3_levels);
spatial_w_bd_nram = spatial_offset_bd_nram + pad_num_levels_points;
spatial_h_bd_nram = spatial_w_bd_nram + pad_num_levels_points;
value_output_nram = spatial_h_bd_nram + pad_num_levels_points;
data_offset_nram = (int32_t*)(value_output_nram + max_deal_n * pad_channels);
weight_polation_nram = (T*)(data_offset_nram + 4 * pad_total_points);
cond_point_polation_nram = weight_polation_nram + 4 * pad_total_points;
cond_point_valid_nram = cond_point_polation_nram + 4 * pad_total_points;
loc_nram = cond_point_valid_nram + pad_total_points;
weight_attn_nram =
loc_nram +
2 * pad_total_points; // total_coord_pad = 2 * pad_total_points
buf_nram = weight_attn_nram + pad_total_points;
buf_nram_end = buf_nram + 6 * max_deal_n * pad_num_levels_points;
value_sram = (T*)sram_buffer;
}
template <typename T>
__mlu_func__ void loadDataValueGdram2Sram(T* value_sram, T* data_value_gdram,
const int32_t batch_idx,
const int32_t head_idx,
const int32_t sram_num_keys,
const int32_t num_heads,
const int32_t channels,
const int32_t skip_num_key) {
int32_t loop_num =
(sram_num_keys + MAX_MEMCPY_SEGNUM - 1) / MAX_MEMCPY_SEGNUM;
int32_t num_heads_channels = num_heads * channels;
for (int32_t i = 0; i < loop_num; i++) {
int32_t load_num =
__mluop_min(MAX_MEMCPY_SEGNUM, sram_num_keys - i * MAX_MEMCPY_SEGNUM);
size_t src_offset = ((size_t)batch_idx * sram_num_keys + skip_num_key +
i * MAX_MEMCPY_SEGNUM) *
num_heads_channels +
head_idx * channels;
int32_t dst_offset = i * MAX_MEMCPY_SEGNUM * channels;
__memcpy(value_sram + dst_offset, (T*)data_value_gdram + src_offset,
channels * sizeof(T), GDRAM2SRAM, channels * sizeof(T),
num_heads_channels * sizeof(T), load_num - 1);
}
}
template <typename T>
__mlu_func__ void computeSramCacheSizeAndOffset(
int32_t* sram_level_cache_size, int32_t* sram_level_start_index,
int32_t* sram_level_start_offset, const int32_t num_levels,
const int32_t num_keys, const int32_t channels,
const T* data_level_start_index_gdram, const int32_t sram_size) {
for (int32_t level_id = num_levels; level_id > 0; level_id--) {
int current_level_end_index =
level_id == num_levels
? num_keys
: ((int32_t*)data_level_start_index_gdram)[level_id];
int32_t current_level_size =
current_level_end_index -
((int32_t*)data_level_start_index_gdram)[level_id - 1];
if ((*sram_level_cache_size + current_level_size) * channels * sizeof(T) >
sram_size) {
break;
}
*sram_level_cache_size += current_level_size;
*sram_level_start_index = level_id - 1;
*sram_level_start_offset = num_keys - *sram_level_cache_size;
}
}
template <typename T>
__mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl(
const char* data_value_gdram, const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram) {
int32_t input_stride_4 = num_queries * num_heads * num_levels * num_points;
int32_t input_stride_3 = num_heads * num_levels * num_points;
int32_t input_stride_2 = num_levels * num_points;
int32_t output_stride_3 = num_queries * num_heads * channels;
int32_t output_stride_2 = num_heads * channels;
int32_t data_value_stride_3 = num_keys * num_heads * channels;
T* value_output_nram = nullptr; // (deal_n, channels)
int32_t* data_offset_nram = nullptr; // (4, deal_n, num_levels, num_points)
T* weight_polation_nram = nullptr; // (4, deal_n, num_levels, num_points)
T* cond_point_polation_nram = nullptr; // (4, deal_n, num_levels, num_points)
T* cond_point_valid_nram = nullptr; // (deal_n, num_levels, num_points)
T* loc_nram = nullptr; // (deal_n, num_levels, num_points, 2)
T* weight_attn_nram = nullptr; // (deal_n, num_levels, num_points)
T* buf_nram = nullptr; // (6, deal_n, num_levels, num_points)
T* buf_nram_end = nullptr;
T* spatial_offset_bd_nram = nullptr; // (num_levels, num_points)
T* spatial_w_bd_nram = nullptr; // (num_levels, num_points)
T* spatial_h_bd_nram = nullptr; // (num_levels, num_points)
int32_t* spatial_offset_nram = nullptr; // (num_levels)
int32_t* spatial_hw_nram = nullptr; // (num_levels, 2)
T* buf_compute_nram = nullptr; // (8, num_levels, num_points)
int32_t max_deal_n = 0;
int32_t pad_channels = 0;
int32_t pad_num_levels_points = 0;
int32_t pad_total_points = 0;
T* value_sram = nullptr;
memPolicyCommon(buf_compute_nram, value_output_nram, data_offset_nram,
weight_polation_nram, cond_point_polation_nram,
cond_point_valid_nram, loc_nram, weight_attn_nram, buf_nram,
buf_nram_end, spatial_offset_bd_nram, spatial_w_bd_nram,
spatial_h_bd_nram, value_sram, spatial_offset_nram,
spatial_hw_nram, max_deal_n, pad_channels,
pad_num_levels_points, pad_total_points, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points);
if (max_deal_n <= 0) {
return;
}
// split batch*head into taskDimY
int32_t batch_head = batch_size * num_heads;
int32_t cluster_avg_batch_head = (batch_head + taskDimY - 1) / taskDimY;
int32_t cluster_begin_batch_head = taskIdY * cluster_avg_batch_head;
int32_t cluster_act_batch_head = __mluop_min(
cluster_avg_batch_head, batch_head - cluster_begin_batch_head);
int32_t cluster_end_batch_head =
cluster_begin_batch_head + cluster_act_batch_head;
// split query into coreDim
int32_t core_avg_query = (num_queries + coreDim - 1) / coreDim;
int32_t core_begin_query = coreId * core_avg_query;
int32_t core_act_query =
__mluop_min(num_queries - core_begin_query, core_avg_query);
int32_t core_loop_num = (core_act_query + max_deal_n - 1) / max_deal_n;
int32_t core_step_query =
core_loop_num > 0 ? (core_act_query + core_loop_num - 1) / core_loop_num
: 0;
int32_t core_remain_query =
core_act_query - (core_loop_num - 1) * core_step_query;
int32_t first_deal_query =
(int)(core_loop_num > 0) *
(core_loop_num > 1 ? core_step_query : core_remain_query);
prepareLoop(spatial_offset_nram, spatial_hw_nram, spatial_offset_bd_nram,
spatial_h_bd_nram, spatial_w_bd_nram,
data_level_start_index_gdram, data_spatial_shapes_gdram, num_keys,
num_levels, num_points, max_deal_n, channels);
int sram_total_size = 0;
int sram_level_start_index = num_levels;
int sram_level_start_offset = 0;
computeSramCacheSizeAndOffset(
&sram_total_size, &sram_level_start_index, &sram_level_start_offset,
num_levels, num_keys, channels, (int32_t*)data_level_start_index_gdram,
SRAM_FOR_VALUE_SIZE);
for (int32_t bh_idx = cluster_begin_batch_head;
bh_idx < cluster_end_batch_head; bh_idx++) {
int32_t b = bh_idx / num_heads;
int32_t head_idx = bh_idx % num_heads;
size_t output_base_offset =
(size_t)b * output_stride_3 + head_idx * channels;
int32_t attn_weight_base_offset =
b * input_stride_4 + head_idx * input_stride_2;
if (__is_mpu() && (sram_level_start_index != num_levels)) {
loadDataValueGdram2Sram(value_sram, (T*)data_value_gdram, b, head_idx,
sram_total_size, num_heads, channels,
sram_level_start_offset);
}
__sync_cluster();
if (__is_ipu()) {
// compute weight, offset and condition
int32_t attn_weight_offset =
attn_weight_base_offset + core_begin_query * input_stride_3;
int32_t loc_offset = attn_weight_offset * 2;
if (first_deal_query > 0) {
__bang_write_value(loc_nram, 2 * pad_total_points, 0);
__bang_write_value(weight_attn_nram, pad_total_points, 0);
__sync_compute();
__memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset,
input_stride_2 * 2 * sizeof(T), GDRAM2NRAM,
pad_num_levels_points * 2 * sizeof(T),
input_stride_3 * 2 * sizeof(T), first_deal_query - 1);
__memcpy_async(weight_attn_nram,
(T*)data_attn_weight_gdram + attn_weight_offset,
input_stride_2 * sizeof(T), GDRAM2NRAM,
pad_num_levels_points * sizeof(T),
input_stride_3 * sizeof(T), first_deal_query - 1);
getConditionCoordWeight<T>(
data_offset_nram, weight_polation_nram, cond_point_polation_nram,
cond_point_valid_nram, loc_nram, weight_attn_nram,
spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram,
buf_nram, first_deal_query, num_levels, num_points, num_heads,
pad_num_levels_points);
}
}
for (int32_t i = 0; __is_ipu() && i < core_loop_num; i++) {
__bang_write_value(loc_nram, 2 * pad_total_points, 0);
__bang_write_value(weight_attn_nram, pad_total_points, 0);
int32_t deal_n =
i < core_loop_num - 1 ? core_step_query : core_remain_query;
int32_t load_n =
i < core_loop_num - 2 ? core_step_query : core_remain_query;
// load value and polation
loadNeighborPolationAttn<T>(
value_output_nram,
(T*)data_value_gdram + b * data_value_stride_3 + head_idx * channels,
data_offset_nram, weight_polation_nram, cond_point_polation_nram,
cond_point_valid_nram, weight_attn_nram, buf_nram, buf_compute_nram,
deal_n, num_levels, num_points, num_keys, channels, num_heads,
pad_channels, pad_num_levels_points, value_sram,
sram_level_start_index, sram_level_start_offset);
__sync_io_move_compute();
// load next weight and loc
if (i < core_loop_num - 1) {
int32_t core_query_offset = (i + 1) * core_step_query;
int32_t attn_weight_offset =
attn_weight_base_offset +
(core_begin_query + core_query_offset) * input_stride_3;
int32_t loc_offset = attn_weight_offset * 2;
__memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset,
input_stride_2 * 2 * sizeof(T), GDRAM2NRAM,
pad_num_levels_points * 2 * sizeof(T),
input_stride_3 * 2 * sizeof(T), load_n - 1);
__memcpy_async(weight_attn_nram,
(T*)data_attn_weight_gdram + attn_weight_offset,
input_stride_2 * sizeof(T), GDRAM2NRAM,
pad_num_levels_points * sizeof(T),
input_stride_3 * sizeof(T), load_n - 1);
__sync_io_move_compute();
}
// store result
size_t output_offset =
((size_t)core_begin_query + i * core_step_query) * output_stride_2;
__memcpy_async((T*)data_col_gdram + output_base_offset + output_offset,
value_output_nram, channels * sizeof(T), NRAM2GDRAM,
output_stride_2 * sizeof(T), pad_channels * sizeof(T),
deal_n - 1);
// compute cond/weight/offset
if (i < core_loop_num - 1) {
getConditionCoordWeight<T>(
data_offset_nram, weight_polation_nram, cond_point_polation_nram,
cond_point_valid_nram, loc_nram, weight_attn_nram,
spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram,
buf_nram, load_n, num_levels, num_points, num_heads,
pad_num_levels_points);
}
__sync_io_move_compute();
}
__sync_cluster();
}
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForwardFast(
const char* data_value_gdram, const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram) {
MLUKernelMsDeformAttnForwardFastImpl<float>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template __mlu_global__ void MLUKernelMsDeformAttnForwardFast<float>(
const char* data_value_gdram, const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
void KernelMsDeformAttnForwardFast(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram) {
MLUKernelMsDeformAttnForwardFast<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}

View File

@ -11,6 +11,7 @@
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "ms_deform_attn_fast_mlu_kernel.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
@ -20,6 +21,8 @@ typedef enum {
1, /*!< MLUKernelMsDeformAttnForwardDefault */
MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL =
2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */
MS_DEFORM_ATTN_FORWARD_FAST =
3, /*!< MLUKernelMsDeformAttnForwardFast */
} MsDeformAttnForwardPolicy;
void KernelMsDeformAttnForwardDefault(
@ -40,6 +43,15 @@ void KernelMsDeformAttnForwardSmallChannel(
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
void KernelMsDeformAttnForwardFast(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
typedef enum {
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
@ -99,7 +111,9 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc(
#endif
int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
if (num_levels * num_points <= 128 && num_levels * num_points * channels <= 8192) {
return MS_DEFORM_ATTN_FORWARD_FAST;
} else if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
} else if (channels > nram_size / 12 / sizeof(float)) {
return MS_DEFORM_ATTN_FORWARD_DEFAULT;
@ -310,6 +324,18 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
(char*)output_ptr);
break;
}
case MS_DEFORM_ATTN_FORWARD_FAST: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardFast<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardFast(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}
}
output = output.view({batch_size, num_queries, num_heads * channels});