// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" template __global__ void TransposeRemovingPadding(const T* input_data, const int* seq_lens, T* output_data, const int batch_size, const int num_head, const int max_len_this_time, const int seq_len, const int head_dim, const int token_num, const int elem_cnt, const int* padding_offset) { // transpose and remove padding // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head, // head_dim] int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; const int dim_embed = num_head * head_dim; using LoadT = AlignedVector; LoadT src_vec; for (int32_t linear_index = idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / dim_embed; const int ori_token_idx = token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); const int ori_batch_id = ori_token_idx / seq_len; if (seq_lens && seq_lens[ori_batch_id] == 0) continue; const int ori_seq_id = ori_token_idx % seq_len; const int ori_head_id = (linear_index % dim_embed) / head_dim; const int ori_head_lane = (linear_index % dim_embed) % head_dim; const int ori_idx = ori_batch_id * num_head * max_len_this_time * head_dim + ori_head_id * max_len_this_time * head_dim + ori_seq_id * head_dim + ori_head_lane; Load(&input_data[ori_idx], &src_vec); Store(src_vec, &output_data[linear_index]); } } template void InvokeTransposeRemovePadding(const T* input_data, const int* seq_lens, T* output_data, const int batch_size, const int num_head, const int max_len_this_time, const int seq_len, const int head_dim, const int token_num, const int* padding_offset, cudaStream_t cu_stream) { // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head, // head_dim] constexpr int VEC_16B = 16; const int elem_cnt = token_num * num_head * head_dim; constexpr int PackSize = VEC_16B / sizeof(T); const int32_t pack_num = elem_cnt / PackSize; const int32_t block_size = 128; int32_t grid_size = (pack_num + block_size - 1) / block_size; TransposeRemovingPadding <<>>(input_data, seq_lens, output_data, batch_size, num_head, max_len_this_time, seq_len, head_dim, token_num, elem_cnt, padding_offset); } template std::vector apply_transpose_remove_padding(const paddle::Tensor& input, const paddle::Tensor& seq_lens, const paddle::Tensor& padding_offset) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; auto cu_stream = input.stream(); std::vector input_shape = input.shape(); const int bsz = input_shape[0]; const int num_head = input_shape[1]; const int seq_len = input_shape[2]; const int dim_head = input_shape[3]; const int token_num = padding_offset.shape()[0]; auto out = paddle::full({token_num, num_head * dim_head}, 0, input.dtype(), input.place()); InvokeTransposeRemovePadding( reinterpret_cast(const_cast(input.data())), seq_lens.data(), reinterpret_cast(out.data()), bsz, num_head, seq_len, seq_len, dim_head, token_num, padding_offset.data(), cu_stream ); return {out}; } std::vector ApplyTransposeRemovingPadding(const paddle::Tensor& input, const paddle::Tensor& seq_lens, const paddle::Tensor& padding_offset) { switch (input.type()) { case paddle::DataType::BFLOAT16: { return apply_transpose_remove_padding( input, seq_lens, padding_offset ); } case paddle::DataType::FLOAT16: { return apply_transpose_remove_padding( input, seq_lens, padding_offset ); } case paddle::DataType::FLOAT32: { return apply_transpose_remove_padding( input, seq_lens, padding_offset ); } default: { PD_THROW( "NOT supported data type. " "Only float16, bfloat16 and float32 are supported. "); break; } } } std::vector> ApplyTransposeRemovingPaddingInferShape( const std::vector& input_shape, const std::vector& seq_lens_shape, const std::vector& padding_offset_shape) { return {{padding_offset_shape[0], input_shape[1] * input_shape[3]}}; } std::vector ApplyTransposeRemovingPaddingInferDtype( const paddle::DataType& input_dtype, const paddle::DataType& seq_lens_dtype, const paddle::DataType& padding_offset_dtype) { return {input_dtype}; } PD_BUILD_OP(transpose_remove_padding) .Inputs({"input", "seq_lens", "padding_offset"}) .Outputs({"fmha_out"}) .SetKernelFn(PD_KERNEL(ApplyTransposeRemovingPadding)) .SetInferShapeFn(PD_INFER_SHAPE(ApplyTransposeRemovingPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ApplyTransposeRemovingPaddingInferDtype));