TensorRT dot product attention ops (#949)
* add detr support * fix softmax * add placeholder * add implement * add docs and ut * update testcase * update docs * update docspull/865/head
parent
e21cad84e0
commit
9541be9c0b
|
@ -0,0 +1,183 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "scaled_dot_product_attention.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "scaled_dot_product_attention_kernel.hpp"
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
using namespace nvinfer1;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"ScaledDotProductAttentionTRT"};
|
||||
} // namespace
|
||||
|
||||
ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string &name)
|
||||
: TRTPluginBase(name), mask_dim(0) {}
|
||||
|
||||
ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void *data,
|
||||
size_t length)
|
||||
: TRTPluginBase(name), mask_dim(0) {}
|
||||
|
||||
ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT {
|
||||
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(mLayerName);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
if (outputIndex == 0) return inputs[0];
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 3;
|
||||
ret.d[0] = inputs[0].d[0];
|
||||
ret.d[1] = inputs[0].d[1];
|
||||
ret.d[2] = inputs[1].d[1];
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttentionTRT::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 0) {
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
} else {
|
||||
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
|
||||
}
|
||||
}
|
||||
|
||||
// Attach the plugin object to an execution context and grant the plugin the
|
||||
// access to some context resource.
|
||||
void ScaledDotProductAttentionTRT::attachToContext(cudnnContext *cudnnContext,
|
||||
cublasContext *cublasContext,
|
||||
IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {
|
||||
_cublas_handle = cublasContext;
|
||||
_cudnn_handle = cudnnContext;
|
||||
cudnnCreateTensorDescriptor(&_x_desc);
|
||||
cudnnCreateTensorDescriptor(&_y_desc);
|
||||
cudnnCreateTensorDescriptor(&_mask_desc);
|
||||
}
|
||||
|
||||
// Detach the plugin object from its execution context.
|
||||
void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT {
|
||||
cudnnDestroyTensorDescriptor(_y_desc);
|
||||
cudnnDestroyTensorDescriptor(_x_desc);
|
||||
cudnnDestroyTensorDescriptor(_mask_desc);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (nbInputs != 4) {
|
||||
mask_dim = 0;
|
||||
} else {
|
||||
mask_dim = in[3].desc.dims.nbDims;
|
||||
}
|
||||
}
|
||||
|
||||
int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs,
|
||||
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1;
|
||||
if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1;
|
||||
int B = inputDesc[0].dims.d[0]; // batch * heads
|
||||
int Nt = inputDesc[0].dims.d[1];
|
||||
int Ns = inputDesc[1].dims.d[1];
|
||||
int E = inputDesc[0].dims.d[2]; // embeding size
|
||||
|
||||
const void *query = inputs[0];
|
||||
const void *key = inputs[1];
|
||||
const void *value = inputs[2];
|
||||
const void *mask = nullptr;
|
||||
|
||||
int mask_dims[3];
|
||||
mask_dims[0] = 0;
|
||||
if (mask_dim > 0) {
|
||||
mask = inputs[3];
|
||||
// check if mask need broadcast
|
||||
if (mask_dim == 2) {
|
||||
mask_dims[0] = 1;
|
||||
mask_dims[1] = inputDesc[3].dims.d[0];
|
||||
mask_dims[2] = inputDesc[3].dims.d[1];
|
||||
} else {
|
||||
mask_dims[0] = inputDesc[3].dims.d[0];
|
||||
mask_dims[1] = inputDesc[3].dims.d[1];
|
||||
mask_dims[2] = inputDesc[3].dims.d[2];
|
||||
}
|
||||
}
|
||||
|
||||
void *output = outputs[0];
|
||||
void *attn = outputs[1];
|
||||
|
||||
auto data_type = inputDesc[0].type;
|
||||
cudnnDataType_t cudnn_dtype{};
|
||||
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
dot_product_attention_impl<float>((float *)query, (float *)key, (float *)value, (float *)mask,
|
||||
(float *)attn, (float *)output, B, Nt, Ns, E, &mask_dims[0],
|
||||
_x_desc, _y_desc, _mask_desc, cudnn_dtype, stream,
|
||||
_cublas_handle, _cudnn_handle);
|
||||
break;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||
|
||||
const char *ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT { return 2; }
|
||||
|
||||
size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT { return 0; }
|
||||
|
||||
void ScaledDotProductAttentionTRT::serialize(void *buffer) const TRT_NOEXCEPT {}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {}
|
||||
|
||||
const char *ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(name);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator);
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,73 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
|
||||
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_base.hpp"
|
||||
|
||||
namespace mmdeploy {
|
||||
class ScaledDotProductAttentionTRT : public TRTPluginBase {
|
||||
public:
|
||||
ScaledDotProductAttentionTRT(const std::string &name);
|
||||
|
||||
ScaledDotProductAttentionTRT(const std::string name, const void *data, size_t length);
|
||||
|
||||
ScaledDotProductAttentionTRT() = delete;
|
||||
|
||||
~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override;
|
||||
|
||||
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
|
||||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
|
||||
TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
void attachToContext(cudnnContext *cudnn, cublasContext *cublas,
|
||||
nvinfer1::IGpuAllocator *allocator) TRT_NOEXCEPT override;
|
||||
void detachFromContext() TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
int mask_dim;
|
||||
cublasHandle_t _cublas_handle{};
|
||||
cudnnHandle_t _cudnn_handle{};
|
||||
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{};
|
||||
};
|
||||
|
||||
class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
ScaledDotProductAttentionTRTCreator();
|
||||
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmdeploy
|
||||
#endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "scaled_dot_product_attention_kernel.hpp"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
template <typename scalar_t>
|
||||
cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const scalar_t* alpha, const scalar_t* A, int lda,
|
||||
long long int strideA, const scalar_t* B, int ldb,
|
||||
long long int strideB, const scalar_t* beta,
|
||||
scalar_t* C, int ldc, long long int strideC,
|
||||
int batchCount);
|
||||
|
||||
template <>
|
||||
cublasStatus_t cublasgemmStridedBatchedWrap<float>(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float* alpha, const float* A, int lda,
|
||||
long long int strideA, const float* B, int ldb,
|
||||
long long int strideB, const float* beta,
|
||||
float* C, int ldc, long long int strideC,
|
||||
int batchCount) {
|
||||
return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
|
||||
strideB, beta, C, ldc, strideC, batchCount);
|
||||
}
|
||||
|
||||
template <>
|
||||
cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const __half* alpha, const __half* A, int lda,
|
||||
long long int strideA, const __half* B, int ldb,
|
||||
long long int strideB, const __half* beta,
|
||||
__half* C, int ldc, long long int strideC,
|
||||
int batchCount) {
|
||||
return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
|
||||
strideB, beta, C, ldc, strideC, batchCount);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value,
|
||||
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B,
|
||||
int Nt, int Ns, int E, const int* mask_dims,
|
||||
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
|
||||
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype,
|
||||
cudaStream_t stream, cublasHandle_t cublas_handle,
|
||||
cudnnHandle_t cudnn_handle) {
|
||||
{
|
||||
// Q @ K
|
||||
const int m = Ns;
|
||||
const int n = Nt;
|
||||
const int k = E;
|
||||
const auto alpha = scalar_t(1.0f / sqrt(float(E)));
|
||||
const auto beta = scalar_t(0);
|
||||
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k,
|
||||
Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B);
|
||||
}
|
||||
|
||||
if (mask_dims != nullptr && mask_dims[0] != 0) {
|
||||
const auto alpha = scalar_t(1);
|
||||
const auto beta = scalar_t(1);
|
||||
cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0],
|
||||
mask_dims[1], mask_dims[2]);
|
||||
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns);
|
||||
cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn);
|
||||
}
|
||||
|
||||
{
|
||||
// softmax attention
|
||||
const auto alpha = scalar_t(1);
|
||||
const auto beta = scalar_t(0);
|
||||
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1);
|
||||
cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1);
|
||||
cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha,
|
||||
x_desc, attn, &beta, y_desc, attn);
|
||||
}
|
||||
|
||||
{
|
||||
// attn @ v
|
||||
const int m = E;
|
||||
const int n = Nt;
|
||||
const int k = Ns;
|
||||
const auto alpha = scalar_t(1);
|
||||
const auto beta = scalar_t(0);
|
||||
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m,
|
||||
Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m,
|
||||
Nt * E, B);
|
||||
}
|
||||
}
|
||||
|
||||
template void dot_product_attention_impl<float>(
|
||||
const float* query, const float* key, const float* value, const float* mask, float* attn,
|
||||
float* output, int B, int Nt, int Ns, int E, const int* mask_dims,
|
||||
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
|
||||
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream,
|
||||
cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle);
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
|
||||
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value,
|
||||
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B,
|
||||
int Nt, int Ns, int E, const int* mask_dims,
|
||||
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
|
||||
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype,
|
||||
cudaStream_t stream, cublasHandle_t cublas_handle,
|
||||
cudnnHandle_t cudnn_handle);
|
||||
|
||||
#endif
|
|
@ -57,6 +57,12 @@
|
|||
- [Inputs](#inputs-8)
|
||||
- [Outputs](#outputs-8)
|
||||
- [Type Constraints](#type-constraints-8)
|
||||
- [ScaledDotProductAttentionTRT](#scaleddotproductattentiontrt)
|
||||
- [Description](#description-9)
|
||||
- [Parameters](#parameters-9)
|
||||
- [Inputs](#inputs-9)
|
||||
- [Outputs](#outputs-9)
|
||||
- [Type Constraints](#type-constraints-9)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -405,3 +411,39 @@ Generate the anchors for object detection task.
|
|||
|
||||
- T:tensor(float32, Linear)
|
||||
- TAny: Any
|
||||
|
||||
### ScaledDotProductAttentionTRT
|
||||
|
||||
#### Description
|
||||
|
||||
Dot product attention used to support multihead attention, read [Attention Is All You Need](https://arxiv.org/abs/1706.03762?context=cs) for more detail.
|
||||
|
||||
#### Parameters
|
||||
|
||||
None
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>query; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
|
||||
<dt><tt>inputs[1]</tt>: T</dt>
|
||||
<dd>key; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
|
||||
<dt><tt>inputs[2]</tt>: T</dt>
|
||||
<dd>value; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
|
||||
<dt><tt>inputs[3]</tt>: T</dt>
|
||||
<dd>mask; 2-D/3-D tensor with shape [sequence_length, sequence_length] or [batch_size, sequence_length, sequence_length]. optional.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>3-D tensor of shape [batch_size, sequence_length, embedding_size]. `softmax(q@k.T)@v`</dd>
|
||||
<dt><tt>outputs[1]</tt>: T</dt>
|
||||
<dd>3-D tensor of shape [batch_size, sequence_length, sequence_length]. `softmax(q@k.T)`</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
<!-- TOC -->
|
||||
|
||||
- [TensorRT Ops](#tensorrt-ops)
|
||||
- [TRT 自定义算子](#trt-自定义算子)
|
||||
- [TRTBatchedNMS](#trtbatchednms)
|
||||
- [Description](#description)
|
||||
- [Parameters](#parameters)
|
||||
|
@ -57,6 +57,12 @@
|
|||
- [Inputs](#inputs-8)
|
||||
- [Outputs](#outputs-8)
|
||||
- [Type Constraints](#type-constraints-8)
|
||||
- [ScaledDotProductAttentionTRT](#scaleddotproductattentiontrt)
|
||||
- [Description](#description-9)
|
||||
- [Parameters](#parameters-9)
|
||||
- [Inputs](#inputs-9)
|
||||
- [Outputs](#outputs-9)
|
||||
- [Type Constraints](#type-constraints-9)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -405,3 +411,39 @@ Generate the anchors for object detection task.
|
|||
|
||||
- T:tensor(float32, Linear)
|
||||
- TAny: Any
|
||||
|
||||
### ScaledDotProductAttentionTRT
|
||||
|
||||
#### Description
|
||||
|
||||
Dot product attention used to support multihead attention, read [Attention Is All You Need](https://arxiv.org/abs/1706.03762?context=cs) for more detail.
|
||||
|
||||
#### Parameters
|
||||
|
||||
None
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>query; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
|
||||
<dt><tt>inputs[1]</tt>: T</dt>
|
||||
<dd>key; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
|
||||
<dt><tt>inputs[2]</tt>: T</dt>
|
||||
<dd>value; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
|
||||
<dt><tt>inputs[3]</tt>: T</dt>
|
||||
<dd>mask; 2-D/3-D tensor with shape [sequence_length, sequence_length] or [batch_size, sequence_length, sequence_length]. optional.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>3-D tensor of shape [batch_size, sequence_length, embedding_size]. `softmax(q@k.T)@v`</dd>
|
||||
<dt><tt>outputs[1]</tt>: T</dt>
|
||||
<dd>3-D tensor of shape [batch_size, sequence_length, sequence_length]. `softmax(q@k.T)`</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
|
|
@ -9,10 +9,43 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
|
||||
class ScaledDotProductAttentionTRT(torch.autograd.Function):
|
||||
"""Caller of scale dot product attention."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
attn_mask: Optional[Tensor] = None):
|
||||
"""forward function."""
|
||||
B, Nt, E = q.shape
|
||||
q = q / math.sqrt(E)
|
||||
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
||||
attn = torch.bmm(q, k.transpose(-2, -1))
|
||||
if attn_mask is not None:
|
||||
attn += attn_mask
|
||||
|
||||
attn = attn.softmax(-1)
|
||||
|
||||
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
||||
output = torch.bmm(attn, v)
|
||||
return output, attn
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, q, k, v, mask):
|
||||
"""Symbolic function."""
|
||||
inputs = [q, k, v]
|
||||
if mask is not None:
|
||||
inputs += [mask]
|
||||
return g.op(
|
||||
'mmdeploy::ScaledDotProductAttentionTRT', *inputs, outputs=2)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional._scaled_dot_product_attention',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def _scaled_dot_product_attention__default(
|
||||
def _scaled_dot_product_attention__tensorrt(
|
||||
ctx,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
|
@ -20,35 +53,5 @@ def _scaled_dot_product_attention__default(
|
|||
attn_mask: Optional[Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Rewrite `_scaled_dot_product_attention` to enable softmax."""
|
||||
B, Nt, E = q.shape
|
||||
q = q / math.sqrt(E)
|
||||
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
||||
attn = torch.bmm(q, k.transpose(-2, -1))
|
||||
if attn_mask is not None:
|
||||
attn += attn_mask
|
||||
# add slice to enable softmax
|
||||
# TODO: Find the reason
|
||||
step = 500
|
||||
if attn.size(-1) > step:
|
||||
attn_max = attn[..., :step].max(-1, keepdim=True)[0]
|
||||
for i in range(step, attn.size(-1), step):
|
||||
attn_max_new = attn[..., i:i + step].max(-1, keepdim=True)[0]
|
||||
attn_max = attn_max.where(attn_max > attn_max_new, attn_max_new)
|
||||
else:
|
||||
attn_max = attn.max(-1, keepdim=True)[0]
|
||||
|
||||
attn = attn - attn_max
|
||||
attn_exp = attn.exp()
|
||||
if attn_exp.size(-1) > step:
|
||||
attn_sum = attn_exp[..., :step].sum(-1, keepdim=True)
|
||||
for i in range(step, attn_exp.size(-1), step):
|
||||
attn_sum_new = attn_exp[..., i:i + step].sum(-1, keepdim=True)
|
||||
attn_sum += attn_sum_new
|
||||
else:
|
||||
attn_sum = attn_exp.sum(-1, keepdim=True)
|
||||
attn = attn_exp / attn_sum
|
||||
|
||||
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
||||
output = torch.bmm(attn, v)
|
||||
return output, attn
|
||||
"""Rewrite for custom ops."""
|
||||
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)
|
||||
|
|
|
@ -1096,3 +1096,31 @@ def test_trt_grid_priors(backend, strides, input_list=None, save_dir=None):
|
|||
3: 'w'
|
||||
}),
|
||||
save_dir=save_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||
def test_dot_product_attention(backend, save_dir=None):
|
||||
backend.check_env()
|
||||
|
||||
B = 2
|
||||
Nt = 4
|
||||
Ns = 4
|
||||
E = 2
|
||||
query = torch.rand(B, Nt, E).cuda()
|
||||
key = torch.rand(B, Ns, E).cuda()
|
||||
value = torch.rand(B, Ns, E).cuda()
|
||||
|
||||
model = torch.nn.MultiheadAttention(E, 2).cuda()
|
||||
|
||||
with RewriterContext(
|
||||
Config({'backend_config': {
|
||||
'type': backend.backend_name
|
||||
}}),
|
||||
backend=backend.backend_name,
|
||||
opset=11):
|
||||
backend.run_and_validate(
|
||||
model, [query, key, value],
|
||||
'dot_product_attention',
|
||||
input_names=['query', 'key', 'value'],
|
||||
output_names=['out', 'attn'],
|
||||
save_dir=save_dir)
|
||||
|
|
Loading…
Reference in New Issue