feat: add fused vision transformer (#3034)
* add paddleclas_ops csrc * add fused_vit * remove "wint4", not support * add wint4 * add fused_vit README.md * fit static graph * fix layernorm dtype * add static graph infer & performance data * update weight only scale dtype * update readme * update readmefused_vit
parent
178eede1db
commit
1c7c773854
|
@ -0,0 +1,15 @@
|
|||
# PaddleClas 自定义 OP
|
||||
|
||||
此文档介绍如何编译安装 PaddleClas 自定义 OP。
|
||||
|
||||
## 安装 pip 依赖
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 编译 Cuda 算子
|
||||
|
||||
```shell
|
||||
python setup_cuda.py install
|
||||
```
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
constexpr int kBlockSize = 256;
|
||||
constexpr int kNumWaves = 16;
|
||||
|
||||
inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
|
||||
int dev;
|
||||
{
|
||||
cudaError_t err = cudaGetDevice(&dev);
|
||||
if (err != cudaSuccess) { return err; }
|
||||
}
|
||||
int sm_count;
|
||||
{
|
||||
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
|
||||
if (err != cudaSuccess) { return err; }
|
||||
}
|
||||
int tpm;
|
||||
{
|
||||
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
|
||||
if (err != cudaSuccess) { return err; }
|
||||
}
|
||||
*num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
|
||||
sm_count * tpm / kBlockSize * kNumWaves));
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ T max_func(const T a, const T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
|
||||
return max_func(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <paddle::DataType D>
|
||||
class PDTraits;
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef half DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
class PDTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef __nv_bfloat16 DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
|
||||
const AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
|
||||
*vec = *addr_vec;
|
||||
}
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
|
||||
AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size>*>(addr);
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
constexpr int VEC_16B = 16;
|
|
@ -0,0 +1,193 @@
|
|||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void fusedQKV_transpose_split_kernel(
|
||||
T *q_buf,
|
||||
T *k_buf,
|
||||
T *v_buf,
|
||||
const T *qkv,
|
||||
const int *padding_offset,
|
||||
const int *seq_lens,
|
||||
const int32_t elem_cnt,
|
||||
const int batch_size,
|
||||
const int max_len_this_time,
|
||||
const int seq_len,
|
||||
const int token_num,
|
||||
const int head_num,
|
||||
const int size_per_head) {
|
||||
const int32_t offset = batch_size * max_len_this_time * head_num * size_per_head;
|
||||
const int32_t hidden_size = head_num * size_per_head;
|
||||
const int32_t fused_hidden_size = 3 * hidden_size;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadT bias_vec;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
Load<T, VecSize>(&qkv[linear_index], &src_vec);
|
||||
int32_t bias_idx = linear_index % fused_hidden_size;
|
||||
const int32_t token_idx = linear_index / fused_hidden_size;
|
||||
const int32_t ori_token_idx =
|
||||
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
|
||||
const int32_t target_batch_id = ori_token_idx / seq_len;
|
||||
if (seq_lens[target_batch_id] == 0) continue;
|
||||
const int32_t seq_id = ori_token_idx % seq_len;
|
||||
|
||||
// equal to:
|
||||
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
|
||||
const int32_t qkv_id = bias_idx / hidden_size;
|
||||
const int32_t head_id = (linear_index % hidden_size) / size_per_head;
|
||||
const int32_t size_id = linear_index % size_per_head;
|
||||
|
||||
if (qkv_id == 0) {
|
||||
Store<T, VecSize>(
|
||||
src_vec,
|
||||
&q_buf[target_batch_id * head_num * max_len_this_time * size_per_head +
|
||||
head_id * max_len_this_time * size_per_head + seq_id * size_per_head +
|
||||
size_id]);
|
||||
} else if (qkv_id == 1) {
|
||||
Store<T, VecSize>(
|
||||
src_vec,
|
||||
&k_buf[target_batch_id * head_num * max_len_this_time * size_per_head +
|
||||
head_id * max_len_this_time * size_per_head + seq_id * size_per_head +
|
||||
size_id]);
|
||||
} else {
|
||||
Store<T, VecSize>(
|
||||
src_vec,
|
||||
&v_buf[target_batch_id * head_num * max_len_this_time * size_per_head +
|
||||
head_id * max_len_this_time * size_per_head + seq_id * size_per_head +
|
||||
size_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> qkv_transpose_split(const paddle::Tensor& qkv, // [token_num, dim_embed]
|
||||
const paddle::Tensor& padding_offset, // [bsz, 1]
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& input_ids,
|
||||
int num_head,
|
||||
int head_size) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto cu_stream = qkv.stream();
|
||||
std::vector<int64_t> qkv_shape = qkv.shape();
|
||||
const int token_num = qkv_shape[0];
|
||||
const int bsz = seq_lens.shape()[0];
|
||||
const int max_seq_len = input_ids.shape()[1]; //max_seq_len_tensor.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
auto q_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place());
|
||||
auto k_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place());
|
||||
auto v_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place());
|
||||
constexpr int PackSize = VEC_16B / sizeof(DataType_);
|
||||
const int elem_cnt = token_num * num_head * head_size * 3;
|
||||
const int pack_num = elem_cnt / PackSize;
|
||||
const int blocksize = 128;
|
||||
const int grid_size = (pack_num + blocksize - 1) / blocksize;
|
||||
fusedQKV_transpose_split_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, qkv.stream()>>>(
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(v_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(qkv.data<data_t>())),
|
||||
padding_offset.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
elem_cnt,
|
||||
bsz,
|
||||
max_seq_len,
|
||||
max_seq_len,
|
||||
token_num,
|
||||
num_head,
|
||||
head_size);
|
||||
return {q_out, k_out, v_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> QKVTransposeSplit(const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& padding_offset,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& input_ids,
|
||||
int num_head,
|
||||
int head_size) {
|
||||
switch (qkv.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return qkv_transpose_split<paddle::DataType::BFLOAT16>(
|
||||
qkv,
|
||||
padding_offset,
|
||||
seq_lens,
|
||||
input_ids,
|
||||
num_head,
|
||||
head_size
|
||||
);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return qkv_transpose_split<paddle::DataType::FLOAT16>(
|
||||
qkv,
|
||||
padding_offset,
|
||||
seq_lens,
|
||||
input_ids,
|
||||
num_head,
|
||||
head_size
|
||||
);
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return qkv_transpose_split<paddle::DataType::FLOAT32>(
|
||||
qkv,
|
||||
padding_offset,
|
||||
seq_lens,
|
||||
input_ids,
|
||||
num_head,
|
||||
head_size
|
||||
);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> QKVTransposeSplitInferShape(const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& padding_offset_shape,
|
||||
const std::vector<int64_t>& seq_lens_shape,
|
||||
const std::vector<int64_t>& input_ids_shape,
|
||||
int num_head,
|
||||
int head_size) {
|
||||
int64_t bsz = seq_lens_shape[0];
|
||||
return {{bsz, num_head, -1, head_size}, {bsz, num_head, -1, head_size}, {bsz, num_head, -1, head_size}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> QKVTransposeSplitInferDtype(const paddle::DataType& qkv_dtype,
|
||||
const paddle::DataType& padding_offset_dtype,
|
||||
const paddle::DataType& seq_lens_dtype,
|
||||
const paddle::DataType& input_ids_dtype) {
|
||||
return {qkv_dtype, qkv_dtype, qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(qkv_transpose_split)
|
||||
.Inputs({"qkv", "padding_offset", "seq_lens", "input_ids"})
|
||||
.Outputs({"q_out", "k_out", "v_out"})
|
||||
.Attrs({"num_head: int",
|
||||
"head_size: int"})
|
||||
.SetKernelFn(PD_KERNEL(QKVTransposeSplit))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(QKVTransposeSplitInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(QKVTransposeSplitInferDtype));
|
|
@ -0,0 +1,177 @@
|
|||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void TransposeRemovingPadding(const T* input_data,
|
||||
const int* seq_lens,
|
||||
T* output_data,
|
||||
const int batch_size,
|
||||
const int num_head,
|
||||
const int max_len_this_time,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int elem_cnt,
|
||||
const int* padding_offset) {
|
||||
// transpose and remove padding
|
||||
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
|
||||
// head_dim]
|
||||
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int dim_embed = num_head * head_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
for (int32_t linear_index = idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / dim_embed;
|
||||
const int ori_token_idx =
|
||||
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
|
||||
const int ori_batch_id = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_batch_id] == 0) continue;
|
||||
const int ori_seq_id = ori_token_idx % seq_len;
|
||||
const int ori_head_id = (linear_index % dim_embed) / head_dim;
|
||||
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
|
||||
const int ori_idx = ori_batch_id * num_head * max_len_this_time * head_dim +
|
||||
ori_head_id * max_len_this_time * head_dim +
|
||||
ori_seq_id * head_dim + ori_head_lane;
|
||||
Load<T, VecSize>(&input_data[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &output_data[linear_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InvokeTransposeRemovePadding(const T* input_data,
|
||||
const int* seq_lens,
|
||||
T* output_data,
|
||||
const int batch_size,
|
||||
const int num_head,
|
||||
const int max_len_this_time,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int* padding_offset,
|
||||
cudaStream_t cu_stream) {
|
||||
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
|
||||
// head_dim]
|
||||
constexpr int VEC_16B = 16;
|
||||
const int elem_cnt = token_num * num_head * head_dim;
|
||||
constexpr int PackSize = VEC_16B / sizeof(T);
|
||||
const int32_t pack_num = elem_cnt / PackSize;
|
||||
const int32_t block_size = 128;
|
||||
int32_t grid_size = (pack_num + block_size - 1) / block_size;
|
||||
TransposeRemovingPadding<T, PackSize>
|
||||
<<<grid_size, block_size, 0, cu_stream>>>(input_data,
|
||||
seq_lens,
|
||||
output_data,
|
||||
batch_size,
|
||||
num_head,
|
||||
max_len_this_time,
|
||||
seq_len,
|
||||
head_dim,
|
||||
token_num,
|
||||
elem_cnt,
|
||||
padding_offset);
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> apply_transpose_remove_padding(const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& padding_offset) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto cu_stream = input.stream();
|
||||
std::vector<int64_t> input_shape = input.shape();
|
||||
const int bsz = input_shape[0];
|
||||
const int num_head = input_shape[1];
|
||||
const int seq_len = input_shape[2];
|
||||
const int dim_head = input_shape[3];
|
||||
const int token_num = padding_offset.shape()[0];
|
||||
|
||||
auto out = paddle::full({token_num, num_head * dim_head}, 0, input.dtype(), input.place());
|
||||
InvokeTransposeRemovePadding(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(input.data<data_t>())),
|
||||
seq_lens.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
bsz,
|
||||
num_head,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dim_head,
|
||||
token_num,
|
||||
padding_offset.data<int>(),
|
||||
cu_stream
|
||||
);
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> ApplyTransposeRemovingPadding(const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& padding_offset) {
|
||||
switch (input.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return apply_transpose_remove_padding<paddle::DataType::BFLOAT16>(
|
||||
input,
|
||||
seq_lens,
|
||||
padding_offset
|
||||
);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return apply_transpose_remove_padding<paddle::DataType::FLOAT16>(
|
||||
input,
|
||||
seq_lens,
|
||||
padding_offset
|
||||
);
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return apply_transpose_remove_padding<paddle::DataType::FLOAT32>(
|
||||
input,
|
||||
seq_lens,
|
||||
padding_offset
|
||||
);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ApplyTransposeRemovingPaddingInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& seq_lens_shape,
|
||||
const std::vector<int64_t>& padding_offset_shape) {
|
||||
return {{padding_offset_shape[0], input_shape[1] * input_shape[3]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> ApplyTransposeRemovingPaddingInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& seq_lens_dtype,
|
||||
const paddle::DataType& padding_offset_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(transpose_remove_padding)
|
||||
.Inputs({"input", "seq_lens", "padding_offset"})
|
||||
.Outputs({"fmha_out"})
|
||||
.SetKernelFn(PD_KERNEL(ApplyTransposeRemovingPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(ApplyTransposeRemovingPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyTransposeRemovingPaddingInferDtype));
|
|
@ -0,0 +1,2 @@
|
|||
cupy-cuda116
|
||||
pybind11
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.utils.cpp_extension import CUDAExtension, setup
|
||||
|
||||
setup(
|
||||
name="paddleclas_ops",
|
||||
ext_modules=CUDAExtension(
|
||||
sources=[
|
||||
"./generation/transpose_remove_padding.cu",
|
||||
"./generation/qkv_transpose_split.cu",
|
||||
]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,331 @@
|
|||
# Fused Vision Transformer 高性能推理使用
|
||||
|
||||
PaddleClas 中已经添加高性能推理模型相关实现,支持:
|
||||
|
||||
| Model | FP16 | Wint8 | Wint4 | PTQ |
|
||||
|-------------------------------------------------------------------------------------------------|------|-------|-------|-----|
|
||||
| [Fused Vision Transformer](../../../ppcls/arch/backbone/model_zoo/fused_vision_transformer.py) | ✅ | ✅ | ✅ | ❌ |
|
||||
|
||||
* 支持以下`fused_vit`类型
|
||||
* `Fused_ViT_small_patch16_224`
|
||||
* `Fused_ViT_base_patch16_224`
|
||||
* `Fused_ViT_base_patch16_384`
|
||||
* `Fused_ViT_base_patch32_384`
|
||||
* `Fused_ViT_large_patch16_224`
|
||||
* `Fused_ViT_large_patch16_384`
|
||||
* `Fused_ViT_large_patch32_384`
|
||||
* 预训练权重来自Vision Transformer对应权重
|
||||
|
||||
## 安装自定义算子库
|
||||
|
||||
PaddleClas 针对于 Fused Vision Transformer 系列编写了高性能自定义算子,提升模型在推理和解码过程中的性能。
|
||||
|
||||
```shell
|
||||
cd ./PaddleClas/csrc
|
||||
pip install -r requirements.txt
|
||||
python setup_cuda.py install
|
||||
```
|
||||
|
||||
## 静态图推理
|
||||
|
||||
* 模型导出
|
||||
|
||||
```python
|
||||
from paddleclas import (
|
||||
Fused_ViT_large_patch16_224,
|
||||
Fused_ViT_large_patch32_384
|
||||
)
|
||||
import paddle
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float16"
|
||||
paddle.set_default_dtype(dtype)
|
||||
path = "/your/path/fused_384_fp16/static_model"
|
||||
model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000)
|
||||
model.eval()
|
||||
model = paddle.jit.to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + [3, 384, 384],
|
||||
dtype=dtype
|
||||
)
|
||||
]
|
||||
)
|
||||
paddle.jit.save(model, path)
|
||||
```
|
||||
|
||||
* 模型推理
|
||||
|
||||
```python
|
||||
from paddle.inference import create_predictor
|
||||
from paddle.inference import PrecisionType
|
||||
from paddle.inference import Config
|
||||
from paddleclas_ops import (
|
||||
qkv_transpose_split,
|
||||
transpose_remove_padding
|
||||
)
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
from paddleclas import (
|
||||
Fused_ViT_large_patch32_384,
|
||||
)
|
||||
|
||||
def run(predictor, img):
|
||||
# copy img data to input tensor
|
||||
input_names = predictor.get_input_names()
|
||||
for i, name in enumerate(input_names):
|
||||
input_tensor = predictor.get_input_handle(name)
|
||||
input_tensor.reshape(img[i].shape)
|
||||
input_tensor.copy_from_cpu(img[i])
|
||||
|
||||
# do the inference
|
||||
predictor.run()
|
||||
|
||||
results = []
|
||||
# get out data from output tensor
|
||||
output_names = predictor.get_output_names()
|
||||
for i, name in enumerate(output_names):
|
||||
output_tensor = predictor.get_output_handle(name)
|
||||
output_data = output_tensor.copy_to_cpu()
|
||||
results.append(output_data)
|
||||
return results
|
||||
|
||||
def static_infer(model_file, params_file, images):
|
||||
config = Config(model_file, params_file)
|
||||
config.enable_memory_optim()
|
||||
config.enable_use_gpu(1000, 0)
|
||||
|
||||
predictor = create_predictor(config)
|
||||
|
||||
output = run(predictor, [images])
|
||||
|
||||
return output
|
||||
|
||||
def main_fp16():
|
||||
dtype = "float16"
|
||||
N, C, H, W = (1, 3, 384, 384)
|
||||
images = np.random.rand(N, C, H, W).astype(dtype)
|
||||
|
||||
# fp32 static infer
|
||||
model_file = "/your/path/fused_384_fp16/static_model.pdmodel"
|
||||
params_file = "/your/path/fused_384_fp16/static_model.pdiparams"
|
||||
static_fp16_output = static_infer(model_file, params_file, images)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_fp16()
|
||||
```
|
||||
|
||||
## 动态图推理
|
||||
|
||||
### FP16
|
||||
|
||||
* `fused_vit`通过`paddle.set_default_dtype`来设置`weight`的数据类型
|
||||
|
||||
```python
|
||||
import paddle
|
||||
|
||||
from paddleclas import (
|
||||
Fused_ViT_large_patch32_384,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
dtype = "float16"
|
||||
N, C, H, W = (1, 3, 384, 384)
|
||||
images = paddle.randn([N, C, H, W]).cast(dtype)
|
||||
paddle.set_default_dtype(dtype)
|
||||
|
||||
# ----- Fused Model -----
|
||||
fused_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000)
|
||||
fused_output = fused_model(images)
|
||||
print(fused_output)
|
||||
```
|
||||
|
||||
### Weight Only Int8/Int4 推理
|
||||
|
||||
> weight only int4 存在精度问题
|
||||
|
||||
* 参数介绍:
|
||||
* `use_weight_only`:使用 weight only 推理,默认为 False
|
||||
* `quant_type`:weight only 类型,默认为`weight_only_int8`,可选`weight_only_int4`
|
||||
|
||||
```python
|
||||
import paddle
|
||||
|
||||
from paddleclas import (
|
||||
Fused_ViT_large_patch32_384,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
dtype = "float16"
|
||||
N, C, H, W = (1, 3, 384, 384)
|
||||
images = paddle.randn([N, C, H, W]).cast(dtype)
|
||||
paddle.set_default_dtype(dtype)
|
||||
|
||||
# ----- 8 bits Quanted Model -----
|
||||
quanted_model_8 = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000, use_weight_only=True)
|
||||
quanted_output_8 = quanted_model_8(images)
|
||||
print(quanted_output_8)
|
||||
|
||||
# ----- 4 bits Quanted Model -----
|
||||
quanted_model_4 = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000, use_weight_only=True, quant_type="weight_only_int4")
|
||||
quanted_output_4 = quanted_model_4(images)
|
||||
print(quanted_output_4)
|
||||
```
|
||||
|
||||
## 性能数据
|
||||
### 测试代码
|
||||
|
||||
```python
|
||||
from paddle.inference import create_predictor
|
||||
from paddle.inference import PrecisionType
|
||||
from paddle.inference import Config
|
||||
from paddleclas_ops import (
|
||||
qkv_transpose_split,
|
||||
transpose_remove_padding
|
||||
)
|
||||
import paddle
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from paddleclas import (
|
||||
Fused_ViT_large_patch16_224,
|
||||
Fused_ViT_large_patch32_384,
|
||||
ViT_large_patch16_224,
|
||||
ViT_large_patch32_384,
|
||||
)
|
||||
|
||||
paddle.seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
warmup_time = 10
|
||||
test_time = 100
|
||||
|
||||
def run(predictor, img):
|
||||
# copy img data to input tensor
|
||||
input_names = predictor.get_input_names()
|
||||
for i, name in enumerate(input_names):
|
||||
input_tensor = predictor.get_input_handle(name)
|
||||
input_tensor.reshape(img[i].shape)
|
||||
input_tensor.copy_from_cpu(img[i])
|
||||
|
||||
# do the inference
|
||||
predictor.run()
|
||||
|
||||
results = []
|
||||
# get out data from output tensor
|
||||
output_names = predictor.get_output_names()
|
||||
for i, name in enumerate(output_names):
|
||||
output_tensor = predictor.get_output_handle(name)
|
||||
output_data = output_tensor.copy_to_cpu()
|
||||
results.append(output_data)
|
||||
return results
|
||||
|
||||
def static_infer(model_file, params_file, images):
|
||||
config = Config(model_file, params_file)
|
||||
config.enable_memory_optim()
|
||||
config.enable_use_gpu(1000, 0)
|
||||
|
||||
predictor = create_predictor(config)
|
||||
|
||||
# warmup
|
||||
for i in range(warmup_time):
|
||||
result = run(predictor, [images])
|
||||
|
||||
# test
|
||||
paddle.device.cuda.synchronize()
|
||||
time_begin = time.time()
|
||||
for i in range(test_time):
|
||||
output = run(predictor, [images])
|
||||
paddle.device.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
print(f"input size: {images.shape}, dtype: {images.dtype}, Description: static model, Avg Time: {(time_end - time_begin) / test_time * 1000} ms")
|
||||
return output
|
||||
|
||||
def dynamic_infer(model, images, description):
|
||||
# warmup
|
||||
for i in range(warmup_time):
|
||||
output = model(images)
|
||||
|
||||
# test
|
||||
paddle.device.cuda.synchronize()
|
||||
time_begin = time.time()
|
||||
for i in range(test_time):
|
||||
output = model(images)
|
||||
paddle.device.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
print(f"input size: {images.shape}, dtype: {images.dtype}, Description: {description}, Avg Time: {(time_end - time_begin) / test_time * 1000} ms")
|
||||
return output
|
||||
|
||||
def main_fp32():
|
||||
N, C, H, W = (1, 3, 384, 384)
|
||||
# fp32
|
||||
dtype = "float32"
|
||||
paddle.set_default_dtype(dtype)
|
||||
images = np.random.rand(N, C, H, W).astype(dtype)
|
||||
images_tensor = paddle.to_tensor(images, dtype=dtype)
|
||||
|
||||
# fp32 origin
|
||||
origin_model = ViT_large_patch32_384(pretrained=True, class_num=1000)
|
||||
origin_output = dynamic_infer(origin_model, images_tensor, "Origin")
|
||||
# print(origin_output)
|
||||
|
||||
# fp32 fused
|
||||
fused_fp32_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000)
|
||||
fused_fp32_output = dynamic_infer(fused_fp32_model, images_tensor, "Fused fp32")
|
||||
# print(fused_fp32_output)
|
||||
|
||||
# fp32 static infer
|
||||
model_file = "/your/path/fused_384_fp32/static_model.pdmodel"
|
||||
params_file = "/your/path/fused_384_fp32/static_model.pdiparams"
|
||||
static_fp32_output = static_infer(model_file, params_file, images)
|
||||
# print(static_fp32_output)
|
||||
|
||||
def main_fp16():
|
||||
N, C, H, W = (1, 3, 384, 384)
|
||||
# fp16
|
||||
dtype = "float16"
|
||||
paddle.set_default_dtype(dtype)
|
||||
images = np.random.rand(N, C, H, W).astype(dtype)
|
||||
images_tensor = paddle.to_tensor(images, dtype=dtype)
|
||||
|
||||
# fp16 origin
|
||||
# need change code in /paddleclas/ppcls/utils/save_load.py load_dygraph_pretrain
|
||||
# origin_model = ViT_large_patch32_384(pretrained=True, class_num=1000)
|
||||
# origin_output = dynamic_infer(origin_model, images_tensor, "Origin")
|
||||
# print(origin_output)
|
||||
|
||||
# fp16 fused
|
||||
fused_fp16_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000)
|
||||
fused_fp16_output = dynamic_infer(fused_fp16_model, images_tensor, "Fused fp16")
|
||||
# print(fused_fp16_output)
|
||||
|
||||
# fp16 static infer
|
||||
model_file = "/your/path/fused_384_fp16/static_model.pdmodel"
|
||||
params_file = "/your/path/fused_384_fp16/static_model.pdiparams"
|
||||
static_fp16_output = static_infer(model_file, params_file, images)
|
||||
# print(static_fp16_output)
|
||||
|
||||
# wint8
|
||||
quanted_8_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000, use_weight_only=True)
|
||||
quanted_8_output = dynamic_infer(quanted_8_model, images_tensor, "8bits Fused Quanted")
|
||||
# print(quanted_8_output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_fp32()
|
||||
main_fp16()
|
||||
```
|
||||
|
||||
### 性能数据—动态图
|
||||
|
||||
<img src="imgs/performance_dynamic.jpg" alt="performance_dynamic" style="zoom:50%;" />
|
||||
|
||||
* 此处的提升是与`naive vit`对应精度实现的对比
|
||||
* `int8`实现的对比基准为`fp16`
|
||||
|
||||
### 性能数据—静态图
|
||||
|
||||
<img src="imgs/performance_static.jpg" alt="performance_static" style="zoom:50%;" />
|
||||
|
||||
* 此处的提升是与`fused vit fp32`的对比
|
Binary file not shown.
After Width: | Height: | Size: 348 KiB |
Binary file not shown.
After Width: | Height: | Size: 196 KiB |
|
@ -55,6 +55,7 @@ from .model_zoo.darknet import DarkNet53
|
|||
from .model_zoo.regnet import RegNetX_200MF, RegNetX_400MF, RegNetX_600MF, RegNetX_800MF, RegNetX_1600MF, RegNetX_3200MF, RegNetX_4GF, RegNetX_6400MF, RegNetX_8GF, RegNetX_12GF, RegNetX_16GF, RegNetX_32GF
|
||||
from .model_zoo.vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT_base_patch16_384, ViT_base_patch32_384, ViT_large_patch16_224, ViT_large_patch16_384, ViT_large_patch32_384
|
||||
from .model_zoo.distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
|
||||
from .model_zoo.fused_vision_transformer import Fused_ViT_small_patch16_224, Fused_ViT_base_patch16_224, Fused_ViT_base_patch16_384, Fused_ViT_base_patch32_384, Fused_ViT_large_patch16_224, Fused_ViT_large_patch16_384, Fused_ViT_large_patch32_384
|
||||
from .legendary_models.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
|
||||
from .model_zoo.swin_transformer_v2 import SwinTransformerV2_tiny_patch4_window8_256, SwinTransformerV2_small_patch4_window8_256, SwinTransformerV2_base_patch4_window8_256, SwinTransformerV2_tiny_patch4_window16_256, SwinTransformerV2_small_patch4_window16_256, SwinTransformerV2_base_patch4_window16_256, SwinTransformerV2_base_patch4_window24_384, SwinTransformerV2_large_patch4_window16_256, SwinTransformerV2_large_patch4_window24_384
|
||||
from .model_zoo.cswin_transformer import CSWinTransformer_tiny_224, CSWinTransformer_small_224, CSWinTransformer_base_224, CSWinTransformer_large_224, CSWinTransformer_base_384, CSWinTransformer_large_384
|
||||
|
|
|
@ -0,0 +1,802 @@
|
|||
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# reference: https://arxiv.org/abs/2010.11929
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.framework import LayerHelper, in_dynamic_mode
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
||||
from paddle.incubate.nn.functional import (
|
||||
fused_layer_norm,
|
||||
fused_linear,
|
||||
variable_length_memory_efficient_attention
|
||||
)
|
||||
from paddle.nn.quant import weight_quantize, weight_only_linear
|
||||
from ....utils.save_load import get_pretrain_state_dict, get_pretrain_state_dict_from_url
|
||||
from ....utils.import_utils import is_paddleclas_ops_available
|
||||
|
||||
if is_paddleclas_ops_available():
|
||||
from paddleclas_ops import (
|
||||
qkv_transpose_split,
|
||||
transpose_remove_padding
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The paddleclas_ops is not installed. You can read the docs and install it by hand,"
|
||||
"you can refer to: csrc/README.md"
|
||||
)
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"Fused_ViT_small_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams",
|
||||
"Fused_ViT_base_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams",
|
||||
"Fused_ViT_base_patch16_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams",
|
||||
"Fused_ViT_base_patch32_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams",
|
||||
"Fused_ViT_large_patch16_224":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams",
|
||||
"Fused_ViT_large_patch16_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_384_pretrained.pdparams",
|
||||
"Fused_ViT_large_patch32_384":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch32_384_pretrained.pdparams",
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=.02)
|
||||
normal_ = Normal
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
def to_2tuple(x):
|
||||
return tuple([x] * 2)
|
||||
|
||||
def fused_act_bias_wrapper(
|
||||
x,
|
||||
bias=None,
|
||||
dequant_scales=None,
|
||||
shift=None,
|
||||
smooth=None,
|
||||
act_method="gelu",
|
||||
compute_dtype="default",
|
||||
quant_scale=-1,
|
||||
quant_round_type=0,
|
||||
quant_max_bound=0,
|
||||
quant_min_bound=0,
|
||||
):
|
||||
if in_dynamic_mode():
|
||||
return paddle._C_ops.fused_bias_act(
|
||||
x,
|
||||
bias,
|
||||
dequant_scales,
|
||||
shift,
|
||||
smooth,
|
||||
act_method,
|
||||
compute_dtype,
|
||||
quant_scale,
|
||||
quant_round_type,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
)
|
||||
helper = LayerHelper("fused_bias_act")
|
||||
if x.dtype == "int32":
|
||||
if compute_dtype == "bf16":
|
||||
dtype = "uint16"
|
||||
elif compute_dtype == "fp16":
|
||||
dtype = "float16"
|
||||
elif compute_dtype == "fp32":
|
||||
dtype = "float32"
|
||||
out = helper.create_variable_for_type_inference(dtype=dtype)
|
||||
else:
|
||||
out = helper.create_variable_for_type_inference(dtype=x.dtype)
|
||||
|
||||
inputs = {}
|
||||
inputs["x"] = x
|
||||
if bias is not None:
|
||||
inputs["bias"] = bias
|
||||
if dequant_scales is not None:
|
||||
inputs["bias"] = dequant_scales
|
||||
|
||||
if shift is not None:
|
||||
inputs["shift"] = shift
|
||||
|
||||
if smooth is not None:
|
||||
inputs["smooth"] = smooth
|
||||
|
||||
attrs = {
|
||||
"act_method": act_method,
|
||||
"compute_dtype": compute_dtype,
|
||||
"quant_scale": quant_scale,
|
||||
"quant_round_type": quant_round_type,
|
||||
"quant_max_bound": quant_max_bound,
|
||||
"quant_min_bound": quant_min_bound,
|
||||
}
|
||||
|
||||
helper.append_op(
|
||||
type="fused_bias_act",
|
||||
inputs=inputs,
|
||||
outputs={"out": out},
|
||||
attrs=attrs,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class FusedVisionTransformer(nn.Layer):
|
||||
""" Fused Vision Transformer with support for patch input
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
class_num=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-5,
|
||||
use_weight_only=False,
|
||||
quant_type="weight_only_int8",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.dtype = self._helper.get_default_dtype()
|
||||
|
||||
self.class_num = class_num
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.epsilon = epsilon
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.depth = depth
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.norm_func = fused_layer_norm
|
||||
self.linear = fused_linear
|
||||
|
||||
self.use_weight_only = use_weight_only
|
||||
self.quant_type = quant_type
|
||||
self.create_params_type = self.get_weight_create_dtype()
|
||||
self._norm_weight_dtype = "float32"
|
||||
|
||||
if self.use_weight_only:
|
||||
assert (
|
||||
self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4"
|
||||
), "Expected quant_type equal to 'weight_only_int8' or 'weight_only_int4' \
|
||||
but received quant_type: {}".format(
|
||||
self.quant_type
|
||||
)
|
||||
self.quant_bits = int(self.quant_type[-1])
|
||||
self.weight_dtype = "int" + str(self.quant_bits)
|
||||
|
||||
self.img_size = to_2tuple(img_size)
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.patch_embed_proj = nn.Conv2D(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
num_patches = (self.img_size[1] // self.patch_size[1]) * \
|
||||
(self.img_size[0] // self.patch_size[0])
|
||||
|
||||
self.pos_embed = self.create_parameter(
|
||||
shape=(1, num_patches + 1, embed_dim), default_initializer=trunc_normal_
|
||||
)
|
||||
self.add_parameter("pos_embed", self.pos_embed)
|
||||
self.cls_token = self.create_parameter(
|
||||
shape=(1, 1, embed_dim), default_initializer=trunc_normal_
|
||||
)
|
||||
self.add_parameter("cls_token", self.cls_token)
|
||||
|
||||
self.norm1_weights, self.norm1_biases = [], []
|
||||
self.attn_qkv_weights, self.attn_qkv_biases = [], []
|
||||
self.attn_proj_weights, self.attn_proj_biases = [], []
|
||||
self.norm2_weights, self.norm2_biases = [], []
|
||||
self.mlp_fc1_weights, self.mlp_fc1_biases = [], []
|
||||
self.mlp_fc2_weights, self.mlp_fc2_biases = [], []
|
||||
|
||||
if self.use_weight_only:
|
||||
self.attn_qkv_weights_scale = []
|
||||
self.attn_proj_weights_scale = []
|
||||
self.mlp_fc1_weights_scale = []
|
||||
self.mlp_fc2_weights_scale = []
|
||||
|
||||
mlp_hidden_dim = int(embed_dim * mlp_ratio)
|
||||
self._init_weight_shape(mlp_hidden_dim)
|
||||
|
||||
for i in range(self.depth):
|
||||
norm1_weight = self.create_parameter(
|
||||
shape=self.norm1_weight_shape,
|
||||
default_initializer=ones_,
|
||||
dtype=self._norm_weight_dtype
|
||||
)
|
||||
norm1_bias = self.create_parameter(
|
||||
shape=self.norm1_bias_shape,
|
||||
default_initializer=zeros_,
|
||||
is_bias=True,
|
||||
dtype=self._norm_weight_dtype
|
||||
)
|
||||
|
||||
attn_qkv_weight = self.create_parameter(
|
||||
shape=self.attn_qkv_weight_shape,
|
||||
default_initializer=ones_,
|
||||
dtype=self.create_params_type
|
||||
)
|
||||
attn_qkv_bias = self.create_parameter(
|
||||
shape=self.attn_qkv_bias_shape,
|
||||
default_initializer=zeros_,
|
||||
is_bias=True,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
attn_proj_weight = self.create_parameter(
|
||||
shape=self.attn_proj_weight_shape,
|
||||
default_initializer=ones_,
|
||||
dtype=self.create_params_type
|
||||
)
|
||||
attn_proj_bias = self.create_parameter(
|
||||
shape=self.attn_proj_bias_shape,
|
||||
default_initializer=zeros_,
|
||||
is_bias=True,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
norm2_weight = self.create_parameter(
|
||||
shape=self.norm2_weight_shape,
|
||||
default_initializer=ones_,
|
||||
dtype=self._norm_weight_dtype
|
||||
)
|
||||
norm2_bias = self.create_parameter(
|
||||
shape=self.norm2_bias_shape,
|
||||
default_initializer=zeros_,
|
||||
is_bias=True,
|
||||
dtype=self._norm_weight_dtype
|
||||
)
|
||||
|
||||
mlp_fc1_weight = self.create_parameter(
|
||||
shape=self.mlp_fc1_weight_shape,
|
||||
default_initializer=ones_,
|
||||
dtype=self.create_params_type
|
||||
)
|
||||
mlp_fc1_bias = self.create_parameter(
|
||||
shape=self.mlp_fc1_bias_shape,
|
||||
default_initializer=zeros_,
|
||||
is_bias=True,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
mlp_fc2_weight = self.create_parameter(
|
||||
shape=self.mlp_fc2_weight_shape,
|
||||
default_initializer=ones_,
|
||||
dtype=self.create_params_type
|
||||
)
|
||||
mlp_fc2_bias = self.create_parameter(
|
||||
shape=self.mlp_fc2_bias_shape,
|
||||
default_initializer=zeros_,
|
||||
is_bias=True,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
self.norm1_weights.append(norm1_weight)
|
||||
self.norm1_biases.append(norm1_bias)
|
||||
self.attn_qkv_weights.append(attn_qkv_weight)
|
||||
self.attn_qkv_biases.append(attn_qkv_bias)
|
||||
self.attn_proj_weights.append(attn_proj_weight)
|
||||
self.attn_proj_biases.append(attn_proj_bias)
|
||||
self.norm2_weights.append(norm2_weight)
|
||||
self.norm2_biases.append(norm2_bias)
|
||||
self.mlp_fc1_weights.append(mlp_fc1_weight)
|
||||
self.mlp_fc1_biases.append(mlp_fc1_bias)
|
||||
self.mlp_fc2_weights.append(mlp_fc2_weight)
|
||||
self.mlp_fc2_biases.append(mlp_fc2_bias)
|
||||
|
||||
self.add_parameter("blocks_{}_norm1_weight".format(i), norm1_weight)
|
||||
self.add_parameter("blocks_{}_norm1_bias".format(i), norm1_bias)
|
||||
self.add_parameter("blocks_{}_attn_qkv_weight".format(i), attn_qkv_weight)
|
||||
self.add_parameter("blocks_{}_attn_qkv_bias".format(i), attn_qkv_bias)
|
||||
self.add_parameter("blocks_{}_attn_proj_weight".format(i), attn_proj_weight)
|
||||
self.add_parameter("blocks_{}_attn_proj_bias".format(i), attn_proj_bias)
|
||||
self.add_parameter("blocks_{}_norm2_weight".format(i), norm2_weight)
|
||||
self.add_parameter("blocks_{}_norm2_bias".format(i), norm2_bias)
|
||||
self.add_parameter("blocks_{}_mlp_fc1_weight".format(i), mlp_fc1_weight)
|
||||
self.add_parameter("blocks_{}_mlp_fc1_bias".format(i), mlp_fc1_bias)
|
||||
self.add_parameter("blocks_{}_mlp_fc2_weight".format(i), mlp_fc2_weight)
|
||||
self.add_parameter("blocks_{}_mlp_fc2_bias".format(i), mlp_fc2_bias)
|
||||
|
||||
if self.use_weight_only:
|
||||
attn_qkv_weight_scale = self.create_parameter(
|
||||
shape=[3 * self.num_heads * self.head_dim],
|
||||
default_initializer=zeros_,
|
||||
dtype=self.dtype,
|
||||
is_bias=False
|
||||
)
|
||||
attn_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.embed_dim],
|
||||
default_initializer=zeros_,
|
||||
dtype=self.dtype,
|
||||
is_bias=False
|
||||
)
|
||||
mlp_fc1_weight_scale = self.create_parameter(
|
||||
shape=[mlp_hidden_dim],
|
||||
default_initializer=zeros_,
|
||||
dtype=self.dtype,
|
||||
is_bias=False
|
||||
)
|
||||
mlp_fc2_weight_scale = self.create_parameter(
|
||||
shape=[self.embed_dim],
|
||||
default_initializer=zeros_,
|
||||
dtype=self.dtype,
|
||||
is_bias=False
|
||||
)
|
||||
|
||||
self.attn_qkv_weights_scale.append(attn_qkv_weight_scale)
|
||||
self.attn_proj_weights_scale.append(attn_proj_weight_scale)
|
||||
self.mlp_fc1_weights_scale.append(mlp_fc1_weight_scale)
|
||||
self.mlp_fc2_weights_scale.append(mlp_fc2_weight_scale)
|
||||
|
||||
self.add_parameter("blocks_{}_attn_qkv_weight_scale".format(i), attn_qkv_weight_scale)
|
||||
self.add_parameter("blocks_{}_attn_proj_weight_scale".format(i), attn_proj_weight_scale)
|
||||
self.add_parameter("blocks_{}_mlp_fc1_weight_scale".format(i), mlp_fc1_weight_scale)
|
||||
self.add_parameter("blocks_{}_mlp_fc2_weight_scale".format(i), mlp_fc2_weight_scale)
|
||||
|
||||
self.norm_weight = self.create_parameter(
|
||||
shape=[embed_dim],
|
||||
default_initializer=ones_,
|
||||
dtype=self._norm_weight_dtype
|
||||
)
|
||||
self.norm_bias = self.create_parameter(
|
||||
shape=[embed_dim],
|
||||
is_bias=True,
|
||||
default_initializer=zeros_,
|
||||
dtype=self._norm_weight_dtype
|
||||
)
|
||||
self.head_weight = self.create_parameter(
|
||||
shape=[embed_dim, class_num],
|
||||
default_initializer=ones_,
|
||||
dtype=self.dtype
|
||||
)
|
||||
self.head_bias = self.create_parameter(
|
||||
shape=[class_num],
|
||||
is_bias=True,
|
||||
default_initializer=zeros_,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
def _init_weight_shape(self, mlp_hidden_dim):
|
||||
self.norm1_weight_shape = [self.embed_dim]
|
||||
self.norm1_bias_shape = [self.embed_dim]
|
||||
self.attn_qkv_weight_shape = (
|
||||
[3 * self.num_heads * self.head_dim, self.embed_dim]
|
||||
if self.use_weight_only
|
||||
else [self.embed_dim, 3 * self.num_heads * self.head_dim, ]
|
||||
)
|
||||
self.attn_qkv_bias_shape = [3 * self.num_heads * self.head_dim]
|
||||
self.attn_proj_weight_shape = (
|
||||
[self.embed_dim, self.num_heads * self.head_dim]
|
||||
if self.use_weight_only
|
||||
else [self.num_heads * self.head_dim, self.embed_dim]
|
||||
)
|
||||
self.attn_proj_bias_shape = [self.num_heads * self.head_dim]
|
||||
self.norm2_weight_shape = [self.embed_dim]
|
||||
self.norm2_bias_shape = [self.embed_dim]
|
||||
self.mlp_fc1_weight_shape = (
|
||||
[mlp_hidden_dim, self.embed_dim]
|
||||
if self.use_weight_only
|
||||
else [self.embed_dim, mlp_hidden_dim]
|
||||
)
|
||||
self.mlp_fc1_bias_shape = [mlp_hidden_dim]
|
||||
self.mlp_fc2_weight_shape = (
|
||||
[self.embed_dim, mlp_hidden_dim]
|
||||
if self.use_weight_only
|
||||
else [mlp_hidden_dim, self.embed_dim]
|
||||
)
|
||||
self.mlp_fc2_bias_shape = [self.embed_dim]
|
||||
|
||||
if self.use_weight_only and self.quant_bits == 4:
|
||||
self.attn_qkv_weight_shape[0] //= 2
|
||||
self.attn_proj_weight_shape[0] //= 2
|
||||
self.mlp_fc1_weight_shape[0] //= 2
|
||||
self.mlp_fc2_weight_shape[0] //= 2
|
||||
|
||||
def get_weight_create_dtype(self):
|
||||
if self.use_weight_only:
|
||||
return "int8"
|
||||
else:
|
||||
return self.dtype
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict):
|
||||
self.pos_embed.set_value(state_dict["pos_embed"].astype(self.dtype))
|
||||
self.cls_token.set_value(state_dict["cls_token"].astype(self.dtype))
|
||||
self.patch_embed_proj.weight.set_value(state_dict["patch_embed.proj.weight"].astype(self.dtype))
|
||||
self.patch_embed_proj.bias.set_value(state_dict["patch_embed.proj.bias"].astype(self.dtype))
|
||||
for i in range(self.depth):
|
||||
self.norm1_weights[i].set_value(state_dict["blocks.{}.norm1.weight".format(i)].astype(self._norm_weight_dtype))
|
||||
self.norm1_biases[i].set_value(state_dict["blocks.{}.norm1.bias".format(i)].astype(self._norm_weight_dtype))
|
||||
|
||||
if self.use_weight_only:
|
||||
attn_qkv_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.attn.qkv.weight".format(i)].astype(self.dtype))
|
||||
attn_qkv_quanted_weight_tensor, attn_qkv_weight_scale_tensor = weight_quantize(
|
||||
attn_qkv_weight_tensor, algo=self.quant_type
|
||||
)
|
||||
self.attn_qkv_weights[i].set_value(attn_qkv_quanted_weight_tensor)
|
||||
self.attn_qkv_weights_scale[i].set_value(attn_qkv_weight_scale_tensor)
|
||||
else:
|
||||
self.attn_qkv_weights[i].set_value(state_dict["blocks.{}.attn.qkv.weight".format(i)].astype(self.dtype))
|
||||
self.attn_qkv_biases[i].set_value(state_dict["blocks.{}.attn.qkv.bias".format(i)].astype(self.dtype))
|
||||
|
||||
if self.use_weight_only:
|
||||
attn_proj_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.attn.proj.weight".format(i)].astype(self.dtype))
|
||||
attn_proj_quanted_weight_tensor, attn_proj_weight_scale_tensor = weight_quantize(
|
||||
attn_proj_weight_tensor, algo=self.quant_type
|
||||
)
|
||||
self.attn_proj_weights[i].set_value(attn_proj_quanted_weight_tensor)
|
||||
self.attn_proj_weights_scale[i].set_value(attn_proj_weight_scale_tensor)
|
||||
else:
|
||||
self.attn_proj_weights[i].set_value(state_dict["blocks.{}.attn.proj.weight".format(i)].astype(self.dtype))
|
||||
self.attn_proj_biases[i].set_value(state_dict["blocks.{}.attn.proj.bias".format(i)].astype(self.dtype))
|
||||
|
||||
self.norm2_weights[i].set_value(state_dict["blocks.{}.norm2.weight".format(i)].astype(self._norm_weight_dtype))
|
||||
self.norm2_biases[i].set_value(state_dict["blocks.{}.norm2.bias".format(i)].astype(self._norm_weight_dtype))
|
||||
|
||||
if self.use_weight_only:
|
||||
mlp_fc1_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.mlp.fc1.weight".format(i)].astype(self.dtype))
|
||||
mlp_fc1_quanted_weight_tensor, mlp_fc1_weight_scale_tensor = weight_quantize(
|
||||
mlp_fc1_weight_tensor, algo=self.quant_type
|
||||
)
|
||||
self.mlp_fc1_weights[i].set_value(mlp_fc1_quanted_weight_tensor)
|
||||
self.mlp_fc1_weights_scale[i].set_value(mlp_fc1_weight_scale_tensor)
|
||||
else:
|
||||
self.mlp_fc1_weights[i].set_value(state_dict["blocks.{}.mlp.fc1.weight".format(i)].astype(self.dtype))
|
||||
self.mlp_fc1_biases[i].set_value(state_dict["blocks.{}.mlp.fc1.bias".format(i)].astype(self.dtype))
|
||||
|
||||
if self.use_weight_only:
|
||||
mlp_fc2_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.mlp.fc2.weight".format(i)].astype(self.dtype))
|
||||
mlp_fc2_quanted_weight_tensor, mlp_fc2_weight_scale_tensor = weight_quantize(
|
||||
mlp_fc2_weight_tensor, algo=self.quant_type
|
||||
)
|
||||
self.mlp_fc2_weights[i].set_value(mlp_fc2_quanted_weight_tensor)
|
||||
self.mlp_fc2_weights_scale[i].set_value(mlp_fc2_weight_scale_tensor)
|
||||
else:
|
||||
self.mlp_fc2_weights[i].set_value(state_dict["blocks.{}.mlp.fc2.weight".format(i)].astype(self.dtype))
|
||||
self.mlp_fc2_biases[i].set_value(state_dict["blocks.{}.mlp.fc2.bias".format(i)].astype(self.dtype))
|
||||
|
||||
self.norm_weight.set_value(state_dict["norm.weight"].astype(self._norm_weight_dtype))
|
||||
self.norm_bias.set_value(state_dict["norm.bias"].astype(self._norm_weight_dtype))
|
||||
self.head_weight.set_value(state_dict["head.weight"].astype(self.dtype))
|
||||
self.head_bias.set_value(state_dict["head.bias"].astype(self.dtype))
|
||||
|
||||
def compute_layernorm_before_qkv(self, src, i):
|
||||
if i == 0:
|
||||
ln_out = self.norm_func(src, self.norm1_weights[i], self.norm1_biases[i], self.epsilon)
|
||||
else:
|
||||
ln_out = src
|
||||
|
||||
return ln_out
|
||||
|
||||
def compute_qkv_linear(self, ln_out, i):
|
||||
if self.use_weight_only:
|
||||
return weight_only_linear(
|
||||
ln_out,
|
||||
weight=self.attn_qkv_weights[i],
|
||||
bias=self.attn_qkv_biases[i],
|
||||
weight_scale=self.attn_qkv_weights_scale[i],
|
||||
weight_dtype=self.weight_dtype
|
||||
)
|
||||
|
||||
if float(paddle.version.cuda()) < 11.6:
|
||||
qkv_out = paddle.matmul(ln_out, self.attn_qkv_weights[i])
|
||||
if self.attn_qkv_biases[i] is not None:
|
||||
qkv_out = paddle.add(qkv_out, self.attn_qkv_biases[i])
|
||||
return qkv_out
|
||||
else:
|
||||
return self.linear(ln_out, self.attn_qkv_weights[i], self.attn_qkv_biases[i])
|
||||
|
||||
def compute_qkv(self, src, residual_input, i):
|
||||
ln_out = self.compute_layernorm_before_qkv(src, i)
|
||||
qkv_out = self.compute_qkv_linear(ln_out, i)
|
||||
return qkv_out, residual_input
|
||||
|
||||
def compute_fmha(self, qkv_out, padding_offset, seq_lens, input_ids, i):
|
||||
q_out, k_out, v_out = qkv_transpose_split(
|
||||
qkv_out, padding_offset, seq_lens, input_ids, self.num_heads, self.head_dim
|
||||
)
|
||||
# cutlass fmha
|
||||
qktv_out = variable_length_memory_efficient_attention(
|
||||
q_out,
|
||||
k_out,
|
||||
v_out,
|
||||
seq_lens,
|
||||
seq_lens,
|
||||
None,
|
||||
scale=self.scale
|
||||
)
|
||||
return transpose_remove_padding(qktv_out, seq_lens, padding_offset)
|
||||
|
||||
def compute_out_linear(self, fmha_out, i):
|
||||
if self.use_weight_only:
|
||||
return weight_only_linear(
|
||||
fmha_out,
|
||||
weight=self.attn_proj_weights[i],
|
||||
weight_scale=self.attn_proj_weights_scale[i],
|
||||
weight_dtype=self.weight_dtype
|
||||
)
|
||||
|
||||
return paddle.matmul(fmha_out, self.attn_proj_weights[i])
|
||||
|
||||
def compute_attn(self, qkv_out, padding_offset, seq_lens, input_ids, i):
|
||||
fmha_out = self.compute_fmha(qkv_out, padding_offset, seq_lens, input_ids, i)
|
||||
out_linear_out = self.compute_out_linear(fmha_out, i)
|
||||
return out_linear_out
|
||||
|
||||
def compute_ffn_layernorm(self, out_linear_out, residual_input, i):
|
||||
"""
|
||||
tmp_out = layernorm(out_linear_out + attn_proj_biases[i] + residual_input)
|
||||
"""
|
||||
norm_out = self.norm_func(
|
||||
out_linear_out,
|
||||
norm_weight=self.norm2_weights[i],
|
||||
norm_bias=self.norm2_biases[i],
|
||||
epsilon=self.epsilon,
|
||||
bias=self.attn_proj_biases[i],
|
||||
residual=residual_input,
|
||||
)
|
||||
tmp_out, residual_input = norm_out[0], norm_out[1]
|
||||
return tmp_out, residual_input
|
||||
|
||||
def compute_ffn1(self, tmp_out, i):
|
||||
if self.use_weight_only:
|
||||
return weight_only_linear(
|
||||
tmp_out,
|
||||
weight=self.mlp_fc1_weights[i],
|
||||
weight_scale=self.mlp_fc1_weights_scale[i],
|
||||
weight_dtype=self.weight_dtype,
|
||||
)
|
||||
|
||||
return paddle.matmul(tmp_out, self.mlp_fc1_weights[i])
|
||||
|
||||
def compute_ffn2(self, ffn1_out, i):
|
||||
if self.use_weight_only:
|
||||
return weight_only_linear(
|
||||
ffn1_out,
|
||||
weight=self.mlp_fc2_weights[i],
|
||||
weight_scale=self.mlp_fc2_weights_scale[i],
|
||||
weight_dtype=self.weight_dtype,
|
||||
)
|
||||
|
||||
return paddle.matmul(ffn1_out, self.mlp_fc2_weights[i])
|
||||
|
||||
def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layers):
|
||||
if i != num_layers - 1:
|
||||
norm_out = self.norm_func(
|
||||
ffn2_out,
|
||||
norm_weight=self.norm1_weights[i + 1],
|
||||
norm_bias=self.norm1_biases[i + 1],
|
||||
epsilon=self.epsilon,
|
||||
bias=self.mlp_fc2_biases[i],
|
||||
residual=residual_input
|
||||
)
|
||||
tmp_out, residual_input = norm_out[0], norm_out[1]
|
||||
else:
|
||||
tmp_out = self.norm_func(
|
||||
ffn2_out,
|
||||
norm_weight=self.norm_weight,
|
||||
norm_bias=self.norm_bias,
|
||||
epsilon=self.epsilon,
|
||||
bias=self.mlp_fc2_biases[i],
|
||||
residual=residual_input
|
||||
)[0]
|
||||
return tmp_out, residual_input
|
||||
|
||||
def compute_head_linear(self, ln_out):
|
||||
if float(paddle.version.cuda()) < 11.6:
|
||||
qkv_out = paddle.matmul(ln_out, self.head_weight)
|
||||
if self.head_bias is not None:
|
||||
qkv_out = paddle.add(qkv_out, self.head_bias)
|
||||
return qkv_out
|
||||
else:
|
||||
return self.linear(ln_out, self.head_weight, self.head_bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.patch_embed_proj(x).flatten(2).transpose((0, 2, 1))
|
||||
|
||||
cls_tokens = self.cls_token.expand((B, -1, -1))
|
||||
x = paddle.concat((cls_tokens, x), axis=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
batch, seq_len, _ = x.shape
|
||||
padding_offset = paddle.zeros([seq_len * batch], dtype='int32')
|
||||
seq_lens = paddle.full([batch], seq_len, dtype='int32')
|
||||
input_ids = paddle.full([batch, seq_len], 0, dtype='int32')
|
||||
|
||||
x = x.reshape([-1, x.shape[-1]])
|
||||
residual_input = x
|
||||
for i in range(self.depth):
|
||||
qkv_out, residual_input = self.compute_qkv(x, residual_input, i)
|
||||
out_linear_out = self.compute_attn(
|
||||
qkv_out,
|
||||
padding_offset,
|
||||
seq_lens,
|
||||
input_ids,
|
||||
i
|
||||
)
|
||||
|
||||
# qkv proj linear + layernorm2
|
||||
tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i)
|
||||
|
||||
# mlp ffn1 matmul
|
||||
ffn1_out = self.compute_ffn1(tmp_out, i)
|
||||
ffn1_out = fused_act_bias_wrapper(ffn1_out, self.mlp_fc1_biases[i])
|
||||
|
||||
# mlp ffn2 matmul
|
||||
ffn2_out = self.compute_ffn2(ffn1_out, i)
|
||||
|
||||
# layernorm1 + residual_add_bias
|
||||
tmp_out, residual_input = self.compute_bias_residual_layernorm(ffn2_out, residual_input, i, self.depth)
|
||||
x = tmp_out
|
||||
x = x.reshape((batch, seq_len, -1))
|
||||
index = paddle.zeros([1], dtype="int32")
|
||||
x = paddle.index_select(x, index, axis=1).reshape((batch, self.embed_dim))
|
||||
x = self.compute_head_linear(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
weight_state_dict = get_pretrain_state_dict_from_url(model_url, use_ssld=use_ssld)
|
||||
model.set_state_dict(weight_state_dict)
|
||||
elif isinstance(pretrained, str):
|
||||
weight_state_dict = get_pretrain_state_dict(pretrained)
|
||||
model.set_state_dict(weight_state_dict)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def Fused_ViT_small_patch16_224(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=8,
|
||||
num_heads=8,
|
||||
mlp_ratio=3,
|
||||
qk_scale=768**-0.5,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_small_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def Fused_ViT_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_base_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def Fused_ViT_base_patch16_384(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_base_patch16_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def Fused_ViT_base_patch32_384(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=32,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_base_patch32_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def Fused_ViT_large_patch16_224(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_large_patch16_224"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def Fused_ViT_large_patch16_384(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_large_patch16_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def Fused_ViT_large_patch32_384(pretrained=False, use_ssld=False, **kwargs):
|
||||
model = FusedVisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=32,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
epsilon=1e-6,
|
||||
**kwargs)
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
MODEL_URLS["Fused_ViT_large_patch32_384"],
|
||||
use_ssld=use_ssld)
|
||||
return model
|
|
@ -0,0 +1,33 @@
|
|||
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util
|
||||
|
||||
def is_package_available(package_name: str) -> bool:
|
||||
"""check if the package is avaliable
|
||||
Args:
|
||||
package_name (str): the installed package name
|
||||
Returns:
|
||||
bool: the existence of installed package
|
||||
"""
|
||||
package_spec = importlib.util.find_spec(package_name)
|
||||
return package_spec is not None and package_spec.has_location
|
||||
|
||||
|
||||
def is_paddleclas_ops_available() -> bool:
|
||||
"""check if `paddleclas_ops` ia avaliable
|
||||
Returns:
|
||||
bool: if `paddleclas_ops` is avaliable
|
||||
"""
|
||||
return is_package_available("paddleclas_ops")
|
|
@ -117,6 +117,23 @@ def load_distillation_model(model, pretrained_model):
|
|||
pretrained_model))
|
||||
|
||||
|
||||
def get_pretrain_state_dict(path=None):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {}.pdparams does not "
|
||||
"exists.".format(path))
|
||||
param_state_dict = paddle.load(path + ".pdparams")
|
||||
return param_state_dict
|
||||
|
||||
|
||||
def get_pretrain_state_dict_from_url(pretrained_url, use_ssld=False):
|
||||
if use_ssld:
|
||||
pretrained_url = pretrained_url.replace("_pretrained",
|
||||
"_ssld_pretrained")
|
||||
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
|
||||
".pdparams", "")
|
||||
return get_pretrain_state_dict(path=local_weight_path)
|
||||
|
||||
|
||||
def init_model(config,
|
||||
net,
|
||||
optimizer=None,
|
||||
|
|
Loading…
Reference in New Issue