diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu index 6aab1dae2..40ad6396a 100644 --- a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu @@ -32,6 +32,7 @@ /**************************************************************************************** * * NRAM partition backward: + * default kernel * | grad_output_nram | grad_output_nram_temp | grad_weight | * | grad_h_weight | grad_w_weight | top_grad | * | top_grad_temp | spatial_shapes_nram | sampling_loc_nram | @@ -39,11 +40,26 @@ * | deal_size | deal_size | deal_size | * | deal_size | deal_size | 64bytes | * + * small channel kernel + * | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl | + * | nram_grad_output_br | grad_temp1 | grad_temp2 | + * | grad_temp3 | grad_temp4 | nram_loc_w | + * | nram_loc_h | nram_h_low | nram_w_low | + * | nram_h_high | nram_w_high | nram_h_low_temp | + * | nram_h_high_temp | nram_hw | nram_hh | + * | nram_lw | nram_lh | nram_h_low_ptr_offset | + * | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset | + * | nram_w1 | nram_w2 | nram_w3 | + * | nram_w4 | nram_grad_weight | nram_base_ptr | + * | nram_offset_temp | nram_offset1 | nram_offset2 | + * | nram_offset3 | nram_offset4 | nram_w_low_temp | + * | nram_spatial_shapes | nram_level_start_index | nram_h_stride | ****************************************************************************************/ #define TWELVE_SPLIT 12 #define ALIGN_NUM 32 #define ALIGN_NUM_FOR_REDUCE 32 +#define ELE_COUNT 32 #define LEN_FLOAT sizeof(float) __nram__ char nram_buffer[MAX_NRAM_SIZE]; @@ -540,6 +556,17 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardDefault( return; } +__mlu_func__ void genMask0101(float *mask_ram, int32_t size) { + int32_t align_num = NFU_ALIGN_SIZE / sizeof(float); + for (int32_t i = 0; i < align_num; ++i) { + mask_ram[i] = i % 2; + } + __asm__ volatile("sync;"); + __memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM, + NFU_ALIGN_SIZE, 0, size / align_num - 2); + __asm__ volatile("sync;"); +} + template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( const char *data_value_gdram, const char *data_spatial_shapes_gdram, @@ -548,467 +575,471 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char *data_col_gdram) { +#if __BANG_ARCH__ >= 300 if (coreId == 0x80) { return; } + size_t block_num_per_core, batch_start, deal_g, offset_g; + size_t block_num_rem = 0; + const size_t grid_total = num_queries * num_heads * num_levels * num_points; + if (batch_size >= taskDim) { + block_num_rem = batch_size % taskDim; + block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1 + : batch_size / taskDim; + batch_start = taskId < block_num_rem + ? taskId * block_num_per_core + : taskId * block_num_per_core + block_num_rem; + deal_g = grid_total; + offset_g = 0; + } else { + size_t skip_n = taskDim / batch_size; + batch_start = taskId / skip_n; + block_num_per_core = batch_start >= batch_size ? 0 : 1; + deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points); + size_t id = taskId % skip_n; + offset_g = id * deal_g; + deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1); + } + + const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float); + int32_t deal_num; + int32_t cut_channel_iter = 2; + const size_t spatial_size = PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE); const size_t level_start_index_size = PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); - size_t sampling_loc_size = - PAD_UP(num_levels * num_points * 2 * sizeof(T), NFU_ALIGN_SIZE); - size_t attn_weight_size = - PAD_UP(num_levels * num_points * sizeof(T), NFU_ALIGN_SIZE); - size_t span_num_deal = - PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size - - sampling_loc_size - attn_weight_size) / - TWELVE_SPLIT / sizeof(T), - NFU_ALIGN_SIZE); - const int32_t channels_seg_num = channels / span_num_deal; - const size_t channels_rem = channels % span_num_deal; - int32_t load_loc_weight_idx = 0; - int32_t load_loc_weight_seg = 1; - if (channels_seg_num == 0) { - span_num_deal = PAD_UP(channels, NFU_ALIGN_SIZE); - attn_weight_size = - PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size - - TWELVE_SPLIT * span_num_deal * sizeof(T)) / - 3, - num_levels * num_points * sizeof(T)); - attn_weight_size = PAD_DOWN(attn_weight_size, NFU_ALIGN_SIZE); - sampling_loc_size = attn_weight_size * 2; - load_loc_weight_seg = - attn_weight_size / (num_levels * num_points * sizeof(T)); - } -#if __BANG_ARCH__ < 322 - const size_t align_num = NFU_ALIGN_SIZE; - const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num); -#endif + int32_t channel = channels; + int32_t mult; + while (true) { + deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) / + (8 * channel + 7) / sizeof(T); + deal_num = PAD_DOWN(deal_num, float_align); + deal_num = PAD_DOWN(deal_num, num_levels * num_points); + if (deal_num > 0) { + break; + } else { + channel = channels / cut_channel_iter; + cut_channel_iter += 2; + } + } + mult = channel; + + const int32_t c_rep = channels / channel; + const int32_t c_rem = channels % channel; + + const int32_t g_rep = deal_g / deal_num; + const int32_t g_rem = deal_g % deal_num; + + // nram buffer alloc char *data_spatial_shapes_nram = nram_buffer; char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size; - char *data_sampling_loc_nram = - data_level_start_index_nram + level_start_index_size; - char *data_attn_weight_nram = data_sampling_loc_nram + sampling_loc_size; - char *ping_data_value_p1_nram = data_attn_weight_nram + attn_weight_size; - char *ping_data_value_p2_nram = - ping_data_value_p1_nram + span_num_deal * sizeof(T); - char *ping_data_value_p3_nram = - ping_data_value_p2_nram + span_num_deal * sizeof(T); - char *ping_data_value_p4_nram = - ping_data_value_p3_nram + span_num_deal * sizeof(T); - char *ping_data_col_nram = - ping_data_value_p4_nram + span_num_deal * sizeof(T); - char *pong_data_value_p1_nram = - ping_data_col_nram + span_num_deal * sizeof(T); - char *pong_data_value_p2_nram = - pong_data_value_p1_nram + span_num_deal * sizeof(T); - char *pong_data_value_p3_nram = - pong_data_value_p2_nram + span_num_deal * sizeof(T); - char *pong_data_value_p4_nram = - pong_data_value_p3_nram + span_num_deal * sizeof(T); - char *pong_data_col_nram = - pong_data_value_p4_nram + span_num_deal * sizeof(T); - char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T); - char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T); - const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T); - size_t data_col_ping_pong_idx = 0; + char *input_tl = data_level_start_index_nram + level_start_index_size; + char *input_tr = input_tl + deal_num * mult * sizeof(T); + char *input_bl = input_tr + deal_num * mult * sizeof(T); + char *input_br = input_bl + deal_num * mult * sizeof(T); + char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T); + char *weight_tr = weight_tl + deal_num * mult * sizeof(T); + char *weight_bl = weight_tr + deal_num * mult * sizeof(T); + char *weight_br = weight_bl + deal_num * mult * sizeof(T); + char *mask_tl = weight_br + deal_num * mult * sizeof(T); + char *mask_tr = mask_tl + deal_num * sizeof(T); + char *mask_bl = mask_tr + deal_num * sizeof(T); + char *mask_br = mask_bl + deal_num * sizeof(T); + char *point_ram = mask_br + deal_num * sizeof(T); + char *index_tl = point_ram + deal_num * sizeof(T); + char *index_bl = index_tl + deal_num * sizeof(T); - const int32_t block_num_rem = - (batch_size * num_queries * num_heads) % taskDim; - const int32_t block_num_per_core = - taskId < block_num_rem - ? (batch_size * num_queries * num_heads) / taskDim + 1 - : (batch_size * num_queries * num_heads) / taskDim; - const int32_t idx_start = taskId < block_num_rem - ? taskId * block_num_per_core - : taskId * block_num_per_core + block_num_rem; + // nram space reuse + char *grid_ram = weight_tl; + char *mask_ram = weight_bl; + char *coord_x = input_bl; + char *coord_y = coord_x + deal_num * sizeof(T); + char *coord_x_low = input_tl; + char *coord_y_low = coord_x_low + deal_num * sizeof(T); + char *coord_x_low_int = weight_tl; + char *coord_y_low_int = weight_tr; + char *spatial_x = mask_tl; + char *spatial_y = mask_tr; + char *spatial_x_float = weight_bl; + char *spatial_y_float = weight_br; + char *spatial_x_temp = mask_bl; + char *spatial_y_temp = mask_br; + char *base_ptr_offset = weight_tl; + char *auxiliary_a = point_ram; + char *auxiliary_b = weight_bl; __memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram, num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); __memcpy_async(data_level_start_index_nram, data_level_start_index_gdram, num_levels * sizeof(int32_t), GDRAM2NRAM); + __asm__ volatile("sync;"); - for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core; - ++cur_idx) { - // cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads + - // head_idx - const int32_t head_idx = cur_idx % num_heads; - const int32_t batch_idx = (cur_idx / num_heads) / num_queries; - - const char *data_value_gdram_start = - data_value_gdram + - batch_idx * num_keys * num_heads * channels * sizeof(T); - char *data_col_gdram_start = - data_col_gdram + cur_idx * channels * sizeof(T); - - if (load_loc_weight_seg == 1 || - (load_loc_weight_idx % load_loc_weight_seg) == 0) { - const char *data_sampling_loc_gdram_start = - data_sampling_loc_gdram + - cur_idx * num_levels * num_points * 2 * sizeof(T); - const char *data_attn_weight_gdram_start = - data_attn_weight_gdram + - cur_idx * num_levels * num_points * sizeof(T); - const int32_t load_loc_weight_size = - (block_num_per_core - load_loc_weight_idx) < load_loc_weight_seg - ? block_num_per_core - load_loc_weight_idx - : load_loc_weight_seg; - __memcpy_async( - data_sampling_loc_nram, data_sampling_loc_gdram_start, - load_loc_weight_size * num_levels * num_points * 2 * sizeof(T), - GDRAM2NRAM); - __memcpy_async(data_attn_weight_nram, data_attn_weight_gdram_start, - load_loc_weight_size * num_levels * num_points * sizeof(T), - GDRAM2NRAM); - __asm__ volatile("sync;"); - } - const int32_t load_loc_weight_offset = - (load_loc_weight_idx % load_loc_weight_seg) * num_levels * num_points; - - for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) { - __bang_write_value( - (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), - span_num_deal, (T)0); - // load data - // level_idx = 0, point_idx = 0 - int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; - int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; - const char *data_value_ptr = - data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T); - T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2]; - T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1]; - T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset]; - T x = loc_w * spatial_w - 0.5; - T y = loc_h * spatial_h - 0.5; - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, (T *)ping_data_value_p1_nram, - (T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, - (T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h, - num_heads, channels, x, y, head_idx); - } - T spatial_h_next_point = 0; - T spatial_w_next_point = 0; - T weight_next_point = 0; - T x_next_point = 0; - T y_next_point = 0; - __asm__ volatile("sync;"); - - for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { - for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { - // load data - if (point_idx == num_points - 1 && level_idx == num_levels - 1) { - // last point no need to load data, continue to compute - } else if (point_idx == num_points - 1) { - const int32_t level_start_id = - ((int32_t *)data_level_start_index_nram)[level_idx + 1]; - const int32_t spatial_h_ptr = (level_idx + 1) << 1; - spatial_h_next_point = - ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr]; - spatial_w_next_point = - ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1]; - data_value_ptr = data_value_gdram_start + - (level_start_id * num_heads * channels + - c_seg_idx * span_num_deal) * - sizeof(T); - loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2]; - loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2 + - 1]; - weight_next_point = - ((T *)data_attn_weight_nram)[load_loc_weight_offset + - level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w_next_point - 0.5; - y_next_point = loc_h * spatial_h_next_point - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h_next_point && - x_next_point < spatial_w_next_point) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - span_num_deal, spatial_w_next_point, spatial_h_next_point, - num_heads, channels, x_next_point, y_next_point, head_idx); - } - } else { - spatial_h_next_point = spatial_h; - spatial_w_next_point = spatial_w; - loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2]; - loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2 + - 1]; - weight_next_point = - ((T *)data_attn_weight_nram)[load_loc_weight_offset + - level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w - 0.5; - y_next_point = loc_h * spatial_h - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h && x_next_point < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - span_num_deal, spatial_w, spatial_h, num_heads, channels, - x_next_point, y_next_point, head_idx); - } - } - - // compute - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - computeMsDeformAttn( - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - weight, span_num_deal, spatial_w, spatial_h, x, y); - } - - spatial_w = spatial_w_next_point; - spatial_h = spatial_h_next_point; - weight = weight_next_point; - x = x_next_point; - y = y_next_point; - __asm__ volatile("sync;"); + for (int32_t batch_idx = batch_start; + batch_idx < batch_start + block_num_per_core; ++batch_idx) { + for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) { + int32_t io_data_num = deal_num; + const int32_t grid_off_base = + batch_idx * grid_total + offset_g + grid_iter * deal_num; + if (grid_iter == g_rep) { + if (g_rem == 0) { + continue; + } else { + io_data_num = g_rem; } } - // store - __memcpy_async( - data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T), - ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, - span_num_deal * sizeof(T), NRAM2GDRAM); - data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; - } - if (channels_rem > 0) { -#if __BANG_ARCH__ >= 322 - __bang_write_value( - (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), - channels_rem, (T)0); -#else - __bang_write_value( - (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), - channels_align_rem, (T)0); -#endif - // load data - // level_idx = 0, point_idx = 0 - int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; - int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; - const char *data_value_ptr = - data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T); - T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2]; - T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1]; - T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset]; - T x = loc_w * spatial_w - 0.5; - T y = loc_h * spatial_h - 0.5; - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, (T *)ping_data_value_p1_nram, - (T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, - (T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h, - num_heads, channels, x, y, head_idx); - } - T spatial_h_next_point = 0; - T spatial_w_next_point = 0; - T weight_next_point = 0; - T x_next_point = 0; - T y_next_point = 0; + char *data_col_gdram_start = + data_col_gdram + (batch_idx * num_queries * num_heads * channels + + (offset_g + grid_iter * deal_num) / + (num_levels * num_points) * channels) * + sizeof(float); + + // load data_sampling_loc + __memcpy_async( + grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float), + io_data_num * 2 * sizeof(float), GDRAM2NRAM); + genMask0101((float *)mask_ram, deal_num * 2); __asm__ volatile("sync;"); - for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { - for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { - // load data - if (point_idx == num_points - 1 && level_idx == num_levels - 1) { - // last point no need to load data, continue to compute - } else if (point_idx == num_points - 1) { - const int32_t level_start_id = - ((int32_t *)data_level_start_index_nram)[level_idx + 1]; - const int32_t spatial_h_ptr = (level_idx + 1) << 1; - spatial_h_next_point = - ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr]; - spatial_w_next_point = - ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1]; - data_value_ptr = data_value_gdram_start + - (level_start_id * num_heads * channels + - channels_seg_num * span_num_deal) * - sizeof(T); - loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2]; - loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2 + - 1]; - weight_next_point = - ((T *)data_attn_weight_nram)[load_loc_weight_offset + - level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w_next_point - 0.5; - y_next_point = loc_h * spatial_h_next_point - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h_next_point && - x_next_point < spatial_w_next_point) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - channels_rem, spatial_w_next_point, spatial_h_next_point, - num_heads, channels, x_next_point, y_next_point, head_idx); - } - } else { - spatial_w_next_point = spatial_w; - spatial_h_next_point = spatial_h; - loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2]; - loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + - level_idx * num_points + - point_idx + 1) * - 2 + - 1]; - weight_next_point = - ((T *)data_attn_weight_nram)[load_loc_weight_offset + - level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w - 0.5; - y_next_point = loc_h * spatial_h - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h && x_next_point < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - channels_rem, spatial_w, spatial_h, num_heads, channels, - x_next_point, y_next_point, head_idx); - } - } + // generate x and y coordinate vector + // generate spatial_x and spatial_y spatial vector + __bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram, + deal_num * 2); // y + __bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram, + (float *)mask_ram, + num_levels * 2); // spatial_x + __bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2); + __bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram, + deal_num * 2); // x + __bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram, + (float *)mask_ram, + num_levels * 2); // spatial_y - // compute - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { -#if __BANG_ARCH__ >= 322 - computeMsDeformAttn( - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - weight, channels_rem, spatial_w, spatial_h, x, y); -#else - computeMsDeformAttn( - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - weight, channels_align_rem, spatial_w, spatial_h, x, y); -#endif - } + for (int32_t i = 0; i < num_levels; i++) { + __bang_write_value((int32_t *)spatial_x + i * num_points, num_points, + ((int32_t *)spatial_x_temp)[i]); + __bang_write_value((int32_t *)spatial_y + i * num_points, num_points, + ((int32_t *)spatial_y_temp)[i]); + } - spatial_w = spatial_w_next_point; - spatial_h = spatial_h_next_point; - weight = weight_next_point; - x = x_next_point; - y = y_next_point; - __asm__ volatile("sync;"); + __bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x, + num_levels * num_points, 0); + __bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y, + num_levels * num_points, 0); + + // map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0, + // spatial_y] + __bang_cycle_mul((float *)coord_x, (float *)coord_x, + (float *)spatial_x_float, deal_num, + num_levels * num_points); + __bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5, + deal_num); + __bang_cycle_mul((float *)coord_y, (float *)coord_y, + (float *)spatial_y_float, deal_num, + num_levels * num_points); + __bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5, + deal_num); + + __bang_floor((float *)coord_x_low, (float *)coord_x, deal_num); + __bang_floor((float *)coord_y_low, (float *)coord_y, deal_num); + + // calc index_tl + const int32_t w_stride = num_heads * channels; + __bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low, + deal_num, 0); + __bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low, + deal_num, 0); + __bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int, + (int32_t *)spatial_x, deal_num, num_levels * num_points); + __bang_add((int32_t *)index_tl, (int32_t *)index_tl, + (int32_t *)coord_x_low_int, deal_num); + __bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride, + deal_num); + + const int32_t deal_lp_num = deal_num / (num_levels * num_points); + const int32_t h_rep = deal_lp_num / num_heads; + const int32_t h_rem = deal_lp_num % num_heads; + const int32_t head_start = + ((offset_g + grid_iter * deal_num) / (num_levels * num_points)) % + num_heads; + for (int32_t iter = 0; iter < num_heads; ++iter) { + ((int32_t *)base_ptr_offset)[iter] = + ((head_start + iter) % num_heads) * channels; + } + if (h_rep > 0) { + __memcpy((int32_t *)base_ptr_offset + num_heads, + (int32_t *)base_ptr_offset, num_heads * sizeof(int32_t), + NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1); + } + if (h_rep > 0 && h_rem > 0) { + __memcpy((int32_t *)base_ptr_offset + h_rep * num_heads, + (int32_t *)base_ptr_offset, h_rem * sizeof(int32_t), + NRAM2NRAM); + } + __bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num, + num_levels * num_points); + __bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, + (int32_t *)base_ptr_offset, deal_num, deal_lp_num); + __bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a, + num_levels * num_points, deal_lp_num); + + // calc index_bl + __bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride, + deal_num); + __bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl, + (int32_t *)auxiliary_a, deal_num, + num_levels * num_points); + + // calc mask_tl, mask_tr, mask_bl, mask_br + __bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float, + (float)1.0, deal_num); + __bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float, + (float)1.0, deal_num); + // mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y + __bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0, + deal_num); + __bang_cycle_le((float *)mask_br, (float *)coord_x_low, + (float *)spatial_x_float, deal_num, + num_levels * num_points); + __bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br, + deal_num); + + __bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0, + deal_num); + __bang_cycle_le((float *)mask_br, (float *)coord_y_low, + (float *)spatial_y_float, deal_num, + num_levels * num_points); + __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br, + deal_num); + __bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl, + deal_num); + + // mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y + __bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0), + deal_num); + __bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low, + (float *)spatial_x_float, deal_num, + num_levels * num_points); + __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, + deal_num); + __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br, + deal_num); + + // mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y + __bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low, + (float)(-1.0), deal_num); + __bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low, + (float *)spatial_y_float, deal_num, + num_levels * num_points); + __bang_and((float *)auxiliary_a, (float *)auxiliary_a, + (float *)auxiliary_b, deal_num); + __bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a, + deal_num); + + // mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < + // spatial_y + __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, + deal_num); + + // calc inner point num + __bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0, + deal_num); + __bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0, + deal_num); + __bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr, + deal_num); + __bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0, + deal_num); + __bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br, + deal_num); + __bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl, + deal_num); + + // calc interpolation weight + __bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x, + deal_num); + __bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y, + deal_num); + __bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0, + deal_num); + __bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0, + deal_num); + + __bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low, + deal_num); + __bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low, + deal_num); + __bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br, + deal_num); + __bang_mul((float *)input_tl + deal_num, (float *)weight_br, + (float *)weight_tl, deal_num); + __bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl, + (float *)weight_tr, deal_num); + __bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl, + (float *)weight_tr, deal_num); + + __asm__ volatile("sync;"); + + // extend weight + const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT; + const int32_t w_rem = channel % ELE_COUNT; + if (w_rem != 0) { + const int32_t data_sz = 1 * sizeof(float); + const int32_t dst_str = channel * sizeof(float); + for (int32_t iter = w_rep; iter < channel; ++iter) { + __memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz, + NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1); } } - // store - __memcpy_async( - data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T), - ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, - channels_rem * sizeof(T), NRAM2GDRAM); - data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; + if (w_rep != 0) { + for (int32_t i = 0; i < 4 * deal_num; i++) { + __bang_write_value((float *)weight_tl + i * channel, w_rep, + ((float *)input_tl)[i]); + } + } + + __asm__ volatile("sync;"); + + const char *data_value_gdram_start = + data_value_gdram + + batch_idx * num_keys * num_heads * channels * sizeof(float); + const int32_t c_str = deal_num * channel * sizeof(float); + const int32_t cs_str = num_heads * channels * sizeof(float); + + for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) { + int32_t c_real_num = channel; + if (c_iter == c_rep) { + if (c_rem == 0) { + continue; + } else { + c_real_num = c_rem; + } + } + + __bang_write_zero((float *)input_tl, 4 * deal_num * channel); + __asm__ volatile("sync;"); + + // load data_value + for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) { + const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx]; + const int32_t tl_offset = ((int32_t *)index_tl)[p_idx]; + const int32_t bl_offset = ((int32_t *)index_bl)[p_idx]; + const int32_t level_start_id = + ((int32_t *)data_level_start_index_nram)[(p_idx / num_points) % + num_levels]; + const char *data_value_ptr = + data_value_gdram_start + + (level_start_id * num_heads * channels + c_iter * channel) * + sizeof(float); + + switch (inner_point_num) { + case 16: // 4 points are cached. + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + break; + case 12: // 2 points are cached. (top_left, top_right) + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + break; + case 4: // 2 points are cached. (bottom_left, bottom_right) + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + break; + case 10: // 2 points are cached. (top_left, bottom_left) + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 6: // 2 points are cached. (top_right, bottom_right) + __memcpy_async( + (float *)input_tr + p_idx * channel, + (float *)data_value_ptr + tl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + __memcpy_async( + (float *)input_br + p_idx * channel, + (float *)data_value_ptr + bl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 7: // 1 point is cached. (top_left) + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 5: // 1 point is cached. (top_right) + __memcpy_async( + (float *)input_tr + p_idx * channel, + (float *)data_value_ptr + tl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 3: // 1 point is cached. (bottom_left) + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 1: // 1 point is cached. (bottom_right) + __memcpy_async( + (float *)input_br + p_idx * channel, + (float *)data_value_ptr + bl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + default: + continue; + } + } + + __asm__ volatile("sync;"); + + // interpolation + __bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl, + 4 * deal_num * channel); + __bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl, + 2 * deal_num * channel); + __bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr, + deal_num * channel); + + // load attention weight + void *attn_weight = mask_tl; + __memcpy((float *)attn_weight, + (float *)data_attn_weight_gdram + grid_off_base, + io_data_num * sizeof(float), GDRAM2NRAM); + + // calc data_col, muladd attention weight + __bang_transpose((float *)input_tr, (float *)input_tl, deal_num, + channel); + __bang_cycle_mul((float *)input_tr, (float *)input_tr, + (float *)attn_weight, deal_num * channel, deal_num); + __bang_transpose((float *)input_tl, (float *)input_tr, channel, + deal_num); + __bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1, + io_data_num, 1, num_levels * num_points, + num_levels * num_points, 1); + + // store + __memcpy((float *)data_col_gdram_start + c_iter * channel, + (float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM, + channels * sizeof(float), channel * sizeof(float), + (io_data_num / (num_levels * num_points)) - 1); + } } - load_loc_weight_idx += 1; } __asm__ volatile("sync;"); +#endif return; } @@ -1316,294 +1347,496 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( } } -template -void __mlu_func__ -loadData(const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, - const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr, - T *grad_output_nram_bl, T *grad_output_nram_br, - const T *data_value_ptr, const int32_t &width, const int32_t &height, - const int32_t &deal_num_real, const int32_t &h_low_ptr_offset, - const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset, - const int32_t &h_high_ptr_offset, const int32_t &base_ptr) { -#if __BANG_ARCH__ > 322 - if (h_low >= 0 && w_low >= 0) +void __mlu_func__ computeGridMaskAndOffset( + float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w, + float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes, + float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low, + float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh, + float *nram_lw, float *nram_hh, float *nram_hw, + float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset, + float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1, + float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp, + float *nram_offset1, float *nram_offset2, float *nram_offset3, + float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp, + int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads, + const int32_t num_levels, const int32_t num_points, const int32_t w_stride, + const int32_t qid_stride) { +#if __BANG_ARCH__ >= 322 + // [num_levels, 2] --> [2, num_levels] + __bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2); + __bang_transpose(nram_loc_w, nram_grad_output_tl, + num_per_time_real * num_heads * num_levels, num_points); + __bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid, + num_per_time_real * num_heads * num_levels, num_points); + __bang_int322float((float *)nram_spatial_shapes, + (int32_t *)nram_spatial_shapes, num_levels * 2, 0); + __bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes, + num_levels, 2); + __bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride, + num_levels); + __memcpy_async(nram_spatial_shapes, nram_grad_output_tr, + num_levels * 2 * sizeof(float), NRAM2NRAM); + __bang_cycle_mul(nram_loc_w, nram_loc_w, + (float *)nram_spatial_shapes + num_levels, num_deal_grid, + num_levels); + __bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes), + num_deal_grid, num_levels); + __bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid); + __bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid); + // get mask. (h_im > -1 && w_im > -1 && + // h_im < spatial_h && w_im < spatial_w) + __bang_cycle_lt(nram_w_low_temp, nram_loc_w, + (float *)(nram_spatial_shapes + num_levels), num_deal_grid, + num_levels); + __bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes), + num_deal_grid, num_levels); + __bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid); + __bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid); + __bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp, + num_deal_grid); + __bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid); + __bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp, + num_deal_grid); + __bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_h_high_temp, nram_w_low_temp, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_transpose(nram_grad_output_tl, nram_loc_w, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float), + NRAM2NRAM); + __bang_transpose(nram_grad_output_tl, nram_loc_h, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float), + NRAM2NRAM); + __bang_floor(nram_w_low, nram_loc_w, num_deal_grid); + __bang_floor(nram_h_low, nram_loc_h, num_deal_grid); + __bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid); + __bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid); + __bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid); + __bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid); + __bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid); + __bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid); + __bang_transpose(nram_h_low_ptr_offset, nram_h_low, + num_per_time_real * num_heads * num_levels, num_points); + __bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride, + num_deal_grid, num_levels); + __bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride, + num_deal_grid, num_levels); + __bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride, + num_deal_grid); + __bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride, + num_deal_grid); + __bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid); + __bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid); + __bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid); + __bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid); + __bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset1, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + __bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset2, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + __bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset3, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + __bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset4, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + // h_low >= 0 && w_low >= 0 mask2 + float *mask1 = nram_h_low_ptr_offset; + float *mask2 = nram_h_high_ptr_offset; + float *mask3 = nram_w_low_ptr_offset; + float *mask4 = nram_w_high_ptr_offset; + __bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid); + __bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid); + __bang_and(mask2, mask1, mask2, num_deal_grid); + __bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid); + // h_low >= 0 && w_high <= width - 1 mask1 + __bang_transpose(mask3, nram_w_high, + num_per_time_real * num_heads * num_levels, num_points); + __bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1, + num_levels * 2); + __bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels), + num_deal_grid, num_levels); + __bang_transpose(mask4, mask3, num_points, + num_per_time_real * num_heads * num_levels); + __bang_and(mask1, mask1, mask4, num_deal_grid); + __bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid); + // h_high <= height - 1 && w_high <= width - 1 mask3 + __bang_transpose(mask3, nram_h_high, + num_per_time_real * num_heads * num_levels, num_points); + __bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid, + num_levels); - { - int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - __memcpy_async(grad_output_nram_tl, data_value_ptr + offset1, - deal_num_real * sizeof(T), GDRAM2NRAM); - } - if (h_low >= 0 && w_high <= width - 1) - - { - int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - __memcpy_async(grad_output_nram_tr, data_value_ptr + offset2, - deal_num_real * sizeof(T), GDRAM2NRAM); - } - if (h_high <= height - 1 && w_low >= 0) - - { - int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - __memcpy_async(grad_output_nram_bl, data_value_ptr + offset3, - deal_num_real * sizeof(T), GDRAM2NRAM); - } - if (h_high <= height - 1 && w_high <= width - 1) - - { - int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - __memcpy_async(grad_output_nram_br, data_value_ptr + offset4, - deal_num_real * sizeof(T), GDRAM2NRAM); - } - __sync_io(); + __bang_transpose(nram_h_low_temp, mask3, num_points, + num_per_time_real * num_heads * num_levels); + __bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid); + __bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid); + // h_high <= height - 1 && w_low >= 0 mask4 + __bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid); + __bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid); + __bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid); #endif } -template -void __mlu_func__ computeData( - const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, - const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr, - T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp, - T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, - T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height, - const int32_t &deal_num_real, T *grad_h_weight, T *grad_w_weight, - T *top_grad_temp, T *top_grad, const T &data_attn_weight, const T &hw, - const T &hh, const T &lw, const T &lh, const T &w1, const T &w2, - const T &w3, const T &w4) { -#if __BANG_ARCH__ > 322 - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); - if (h_low >= 0 && w_low >= 0) { - __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tl, (float)(-hw), - grad_h_weight, deal_num_real, deal_num_real); - __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tl, (float)(-hh), - grad_w_weight, deal_num_real, deal_num_real); - __bang_mul_scalar(grad_output_nram_tl_temp, top_grad_temp, w1, - deal_num_real); - // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram_tl, grad_output_nram_tl, w1, - deal_num_real); - } - if (h_low >= 0 && w_high <= width - 1) { - __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tr, (float)(-lw), - grad_h_weight, deal_num_real, deal_num_real); - __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tr, (float)(hh), - grad_w_weight, deal_num_real, deal_num_real); - __bang_mul_scalar(grad_output_nram_tr_temp, top_grad_temp, w2, - deal_num_real); - __bang_mul_scalar(grad_output_nram_tr, grad_output_nram_tr, w2, - deal_num_real); - __bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_tr, - deal_num_real); - } - if (h_high <= height - 1 && w_low >= 0) { - __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_bl, (float)(hw), - grad_h_weight, deal_num_real, deal_num_real); - __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_bl, (float)(-lh), - grad_w_weight, deal_num_real, deal_num_real); - __bang_mul_scalar(grad_output_nram_bl_temp, top_grad_temp, w3, - deal_num_real); - // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram_bl, grad_output_nram_bl, w3, - deal_num_real); - __bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_bl, - deal_num_real); - } - if (h_high <= height - 1 && w_high <= width - 1) { - __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_br, (float)(lw), - grad_h_weight, deal_num_real, deal_num_real); - __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_br, (float)(lh), - grad_w_weight, deal_num_real, deal_num_real); - __bang_mul_scalar(grad_output_nram_br_temp, top_grad_temp, w4, - deal_num_real); - // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram_br, grad_output_nram_br, w4, - deal_num_real); - __bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_br, - deal_num_real); - } - __bang_mul(grad_output_nram_tl, grad_output_nram_tl, top_grad, deal_num_real); - recursiveSumPool(grad_output_nram_tl, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); - __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real); - __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real); - - recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); - __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real); - __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real); - recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); -#endif -} - -template -void __mlu_func__ storeData( - const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, - const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tl_temp, - T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, - T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height, - const int32_t &deal_num_real, const int32_t &h_low_ptr_offset, - const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset, - const int32_t &h_high_ptr_offset, const int32_t &base_ptr, T *grad_value, - T *grad_w_weight, T *grad_h_weight, T *grad_sampling_loc, - T *grad_attn_weight) { -#if __BANG_ARCH__ > 322 - if (h_low >= 0 && w_low >= 0) - - { - int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - __bang_atomic_add((T *)grad_output_nram_tl_temp, - (T *)(grad_value + offset1), - (T *)grad_output_nram_tl_temp, deal_num_real); - } - if (h_low >= 0 && w_high <= width - 1) - - { - int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - __bang_atomic_add((T *)grad_output_nram_tr_temp, - (T *)(grad_value + offset2), - (T *)grad_output_nram_tr_temp, deal_num_real); - } - if (h_high <= height - 1 && w_low >= 0) - - { - int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - __bang_atomic_add((T *)grad_output_nram_bl_temp, - (T *)(grad_value + offset3), - (T *)grad_output_nram_bl_temp, deal_num_real); - } - if (h_high <= height - 1 && w_high <= width - 1) - - { - int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - __bang_atomic_add((T *)grad_output_nram_br_temp, - (T *)(grad_value + offset4), - (T *)grad_output_nram_br_temp, deal_num_real); - } - __bang_atomic_add((T *)grad_output_nram_tl, (T *)grad_attn_weight, - (T *)grad_output_nram_tl, 1); - __bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc), - (T *)grad_w_weight, 1); - __bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1), - (T *)grad_h_weight, 1); -#endif -} - -template -void __mlu_func__ msDeformAttnCol2imBilinearSmallChannels( - T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1, - const T &w2, const T &w3, const T &w4, const int32_t &h_low, - const int32_t &w_low, const int32_t &h_high, const int32_t &w_high, - const int32_t &base_ptr, const int32_t &h_low_ptr_offset, - const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset, - const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh, - const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight, - T *grad_w_weight, T *grad_value, T *grad_output_nram_tl, - T *grad_output_nram_tr, T *grad_output_nram_bl, T *grad_output_nram_br, - T *grad_output_nram_tl_temp, T *grad_output_nram_tr_temp, - T *grad_output_nram_bl_temp, T *grad_output_nram_br_temp, - T *grad_sampling_loc, T *grad_attn_weight, const int32_t &deal_num_real, - const T *data_value_ptr) - -{ - loadData(h_low, w_low, h_high, w_high, grad_output_nram_tl, - grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br, - data_value_ptr, width, height, deal_num_real, h_low_ptr_offset, - w_low_ptr_offset, w_high_ptr_offset, h_high_ptr_offset, base_ptr); - computeData(h_low, w_low, h_high, w_high, grad_output_nram_tl, - grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br, - grad_output_nram_tl_temp, grad_output_nram_tr_temp, - grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height, - deal_num_real, grad_h_weight, grad_w_weight, top_grad_temp, - top_grad, data_attn_weight, hw, hh, lw, lh, w1, w2, w3, w4); - storeData(h_low, w_low, h_high, w_high, grad_output_nram_tl, - grad_output_nram_tl_temp, grad_output_nram_tr_temp, - grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height, - deal_num_real, h_low_ptr_offset, w_low_ptr_offset, - w_high_ptr_offset, h_high_ptr_offset, base_ptr, grad_value, - grad_w_weight, grad_h_weight, grad_sampling_loc, grad_attn_weight); -} - -template -void __mlu_func__ msDeformAttnCol2imImpl( - T *top_grad_temp, T *top_grad, T *grad_h_weight, T *grad_w_weight, - T *grad_value, T *grad_output_nram_tl, T *grad_output_nram_tr, - T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp, - T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, - T *grad_output_nram_br_temp, T *grad_sampling_loc, T *grad_attn_weight, - T *nram_sampling_loc, T *nram_attn_weight, const int32_t &load_num, - const int32_t &tail, const int32_t &i_repeat, const int32_t &num_points, - const int32_t &start_per_core, const int32_t &num_levels, - const int32_t &num_heads, const int32_t &num_query, - const int32_t &spatial_size, const int32_t &qid_stride, - int32_t *level_start_index_nram, const int32_t &channels, - const T *data_value, const T *grad_output, int32_t *spatial_shapes_nram) { -#if __BANG_ARCH__ > 322 - int32_t weight_pos = 0; - int32_t sampling_loc_pos = 0; - for (int32_t p = 0; p < tail; ++p) { - int32_t grid_offset = start_per_core + i_repeat * load_num + p; - const int32_t l_col = grid_offset % num_levels; - const int32_t m_col = grid_offset / num_levels % num_heads; - const int32_t q_col = grid_offset / num_levels / num_heads % num_query; - const int32_t b_col = grid_offset / num_query / num_heads / num_levels; - const int32_t value_offset = b_col * spatial_size * qid_stride; - const int32_t level_start_id = level_start_index_nram[l_col]; - const int32_t grad_attn_weight_out = grid_offset * num_points; - const int32_t spatial_h_ptr = l_col << 1; - const int32_t grad_output_offset = - b_col * num_query * qid_stride + q_col * qid_stride + m_col * channels; - __memcpy(top_grad, grad_output + grad_output_offset, channels * LEN_FLOAT, - GDRAM2NRAM); - const int32_t spatial_h = spatial_shapes_nram[spatial_h_ptr]; - const int32_t spatial_w = spatial_shapes_nram[spatial_h_ptr + 1]; - const int32_t h_stride = spatial_w * qid_stride; - const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride; - const float *data_value_ptr = data_value + value_ptr_offset; - float *grad_value_ptr = grad_value + value_ptr_offset; - const int32_t grad_sampling_loc_out = grid_offset * num_points << 1; - - for (int32_t p_col = 0; p_col < num_points; ++p_col) { - const float loc_w = nram_sampling_loc[sampling_loc_pos]; - const float loc_h = nram_sampling_loc[sampling_loc_pos + 1]; - const float weight = nram_attn_weight[weight_pos]; - const float h_im = loc_h * spatial_h - 0.5; - const float w_im = loc_w * spatial_w - 0.5; - - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { - const int32_t h_low = floorf(h_im); - const int32_t w_low = floorf(w_im); - const int32_t h_high = h_low + 1; - const int32_t w_high = w_low + 1; - const float lh = h_im - h_low; - const float lw = w_im - w_low; - const float hh = 1.0 - lh; - const float hw = 1.0 - lw; - const int32_t h_low_ptr_offset = h_low * h_stride; - const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int32_t w_low_ptr_offset = w_low * qid_stride; - const int32_t w_high_ptr_offset = w_low_ptr_offset + qid_stride; - const float w1 = hh * hw; - const float w2 = hh * lw; - const float w3 = lh * hw; - const float w4 = lh * lw; - const int32_t base_ptr = m_col * channels; - __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); - __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); - __bang_write_zero(grad_output_nram_tl, PAD_UP(channels, ALIGN_NUM)); - msDeformAttnCol2imBilinearSmallChannels( - top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, - h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, - h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad, - weight, grad_h_weight, grad_w_weight, grad_value_ptr, - grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, - grad_output_nram_br, grad_output_nram_tl_temp, - grad_output_nram_tr_temp, grad_output_nram_bl_temp, - grad_output_nram_br_temp, - grad_sampling_loc + grad_sampling_loc_out + (p_col << 1), - grad_attn_weight + grad_attn_weight_out + p_col, channels, - data_value_ptr); - } - weight_pos += 1; - sampling_loc_pos += 2; +void __mlu_func__ loadValue( + float *nram_grad_output_tl, float *nram_grad_output_tr, + float *nram_grad_output_bl, float *nram_grad_output_br, + const float *data_value, const float *grad_output, float *grad_temp1, + float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4, + float *nram_offset1, float *nram_offset2, float *nram_offset3, + float *nram_offset4, float *nram_grad_weight, + int32_t *nram_level_start_index, int32_t offset_nram, + int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory, + int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real, + int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels, + const int32_t num_points, int32_t grid_offset, const int32_t spatial_size, + const int32_t qid_stride) { +#if __BANG_ARCH__ >= 322 + int32_t value_offset_temp = 0; + __bang_write_zero(nram_grad_output_tl, 4 * offset_nram); + __sync_io_move_compute(); + __memcpy_async( + grad_temp2, + grad_output + (start_per_core + grid_loop * num_per_time_theory) * + num_heads * deal_num_real, + num_per_time_real * num_heads * deal_num_real * sizeof(float), + GDRAM2NRAM); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + const int32_t b_col = + (grid_offset + loop) / num_query / num_heads / num_levels / num_points; + const int32_t l_col = (grid_offset + loop) / num_points % num_levels; + const int32_t level_start_id = nram_level_start_index[l_col]; + value_offset_temp = + b_col * spatial_size * qid_stride + level_start_id * qid_stride; + if (mask2[loop]) { + __memcpy_async( + nram_grad_output_tl + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset1[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + if (mask1[loop]) { + __memcpy_async( + nram_grad_output_tr + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset2[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + if (mask4[loop]) { + __memcpy_async( + nram_grad_output_bl + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset3[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + if (mask3[loop]) { + __memcpy_async( + nram_grad_output_br + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset4[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); } } + for (int32_t m = 0; m < deal_num_real; ++m) { + __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight, + num_deal_grid * sizeof(float), NRAM2NRAM); + } + __sync_io_move_compute(); +#endif +} + +void __mlu_func__ computeGradValue( + float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4, + float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1, + float *nram_offset2, float *nram_offset3, float *nram_offset4, + int32_t *nram_level_start_index, int32_t deal_num_real, + const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3, + float *nram_w4, int32_t num_per_time_real, const int32_t num_heads, + const int32_t num_levels, const int32_t num_points, const int32_t num_query, + int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size, + const int32_t qid_stride, float *nram_grid_offset1, + float *nram_grid_offset2) { +#if __BANG_ARCH__ >= 322 + __bang_transpose(grad_temp3, grad_temp1, + deal_num_real * num_per_time_real * num_heads, + num_levels * num_points); + __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, + deal_num_real); + __bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1, + num_deal_grid * deal_num_real, + deal_num_real * num_per_time_real * num_heads); + __bang_transpose(grad_temp4, grad_temp3, num_levels * num_points, + deal_num_real * num_per_time_real * num_heads); + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w1, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads / + num_levels / num_points) * + spatial_size * qid_stride; + } + __bang_transpose(nram_grid_offset2, nram_grid_offset1, + num_per_time_real * num_heads * num_levels, num_points); + __bang_int322float((float *)nram_level_start_index, nram_level_start_index, + num_levels, 0); + __bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index, + qid_stride, num_levels); + __bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1, + num_deal_grid, num_levels); + __bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points, + num_per_time_real * num_heads * num_levels); + __bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid); + __bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid); + __bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid); + __bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask2[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset1[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w2, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask1[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset2[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w3, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask4[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset3[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } + + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w4, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask3[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset4[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } +#endif +} + +void __mlu_func__ computeGradAttnWeight( + float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl, + float *nram_grad_output_tr, float *nram_grad_output_bl, + float *nram_grad_output_br, float *grad_temp1, float *grad_temp2, + const float *grad_attn_weight, float *nram_hw, float *nram_hh, + float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1, + float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram, + int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real, + const int32_t num_heads, const int32_t num_levels, const int32_t num_points, + int32_t grid_offset, float *nram_h_high_temp) { +#if __BANG_ARCH__ >= 322 + __bang_write_zero(grad_w_weight, 2 * offset_nram); + + // grad_output_nram_tl + __bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1, + num_deal_grid * deal_num_real, num_deal_grid); + // nram_grad_output_tr + __bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr, + num_deal_grid * deal_num_real); + // nram_grad_output_tl + __bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl, + num_deal_grid * deal_num_real); + // nram_grad_output_br + __bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br, + num_deal_grid * deal_num_real); + __bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real, + num_deal_grid); + __bang_transpose(nram_grad_output_tr, nram_grad_output_br, + num_per_time_real * num_heads, + num_points * num_levels * deal_num_real); + __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1, + num_deal_grid * deal_num_real, + num_per_time_real * num_heads * deal_num_real); + __bang_transpose(nram_grad_output_br, nram_grad_output_tr, + num_points * num_levels * deal_num_real, + num_per_time_real * num_heads); + + __bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br, + num_deal_grid, deal_num_real); + recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real, + ALIGN_NUM); + __bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid, + 0); + __nram__ int table[2] = {0, (int)0xffffffff}; + __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, + num_deal_grid, 64); + __bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr, + (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); + + __bang_atomic_add((float *)nram_grad_output_tr, + (float *)grad_attn_weight + grid_offset, + (float *)nram_grad_output_tr, num_deal_grid); +#endif +} + +void __mlu_func__ computeGradSampingLoc( + const float *grad_sampling_loc, float *nram_grad_output_tl, + float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight, + int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2, + float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real, + int32_t num_per_time_real, const int32_t num_heads, + const int32_t num_levels, const int32_t num_points, int32_t grid_offset, + float *nram_h_high_temp) { +#if __BANG_ARCH__ >= 322 + __bang_transpose(nram_grad_output_tl, grad_h_weight, + num_per_time_real * num_heads * num_levels * deal_num_real, + num_points); + __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl, + (float *)nram_spatial_shapes, num_deal_grid * deal_num_real, + num_levels); + __bang_transpose(grad_h_weight, nram_grad_output_tl, + num_points * deal_num_real, + num_per_time_real * num_heads * num_levels); + for (int32_t m = 0; m < deal_num_real; ++m) { + __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight, + num_deal_grid * sizeof(float), NRAM2NRAM); + } + __sync_io_move_compute(); + __bang_transpose(nram_grad_output_tr, grad_temp1, + deal_num_real * num_per_time_real * num_heads, + num_levels * num_points); + __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1, + num_deal_grid * deal_num_real, + deal_num_real * num_per_time_real * num_heads); + __bang_transpose(grad_temp1, nram_grad_output_tr, + num_levels * num_points * deal_num_real, + num_per_time_real * num_heads); + __bang_mul(grad_h_weight, grad_h_weight, grad_temp1, + num_deal_grid * deal_num_real); + __bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid, + deal_num_real); + __memcpy_async(grad_h_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM); + recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM); + __nram__ int table[2] = {0, (int)0xffffffff}; + __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, + num_deal_grid, 64); + __bang_band((char *)grad_h_weight, (char *)grad_h_weight, + (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); + __bang_transpose(nram_grad_output_tl, grad_w_weight, + num_per_time_real * num_heads * num_levels * deal_num_real, + num_points); + __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl, + (float *)(nram_spatial_shapes + num_levels), + num_deal_grid * deal_num_real, num_levels); + __bang_transpose(grad_w_weight, nram_grad_output_tl, + num_points * deal_num_real, + num_per_time_real * num_heads * num_levels); + __bang_mul(grad_w_weight, grad_w_weight, grad_temp1, + num_deal_grid * deal_num_real); + __bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid, + deal_num_real); + __memcpy(grad_w_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM); + recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM); + __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, + num_deal_grid, 64); + __bang_band((char *)grad_w_weight, (char *)grad_w_weight, + (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); + + __memcpy(grad_w_weight + num_deal_grid, grad_h_weight, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid); + __bang_atomic_add((float *)nram_grad_output_tl, + (float *)grad_sampling_loc + grid_offset * 2, + (float *)nram_grad_output_tl, 2 * num_deal_grid); + #endif } @@ -1616,117 +1849,195 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( const int32_t num_points, float *grad_value, float *grad_sampling_loc, float *grad_attn_weight) { #if __BANG_ARCH__ > 322 - const int32_t split_num = 12; + const int32_t split_grid_num = 28; + const int32_t split_num_c = 8; const int32_t C_align = PAD_UP(channels, ALIGN_NUM); - float *grad_output_nram_tl = (float *)nram_buffer; - float *grad_output_nram_tr = (float *)nram_buffer + C_align; - float *grad_output_nram_bl = (float *)nram_buffer + 2 * C_align; - float *grad_output_nram_br = (float *)nram_buffer + 3 * C_align; - float *grad_output_nram_tl_temp = (float *)nram_buffer + 4 * C_align; - float *grad_output_nram_tr_temp = (float *)nram_buffer + 5 * C_align; - float *grad_output_nram_bl_temp = (float *)nram_buffer + 6 * C_align; - float *grad_output_nram_br_temp = (float *)nram_buffer + 7 * C_align; - float *grad_h_weight = (float *)nram_buffer + 8 * C_align; - float *grad_w_weight = (float *)nram_buffer + 9 * C_align; - float *top_grad_temp = (float *)nram_buffer + 10 * C_align; - float *top_grad = (float *)nram_buffer + 11 * C_align; + const int32_t num_hlp = num_heads * num_levels * num_points; + int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) - + 3 * num_levels * sizeof(int32_t)) / + sizeof(float) / + (split_num_c * C_align + split_grid_num) / + PAD_UP((num_hlp), ALIGN_NUM); - int32_t *spatial_shapes_nram = - (int32_t *)((float *)nram_buffer + split_num * C_align); - int32_t *level_start_index_nram = - (int32_t *)(spatial_shapes_nram + PAD_UP(num_levels * 2, ALIGN_NUM)); - float *nram_remain = (float *)((int32_t *)level_start_index_nram + - PAD_UP(num_levels, ALIGN_NUM)); + int32_t deal_grid_num_theory = num_per_time_theory * num_hlp; - // calc load num - const int32_t weight_num2nram = - (MAX_NRAM_SIZE / LEN_FLOAT - split_num * C_align - - 3 * PAD_UP(num_levels, ALIGN_NUM)) / - 3 / num_points; - int32_t load_num = weight_num2nram; - const int32_t total_num = batch * num_query * num_heads * num_levels; + const int32_t offset_nram = num_per_time_theory * C_align * num_hlp; + const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM); + float *nram_grad_output_tl = (float *)nram_buffer; + float *nram_grad_output_tr = (float *)nram_buffer + offset_nram; + float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram; + float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram; + + float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram; + float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram; + float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram; + float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram; + + float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram; + float *nram_loc_h = + (float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc; + float *nram_h_low = + (float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc; + float *nram_w_low = + (float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc; + float *nram_h_high = + (float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc; + float *nram_w_high = + (float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc; + float *nram_h_low_temp = + (float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc; + float *nram_h_high_temp = + (float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc; + + float *nram_hw = + (float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc; + float *nram_hh = + (float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc; + float *nram_lw = + (float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc; + float *nram_lh = + (float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc; + + float *nram_h_low_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc; + float *nram_h_high_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc; + float *nram_w_low_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc; + float *nram_w_high_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc; + + float *nram_w1 = + (float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc; + float *nram_w2 = + (float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc; + float *nram_w3 = + (float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc; + float *nram_w4 = + (float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc; + + float *nram_grad_weight = + (float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc; + float *nram_base_ptr = + (float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc; + float *nram_offset_temp = + (float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc; + + float *nram_offset1 = + (float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc; + float *nram_offset2 = + (float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc; + float *nram_offset3 = + (float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc; + float *nram_offset4 = + (float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc; + + float *nram_w_low_temp = + (float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc; + int32_t *nram_spatial_shapes = + (int32_t *)((float *)nram_buffer + split_num_c * offset_nram + + 28 * offset_nram_calc); + int32_t *nram_level_start_index = + (int32_t *)(nram_spatial_shapes + 2 * num_levels); + float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels); + const int32_t total_num = batch * num_query; int32_t num_per_core = total_num / taskDim; int32_t num_rem = total_num % taskDim; num_per_core = num_per_core + int32_t(taskId < num_rem); - if (num_per_core == 0) { - return; - } - const int32_t start_per_core = num_rem > taskId - ? (taskId * num_per_core) - : (num_rem + taskId * num_per_core); + num_per_time_theory = + num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core; + int32_t num_deal_grid = num_per_time_theory * num_hlp; + + if (num_per_core == 0) return; + int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core) + : (num_rem + taskId * num_per_core); + const int32_t qid_stride = num_heads * channels; + int32_t deal_num_real = channels; - // load spatial_shapes anddata_level_start_index to nram - __memcpy_async(spatial_shapes_nram, spatial_shapes, - num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); - __memcpy_async(level_start_index_nram, data_level_start_index, - num_levels * sizeof(int32_t), GDRAM2NRAM); + const int32_t repeat_times = num_per_core / num_per_time_theory; + const int32_t tail_num = num_per_core % num_per_time_theory; - const int32_t start_l_col = start_per_core % num_levels; - const int32_t start_m_col = start_per_core / num_levels % num_heads; - const int32_t start_q_col = - start_per_core / num_levels / num_heads % num_query; - const int32_t start_b_col = - start_per_core / num_query / num_heads / num_levels; + int32_t num_per_time_real = num_per_time_theory; - const int32_t repeat = num_per_core / load_num; - const int32_t tail = num_per_core % load_num; - float *nram_sampling_loc = nram_remain; - float *nram_attn_weight = nram_sampling_loc + 2 * load_num * num_points; - - const int32_t attn_weight_offset = - start_b_col * num_query * num_heads * num_levels * num_points + - start_q_col * num_heads * num_levels * num_points + - start_m_col * num_levels * num_points + start_l_col * num_points; - const int32_t sampling_loc_offset = - start_b_col * num_query * num_heads * num_levels * num_points * 2 + - start_q_col * num_heads * num_levels * num_points * 2 + - start_m_col * num_levels * num_points * 2 + start_l_col * num_points * 2; - if (repeat > 0) { - for (int32_t i_repeat = 0; i_repeat < repeat; ++i_repeat) - - { // load weight and sampling_loc to nram - __memcpy_async(nram_sampling_loc, - data_sampling_loc + sampling_loc_offset + - i_repeat * load_num * 2 * num_points, - 2 * load_num * num_points * LEN_FLOAT, GDRAM2NRAM); - __memcpy(nram_attn_weight, - data_attn_weight + attn_weight_offset + - i_repeat * load_num * num_points, - load_num * num_points * LEN_FLOAT, GDRAM2NRAM); - msDeformAttnCol2imImpl( - top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value, - grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, - grad_output_nram_br, grad_output_nram_tl_temp, - grad_output_nram_tr_temp, grad_output_nram_bl_temp, - grad_output_nram_br_temp, grad_sampling_loc, grad_attn_weight, - nram_sampling_loc, nram_attn_weight, load_num, load_num, i_repeat, - num_points, start_per_core, num_levels, num_heads, num_query, - spatial_size, qid_stride, level_start_index_nram, channels, - data_value, grad_output, spatial_shapes_nram); - } + for (int32_t loop = 0; loop < num_heads; ++loop) { + nram_base_ptr[loop] = loop * channels; } - if (tail > 0) + const int32_t w_stride = num_heads * channels; + for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) { + int32_t grid_offset = + (start_per_core + grid_loop * num_per_time_theory) * num_hlp; + if (grid_loop == repeat_times) { + if (tail_num == 0) { + continue; + } else { + grid_offset = + (start_per_core + repeat_times * num_per_time_theory) * num_hlp; + num_per_time_real = tail_num; + num_deal_grid = tail_num * num_hlp; + } + } - { // load weight and sampling_loc to nram - __memcpy_async(nram_sampling_loc, - data_sampling_loc + sampling_loc_offset + - repeat * load_num * 2 * num_points, - tail * num_points * 2 * LEN_FLOAT, GDRAM2NRAM); - __memcpy( - nram_attn_weight, - data_attn_weight + attn_weight_offset + repeat * load_num * num_points, - tail * num_points * LEN_FLOAT, GDRAM2NRAM); - msDeformAttnCol2imImpl( - top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value, - grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, - grad_output_nram_br, grad_output_nram_tl_temp, grad_output_nram_tr_temp, - grad_output_nram_bl_temp, grad_output_nram_br_temp, grad_sampling_loc, - grad_attn_weight, nram_sampling_loc, nram_attn_weight, load_num, tail, - repeat, num_points, start_per_core, num_levels, num_heads, num_query, - spatial_size, qid_stride, level_start_index_nram, channels, data_value, - grad_output, spatial_shapes_nram); + __memcpy_async(nram_spatial_shapes, spatial_shapes, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(nram_level_start_index, data_level_start_index, + num_levels * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2, + num_deal_grid * 2 * sizeof(float), GDRAM2NRAM); + __memcpy(nram_grad_weight, data_attn_weight + grid_offset, + num_deal_grid * sizeof(float), GDRAM2NRAM); + computeGridMaskAndOffset( + nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h, + nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp, + nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw, + nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset, + nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2, + nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2, + nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp, + num_deal_grid, num_per_time_real, num_heads, num_levels, num_points, + w_stride, qid_stride); + float *mask1 = nram_h_low_ptr_offset; + float *mask2 = nram_h_high_ptr_offset; + float *mask3 = nram_w_low_ptr_offset; + float *mask4 = nram_w_high_ptr_offset; + loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl, + nram_grad_output_br, data_value, grad_output, grad_temp1, + grad_temp2, mask1, mask2, mask3, mask4, nram_offset1, + nram_offset2, nram_offset3, nram_offset4, nram_grad_weight, + nram_level_start_index, offset_nram, start_per_core, grid_loop, + num_per_time_theory, num_heads, deal_num_real, num_per_time_real, + num_deal_grid, num_query, num_levels, num_points, grid_offset, + spatial_size, qid_stride); + float *nram_grid_offset1 = nram_loc_h; + float *nram_grid_offset2 = nram_loc_w; + computeGradValue( + grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3, + mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4, + nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2, + nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points, + num_query, num_deal_grid, grid_offset, spatial_size, qid_stride, + nram_grid_offset1, nram_grid_offset2); + + // compute grad_weight + float *grad_weight = grad_temp1; + float *grad_h_weight = grad_temp4; + float *grad_w_weight = grad_temp3; + computeGradAttnWeight( + grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr, + nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2, + grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight, + nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid, + deal_num_real, num_per_time_real, num_heads, num_levels, num_points, + grid_offset, nram_h_high_temp); + + // compute grad_sampling_loc + computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl, + nram_grad_output_tr, grad_h_weight, grad_w_weight, + nram_spatial_shapes, grad_temp1, grad_temp2, + nram_grad_weight, num_deal_grid, deal_num_real, + num_per_time_real, num_heads, num_levels, num_points, + grid_offset, nram_h_high_temp); } #endif } @@ -1739,6 +2050,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( const int32_t channels, const int32_t num_levels, const int32_t num_query, const int32_t num_points, float *grad_value, float *grad_sampling_loc, float *grad_attn_weight); + __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, const float *data_sampling_loc, diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index 845465ae4..f8e884d97 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -47,17 +47,17 @@ typedef enum { } MsDeformAttnBackwardKernelPolicy; MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( - const int32_t channels, const int32_t num_levels, - const int32_t num_points) { + const int32_t channels, const int32_t num_levels, const int32_t num_points, + const int32_t num_heads) { const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - const uint64_t max_num = nram_size / sizeof(float); - const uint64_t deal_num = - 12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points; - - if (max_num >= deal_num) { + const int num_hlp = num_heads * num_levels * num_points; + int num_per_time_theory = (nram_size - num_levels * sizeof(float) - + 3 * num_levels * sizeof(int32_t)) / + sizeof(float) / (8 * PAD_UP(channels, 32) + 28) / + PAD_UP((num_hlp), 32); + if (num_per_time_theory >= 1) { return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; } - return MS_DEFORM_ATTN_BACKWARD_DEFAULT; } @@ -101,7 +101,8 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { return MS_DEFORM_ATTN_FORWARD_DEFAULT; - } else if (channels > nram_size / 12 / sizeof(float)) { + } else if (channels > nram_size / 12 / sizeof(float) || channels > 96 || + channels < 16) { return MS_DEFORM_ATTN_FORWARD_DEFAULT; } else { return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; @@ -472,7 +473,8 @@ void ms_deform_attn_mlu_backward( CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; MsDeformAttnBackwardKernelPolicy kernelPolicy = - msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points); + msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points, + num_heads); switch (kernelPolicy) { default: { VLOG(5) << "NotImplemented.";