TensorRT dot product attention ops (#949)

* add detr support

* fix softmax

* add placeholder

* add implement

* add docs and ut

* update testcase

* update docs

* update docs
pull/865/head
q.yao 2022-09-05 18:25:39 +08:00 committed by GitHub
parent e21cad84e0
commit 9541be9c0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 525 additions and 34 deletions

View File

@ -0,0 +1,183 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "scaled_dot_product_attention.hpp"
#include <assert.h>
#include <chrono>
#include "scaled_dot_product_attention_kernel.hpp"
#include "trt_serialize.hpp"
using namespace nvinfer1;
namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"ScaledDotProductAttentionTRT"};
} // namespace
ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string &name)
: TRTPluginBase(name), mask_dim(0) {}
ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void *data,
size_t length)
: TRTPluginBase(name), mask_dim(0) {}
ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {}
nvinfer1::IPluginV2DynamicExt *ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT {
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
if (outputIndex == 0) return inputs[0];
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = inputs[1].d[1];
return ret;
}
bool ScaledDotProductAttentionTRT::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
}
}
// Attach the plugin object to an execution context and grant the plugin the
// access to some context resource.
void ScaledDotProductAttentionTRT::attachToContext(cudnnContext *cudnnContext,
cublasContext *cublasContext,
IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {
_cublas_handle = cublasContext;
_cudnn_handle = cudnnContext;
cudnnCreateTensorDescriptor(&_x_desc);
cudnnCreateTensorDescriptor(&_y_desc);
cudnnCreateTensorDescriptor(&_mask_desc);
}
// Detach the plugin object from its execution context.
void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT {
cudnnDestroyTensorDescriptor(_y_desc);
cudnnDestroyTensorDescriptor(_x_desc);
cudnnDestroyTensorDescriptor(_mask_desc);
}
void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {
if (nbInputs != 4) {
mask_dim = 0;
} else {
mask_dim = in[3].desc.dims.nbDims;
}
}
int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs,
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1;
if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1;
int B = inputDesc[0].dims.d[0]; // batch * heads
int Nt = inputDesc[0].dims.d[1];
int Ns = inputDesc[1].dims.d[1];
int E = inputDesc[0].dims.d[2]; // embeding size
const void *query = inputs[0];
const void *key = inputs[1];
const void *value = inputs[2];
const void *mask = nullptr;
int mask_dims[3];
mask_dims[0] = 0;
if (mask_dim > 0) {
mask = inputs[3];
// check if mask need broadcast
if (mask_dim == 2) {
mask_dims[0] = 1;
mask_dims[1] = inputDesc[3].dims.d[0];
mask_dims[2] = inputDesc[3].dims.d[1];
} else {
mask_dims[0] = inputDesc[3].dims.d[0];
mask_dims[1] = inputDesc[3].dims.d[1];
mask_dims[2] = inputDesc[3].dims.d[2];
}
}
void *output = outputs[0];
void *attn = outputs[1];
auto data_type = inputDesc[0].type;
cudnnDataType_t cudnn_dtype{};
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
dot_product_attention_impl<float>((float *)query, (float *)key, (float *)value, (float *)mask,
(float *)attn, (float *)output, B, Nt, Ns, E, &mask_dims[0],
_x_desc, _y_desc, _mask_desc, cudnn_dtype, stream,
_cublas_handle, _cudnn_handle);
break;
default:
return 1;
}
return 0;
}
nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}
// IPluginV2 Methods
const char *ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
const char *ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT { return 2; }
size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT { return 0; }
void ScaledDotProductAttentionTRT::serialize(void *buffer) const TRT_NOEXCEPT {}
////////////////////// creator /////////////////////////////
ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {}
const char *ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator);
} // namespace mmdeploy

View File

@ -0,0 +1,73 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace mmdeploy {
class ScaledDotProductAttentionTRT : public TRTPluginBase {
public:
ScaledDotProductAttentionTRT(const std::string &name);
ScaledDotProductAttentionTRT(const std::string name, const void *data, size_t length);
ScaledDotProductAttentionTRT() = delete;
~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override;
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
void attachToContext(cudnnContext *cudnn, cublasContext *cublas,
nvinfer1::IGpuAllocator *allocator) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
private:
int mask_dim;
cublasHandle_t _cublas_handle{};
cudnnHandle_t _cudnn_handle{};
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{};
};
class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase {
public:
ScaledDotProductAttentionTRTCreator();
const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP

View File

@ -0,0 +1,103 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>
#include <cmath>
#include <vector>
#include "common_cuda_helper.hpp"
#include "scaled_dot_product_attention_kernel.hpp"
#include "trt_plugin_helper.hpp"
template <typename scalar_t>
cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const scalar_t* alpha, const scalar_t* A, int lda,
long long int strideA, const scalar_t* B, int ldb,
long long int strideB, const scalar_t* beta,
scalar_t* C, int ldc, long long int strideC,
int batchCount);
template <>
cublasStatus_t cublasgemmStridedBatchedWrap<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float* alpha, const float* A, int lda,
long long int strideA, const float* B, int ldb,
long long int strideB, const float* beta,
float* C, int ldc, long long int strideC,
int batchCount) {
return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
strideB, beta, C, ldc, strideC, batchCount);
}
template <>
cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const __half* alpha, const __half* A, int lda,
long long int strideA, const __half* B, int ldb,
long long int strideB, const __half* beta,
__half* C, int ldc, long long int strideC,
int batchCount) {
return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
strideB, beta, C, ldc, strideC, batchCount);
}
template <typename scalar_t>
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value,
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B,
int Nt, int Ns, int E, const int* mask_dims,
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype,
cudaStream_t stream, cublasHandle_t cublas_handle,
cudnnHandle_t cudnn_handle) {
{
// Q @ K
const int m = Ns;
const int n = Nt;
const int k = E;
const auto alpha = scalar_t(1.0f / sqrt(float(E)));
const auto beta = scalar_t(0);
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k,
Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B);
}
if (mask_dims != nullptr && mask_dims[0] != 0) {
const auto alpha = scalar_t(1);
const auto beta = scalar_t(1);
cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0],
mask_dims[1], mask_dims[2]);
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns);
cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn);
}
{
// softmax attention
const auto alpha = scalar_t(1);
const auto beta = scalar_t(0);
cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1);
cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1);
cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha,
x_desc, attn, &beta, y_desc, attn);
}
{
// attn @ v
const int m = E;
const int n = Nt;
const int k = Ns;
const auto alpha = scalar_t(1);
const auto beta = scalar_t(0);
cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m,
Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m,
Nt * E, B);
}
}
template void dot_product_attention_impl<float>(
const float* query, const float* key, const float* value, const float* mask, float* attn,
float* output, int B, int Nt, int Ns, int E, const int* mask_dims,
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream,
cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle);

View File

@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cudnn.h>
template <typename scalar_t>
void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value,
const scalar_t* mask, scalar_t* attn, scalar_t* output, int B,
int Nt, int Ns, int E, const int* mask_dims,
cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc,
cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype,
cudaStream_t stream, cublasHandle_t cublas_handle,
cudnnHandle_t cudnn_handle);
#endif

View File

@ -57,6 +57,12 @@
- [Inputs](#inputs-8)
- [Outputs](#outputs-8)
- [Type Constraints](#type-constraints-8)
- [ScaledDotProductAttentionTRT](#scaleddotproductattentiontrt)
- [Description](#description-9)
- [Parameters](#parameters-9)
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
<!-- TOC -->
@ -405,3 +411,39 @@ Generate the anchors for object detection task.
- T:tensor(float32, Linear)
- TAny: Any
### ScaledDotProductAttentionTRT
#### Description
Dot product attention used to support multihead attention, read [Attention Is All You Need](https://arxiv.org/abs/1706.03762?context=cs) for more detail.
#### Parameters
None
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>query; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>key; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>value; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
<dt><tt>inputs[3]</tt>: T</dt>
<dd>mask; 2-D/3-D tensor with shape [sequence_length, sequence_length] or [batch_size, sequence_length, sequence_length]. optional.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>3-D tensor of shape [batch_size, sequence_length, embedding_size]. `softmax(q@k.T)@v`</dd>
<dt><tt>outputs[1]</tt>: T</dt>
<dd>3-D tensor of shape [batch_size, sequence_length, sequence_length]. `softmax(q@k.T)`</dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear)

View File

@ -2,7 +2,7 @@
<!-- TOC -->
- [TensorRT Ops](#tensorrt-ops)
- [TRT 自定义算子](#trt-自定义算子)
- [TRTBatchedNMS](#trtbatchednms)
- [Description](#description)
- [Parameters](#parameters)
@ -57,6 +57,12 @@
- [Inputs](#inputs-8)
- [Outputs](#outputs-8)
- [Type Constraints](#type-constraints-8)
- [ScaledDotProductAttentionTRT](#scaleddotproductattentiontrt)
- [Description](#description-9)
- [Parameters](#parameters-9)
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
<!-- TOC -->
@ -405,3 +411,39 @@ Generate the anchors for object detection task.
- T:tensor(float32, Linear)
- TAny: Any
### ScaledDotProductAttentionTRT
#### Description
Dot product attention used to support multihead attention, read [Attention Is All You Need](https://arxiv.org/abs/1706.03762?context=cs) for more detail.
#### Parameters
None
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>query; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>key; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>value; 3-D tensor with shape [batch_size, sequence_length, embedding_size].</dd>
<dt><tt>inputs[3]</tt>: T</dt>
<dd>mask; 2-D/3-D tensor with shape [sequence_length, sequence_length] or [batch_size, sequence_length, sequence_length]. optional.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>3-D tensor of shape [batch_size, sequence_length, embedding_size]. `softmax(q@k.T)@v`</dd>
<dt><tt>outputs[1]</tt>: T</dt>
<dd>3-D tensor of shape [batch_size, sequence_length, sequence_length]. `softmax(q@k.T)`</dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear)

View File

@ -9,10 +9,43 @@ from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.constants import Backend
class ScaledDotProductAttentionTRT(torch.autograd.Function):
"""Caller of scale dot product attention."""
@staticmethod
def forward(ctx,
q: Tensor,
k: Tensor,
v: Tensor,
attn_mask: Optional[Tensor] = None):
"""forward function."""
B, Nt, E = q.shape
q = q / math.sqrt(E)
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
attn = torch.bmm(q, k.transpose(-2, -1))
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(-1)
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)
return output, attn
@staticmethod
def symbolic(g, q, k, v, mask):
"""Symbolic function."""
inputs = [q, k, v]
if mask is not None:
inputs += [mask]
return g.op(
'mmdeploy::ScaledDotProductAttentionTRT', *inputs, outputs=2)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional._scaled_dot_product_attention',
backend=Backend.TENSORRT.value)
def _scaled_dot_product_attention__default(
def _scaled_dot_product_attention__tensorrt(
ctx,
q: Tensor,
k: Tensor,
@ -20,35 +53,5 @@ def _scaled_dot_product_attention__default(
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
"""Rewrite `_scaled_dot_product_attention` to enable softmax."""
B, Nt, E = q.shape
q = q / math.sqrt(E)
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
attn = torch.bmm(q, k.transpose(-2, -1))
if attn_mask is not None:
attn += attn_mask
# add slice to enable softmax
# TODO: Find the reason
step = 500
if attn.size(-1) > step:
attn_max = attn[..., :step].max(-1, keepdim=True)[0]
for i in range(step, attn.size(-1), step):
attn_max_new = attn[..., i:i + step].max(-1, keepdim=True)[0]
attn_max = attn_max.where(attn_max > attn_max_new, attn_max_new)
else:
attn_max = attn.max(-1, keepdim=True)[0]
attn = attn - attn_max
attn_exp = attn.exp()
if attn_exp.size(-1) > step:
attn_sum = attn_exp[..., :step].sum(-1, keepdim=True)
for i in range(step, attn_exp.size(-1), step):
attn_sum_new = attn_exp[..., i:i + step].sum(-1, keepdim=True)
attn_sum += attn_sum_new
else:
attn_sum = attn_exp.sum(-1, keepdim=True)
attn = attn_exp / attn_sum
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)
return output, attn
"""Rewrite for custom ops."""
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)

View File

@ -1096,3 +1096,31 @@ def test_trt_grid_priors(backend, strides, input_list=None, save_dir=None):
3: 'w'
}),
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
def test_dot_product_attention(backend, save_dir=None):
backend.check_env()
B = 2
Nt = 4
Ns = 4
E = 2
query = torch.rand(B, Nt, E).cuda()
key = torch.rand(B, Ns, E).cuda()
value = torch.rand(B, Ns, E).cuda()
model = torch.nn.MultiheadAttention(E, 2).cuda()
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
model, [query, key, value],
'dot_product_attention',
input_names=['query', 'key', 'value'],
output_names=['out', 'attn'],
save_dir=save_dir)