PaddleClas/csrc/generation/transpose_remove_padding.cu

177 lines
7.5 KiB
Plaintext

// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
const int* seq_lens,
T* output_data,
const int batch_size,
const int num_head,
const int max_len_this_time,
const int seq_len,
const int head_dim,
const int token_num,
const int elem_cnt,
const int* padding_offset) {
// transpose and remove padding
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
// head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
for (int32_t linear_index = idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / dim_embed;
const int ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int ori_batch_id = ori_token_idx / seq_len;
if (seq_lens && seq_lens[ori_batch_id] == 0) continue;
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * max_len_this_time * head_dim +
ori_head_id * max_len_this_time * head_dim +
ori_seq_id * head_dim + ori_head_lane;
Load<T, VecSize>(&input_data[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}
template <typename T>
void InvokeTransposeRemovePadding(const T* input_data,
const int* seq_lens,
T* output_data,
const int batch_size,
const int num_head,
const int max_len_this_time,
const int seq_len,
const int head_dim,
const int token_num,
const int* padding_offset,
cudaStream_t cu_stream) {
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
const int elem_cnt = token_num * num_head * head_dim;
constexpr int PackSize = VEC_16B / sizeof(T);
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
<<<grid_size, block_size, 0, cu_stream>>>(input_data,
seq_lens,
output_data,
batch_size,
num_head,
max_len_this_time,
seq_len,
head_dim,
token_num,
elem_cnt,
padding_offset);
}
template <paddle::DataType D>
std::vector<paddle::Tensor> apply_transpose_remove_padding(const paddle::Tensor& input,
const paddle::Tensor& seq_lens,
const paddle::Tensor& padding_offset) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto cu_stream = input.stream();
std::vector<int64_t> input_shape = input.shape();
const int bsz = input_shape[0];
const int num_head = input_shape[1];
const int seq_len = input_shape[2];
const int dim_head = input_shape[3];
const int token_num = padding_offset.shape()[0];
auto out = paddle::full({token_num, num_head * dim_head}, 0, input.dtype(), input.place());
InvokeTransposeRemovePadding(
reinterpret_cast<DataType_*>(const_cast<data_t*>(input.data<data_t>())),
seq_lens.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
bsz,
num_head,
seq_len,
seq_len,
dim_head,
token_num,
padding_offset.data<int>(),
cu_stream
);
return {out};
}
std::vector<paddle::Tensor> ApplyTransposeRemovingPadding(const paddle::Tensor& input,
const paddle::Tensor& seq_lens,
const paddle::Tensor& padding_offset) {
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
return apply_transpose_remove_padding<paddle::DataType::BFLOAT16>(
input,
seq_lens,
padding_offset
);
}
case paddle::DataType::FLOAT16: {
return apply_transpose_remove_padding<paddle::DataType::FLOAT16>(
input,
seq_lens,
padding_offset
);
}
case paddle::DataType::FLOAT32: {
return apply_transpose_remove_padding<paddle::DataType::FLOAT32>(
input,
seq_lens,
padding_offset
);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> ApplyTransposeRemovingPaddingInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& seq_lens_shape,
const std::vector<int64_t>& padding_offset_shape) {
return {{padding_offset_shape[0], input_shape[1] * input_shape[3]}};
}
std::vector<paddle::DataType> ApplyTransposeRemovingPaddingInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& seq_lens_dtype,
const paddle::DataType& padding_offset_dtype) {
return {input_dtype};
}
PD_BUILD_OP(transpose_remove_padding)
.Inputs({"input", "seq_lens", "padding_offset"})
.Outputs({"fmha_out"})
.SetKernelFn(PD_KERNEL(ApplyTransposeRemovingPadding))
.SetInferShapeFn(PD_INFER_SHAPE(ApplyTransposeRemovingPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyTransposeRemovingPaddingInferDtype));