add TRT8 support (#23)

This commit is contained in:
q.yao 2021-07-28 11:27:07 +08:00 committed by GitHub
parent 77080bd931
commit 6ff8e96e71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 318 additions and 244 deletions

View File

@ -1,2 +1,2 @@
[settings] [settings]
known_third_party = mmcls,mmcv,mmdet,numpy,onnx,onnxruntime,pytest,setuptools,tensorrt,torch known_third_party = mmcls,mmcv,mmdet,numpy,onnx,onnxruntime,packaging,pytest,setuptools,tensorrt,torch

View File

@ -63,7 +63,7 @@ endforeach(PLUGIN_ITER)
list(APPEND BACKEND_OPS_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/common_impl/trt_cuda_helper.cu") list(APPEND BACKEND_OPS_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/common_impl/trt_cuda_helper.cu")
set(INFER_PLUGIN_LIB ${TENSORRT_LIBRARY}) set(INFER_PLUGIN_LIB ${TENSORRT_LIBRARY} cublas cudnn)
cuda_add_library(${SHARED_TARGET} MODULE ${BACKEND_OPS_SRCS}) cuda_add_library(${SHARED_TARGET} MODULE ${BACKEND_OPS_SRCS})
target_link_libraries(${SHARED_TARGET} ${INFER_PLUGIN_LIB}) target_link_libraries(${SHARED_TARGET} ${INFER_PLUGIN_LIB})

View File

@ -31,11 +31,11 @@ TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data,
deserialize_value(&data, &length, &mClipBoxes); deserialize_value(&data, &length, &mClipBoxes);
} }
int TRTBatchedNMS::getNbOutputs() const { return 2; } int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; }
nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) { nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
ASSERT(nbInputs == 2); ASSERT(nbInputs == 2);
ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs());
ASSERT(inputs[0].nbDims == 4); ASSERT(inputs[0].nbDims == 4);
@ -61,7 +61,8 @@ nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
size_t TRTBatchedNMS::getWorkspaceSize( size_t TRTBatchedNMS::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
size_t batch_size = inputs[0].dims.d[0]; size_t batch_size = inputs[0].dims.d[0];
size_t boxes_size = size_t boxes_size =
inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3];
@ -77,7 +78,7 @@ size_t TRTBatchedNMS::getWorkspaceSize(
int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, const void* const* inputs, void* const* outputs,
void* workSpace, cudaStream_t stream) { void* workSpace, cudaStream_t stream) TRT_NOEXCEPT {
const void* const locData = inputs[0]; const void* const locData = inputs[0];
const void* const confData = inputs[1]; const void* const confData = inputs[1];
@ -102,12 +103,12 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
return 0; return 0;
} }
size_t TRTBatchedNMS::getSerializationSize() const { size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT {
// NMSParameters, boxesSize,scoresSize,numPriors // NMSParameters, boxesSize,scoresSize,numPriors
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool); return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
} }
void TRTBatchedNMS::serialize(void* buffer) const { void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, param); serialize_value(&buffer, param);
serialize_value(&buffer, boxesSize); serialize_value(&buffer, boxesSize);
serialize_value(&buffer, scoresSize); serialize_value(&buffer, scoresSize);
@ -117,13 +118,14 @@ void TRTBatchedNMS::serialize(void* buffer) const {
void TRTBatchedNMS::configurePlugin( void TRTBatchedNMS::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) { const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {
// Validate input arguments // Validate input arguments
} }
bool TRTBatchedNMS::supportsFormatCombination( bool TRTBatchedNMS::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) { int nbOutputs) TRT_NOEXCEPT {
if (pos == 3) { if (pos == 3) {
return inOut[pos].type == nvinfer1::DataType::kINT32 && return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
@ -132,13 +134,15 @@ bool TRTBatchedNMS::supportsFormatCombination(
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
const char* TRTBatchedNMS::getPluginType() const { return NMS_PLUGIN_NAME; } const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT {
return NMS_PLUGIN_NAME;
}
const char* TRTBatchedNMS::getPluginVersion() const { const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT {
return NMS_PLUGIN_VERSION; return NMS_PLUGIN_VERSION;
} }
IPluginV2DynamicExt* TRTBatchedNMS::clone() const { IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT {
auto* plugin = new TRTBatchedNMS(mLayerName, param); auto* plugin = new TRTBatchedNMS(mLayerName, param);
plugin->boxesSize = boxesSize; plugin->boxesSize = boxesSize;
plugin->scoresSize = scoresSize; plugin->scoresSize = scoresSize;
@ -149,7 +153,8 @@ IPluginV2DynamicExt* TRTBatchedNMS::clone() const {
} }
nvinfer1::DataType TRTBatchedNMS::getOutputDataType( nvinfer1::DataType TRTBatchedNMS::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT {
ASSERT(index >= 0 && index < this->getNbOutputs()); ASSERT(index >= 0 && index < this->getNbOutputs());
if (index == 1) { if (index == 1) {
return nvinfer1::DataType::kINT32; return nvinfer1::DataType::kINT32;
@ -181,16 +186,16 @@ TRTBatchedNMSCreator::TRTBatchedNMSCreator() {
mFC.fields = mPluginAttributes.data(); mFC.fields = mPluginAttributes.data();
} }
const char* TRTBatchedNMSCreator::getPluginName() const { const char* TRTBatchedNMSCreator::getPluginName() const TRT_NOEXCEPT {
return NMS_PLUGIN_NAME; return NMS_PLUGIN_NAME;
} }
const char* TRTBatchedNMSCreator::getPluginVersion() const { const char* TRTBatchedNMSCreator::getPluginVersion() const TRT_NOEXCEPT {
return NMS_PLUGIN_VERSION; return NMS_PLUGIN_VERSION;
} }
IPluginV2Ext* TRTBatchedNMSCreator::createPlugin( IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(
const char* name, const PluginFieldCollection* fc) { const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT {
const PluginField* fields = fc->fields; const PluginField* fields = fc->fields;
bool clipBoxes = true; bool clipBoxes = true;
nvinfer1::plugin::NMSParameters params{}; nvinfer1::plugin::NMSParameters params{};
@ -228,9 +233,9 @@ IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(
return plugin; return plugin;
} }
IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name, IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(
const void* serialData, const char* name, const void* serialData,
size_t serialLength) { size_t serialLength) TRT_NOEXCEPT {
// This object will be deleted when the network is destroyed, which will // This object will be deleted when the network is destroyed, which will
// call NMS::destroy() // call NMS::destroy()
TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength); TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength);

View File

@ -13,46 +13,47 @@ class TRTBatchedNMS : public TRTPluginBase {
TRTBatchedNMS(const std::string& name, const void* data, size_t length); TRTBatchedNMS(const std::string& name, const void* data, size_t length);
~TRTBatchedNMS() override = default; ~TRTBatchedNMS() TRT_NOEXCEPT override = default;
int getNbOutputs() const override; int getNbOutputs() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override; nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override; int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workSpace, const void* const* inputs, void* const* outputs, void* workSpace,
cudaStream_t stream) override; cudaStream_t stream) TRT_NOEXCEPT override;
size_t getSerializationSize() const override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const override; void serialize(void* buffer) const TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs, const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) override; int nbOutputs) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut, const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override; int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
const char* getPluginType() const override; const char* getPluginType() const TRT_NOEXCEPT override;
const char* getPluginVersion() const override; const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt* clone() const override; nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(
const nvinfer1::DataType* inputType, int index, const nvinfer1::DataType* inputType,
int nbInputs) const override; int nbInputs) const TRT_NOEXCEPT override;
void setClipParam(bool clip); void setClipParam(bool clip);
@ -68,18 +69,19 @@ class TRTBatchedNMSCreator : public TRTPluginCreatorBase {
public: public:
TRTBatchedNMSCreator(); TRTBatchedNMSCreator();
~TRTBatchedNMSCreator() override = default; ~TRTBatchedNMSCreator() TRT_NOEXCEPT override = default;
const char* getPluginName() const override; const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const override; const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* createPlugin( nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override; const char* name,
const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, nvinfer1::IPluginV2Ext* deserializePlugin(
const void* serialData, const char* name, const void* serialData,
size_t serialLength) override; size_t serialLength) TRT_NOEXCEPT override;
}; };
} // namespace mmlab } // namespace mmlab
#endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H #endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H

View File

@ -1,26 +1,37 @@
#ifndef TRT_PLUGIN_BASE_HPP #ifndef TRT_PLUGIN_BASE_HPP
#define TRT_PLUGIN_BASE_HPP #define TRT_PLUGIN_BASE_HPP
#include "NvInferPlugin.h" #include "NvInferPlugin.h"
#include "NvInferVersion.h"
#include "trt_plugin_helper.hpp" #include "trt_plugin_helper.hpp"
namespace mmlab { namespace mmlab {
#if NV_TENSORRT_MAJOR > 7
#define TRT_NOEXCEPT noexcept
#else
#define TRT_NOEXCEPT
#endif
class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
public: public:
TRTPluginBase(const std::string &name) : mLayerName(name) {} TRTPluginBase(const std::string &name) : mLayerName(name) {}
// IPluginV2 Methods // IPluginV2 Methods
const char *getPluginVersion() const override { return "1"; } const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
int initialize() override { return STATUS_SUCCESS; } int initialize() TRT_NOEXCEPT override { return STATUS_SUCCESS; }
void terminate() override {} void terminate() TRT_NOEXCEPT override {}
void destroy() override { delete this; } void destroy() TRT_NOEXCEPT override { delete this; }
void setPluginNamespace(const char *pluginNamespace) override { void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override {
mNamespace = pluginNamespace; mNamespace = pluginNamespace;
} }
const char *getPluginNamespace() const override { return mNamespace.c_str(); } const char *getPluginNamespace() const TRT_NOEXCEPT override {
return mNamespace.c_str();
}
protected: protected:
const std::string mLayerName; const std::string mLayerName;
std::string mNamespace; std::string mNamespace;
#if NV_TENSORRT_MAJOR < 8
protected: protected:
// To prevent compiler warnings. // To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
@ -30,21 +41,24 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat; using nvinfer1::IPluginV2DynamicExt::supportsFormat;
#endif
}; };
class TRTPluginCreatorBase : public nvinfer1::IPluginCreator { class TRTPluginCreatorBase : public nvinfer1::IPluginCreator {
public: public:
const char *getPluginVersion() const override { return "1"; }; const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; };
const nvinfer1::PluginFieldCollection *getFieldNames() override { const nvinfer1::PluginFieldCollection *getFieldNames() TRT_NOEXCEPT override {
return &mFC; return &mFC;
} }
void setPluginNamespace(const char *pluginNamespace) override { void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override {
mNamespace = pluginNamespace; mNamespace = pluginNamespace;
} }
const char *getPluginNamespace() const override { return mNamespace.c_str(); } const char *getPluginNamespace() const TRT_NOEXCEPT override {
return mNamespace.c_str();
}
protected: protected:
nvinfer1::PluginFieldCollection mFC; nvinfer1::PluginFieldCollection mFC;

View File

@ -31,18 +31,19 @@ TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name,
TRTInstanceNormalization::~TRTInstanceNormalization() {} TRTInstanceNormalization::~TRTInstanceNormalization() {}
// TRTInstanceNormalization returns one output. // TRTInstanceNormalization returns one output.
int TRTInstanceNormalization::getNbOutputs() const { return 1; } int TRTInstanceNormalization::getNbOutputs() const TRT_NOEXCEPT { return 1; }
DimsExprs TRTInstanceNormalization::getOutputDimensions( DimsExprs TRTInstanceNormalization::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) { nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output(inputs[0]); nvinfer1::DimsExprs output(inputs[0]);
return output; return output;
} }
size_t TRTInstanceNormalization::getWorkspaceSize( size_t TRTInstanceNormalization::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
int n = inputs[0].dims.d[0]; int n = inputs[0].dims.d[0];
int c = inputs[0].dims.d[1]; int c = inputs[0].dims.d[1];
int elem_size = getElementSize(inputs[1].type); int elem_size = getElementSize(inputs[1].type);
@ -52,7 +53,7 @@ size_t TRTInstanceNormalization::getWorkspaceSize(
int TRTInstanceNormalization::enqueue( int TRTInstanceNormalization::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) { void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
nvinfer1::Dims input_dims = inputDesc[0].dims; nvinfer1::Dims input_dims = inputDesc[0].dims;
int n = input_dims.d[0]; int n = input_dims.d[0];
int c = input_dims.d[1]; int c = input_dims.d[1];
@ -97,47 +98,48 @@ int TRTInstanceNormalization::enqueue(
return 0; return 0;
} }
size_t TRTInstanceNormalization::getSerializationSize() const { size_t TRTInstanceNormalization::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mEpsilon); return serialized_size(mEpsilon);
} }
void TRTInstanceNormalization::serialize(void* buffer) const { void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mEpsilon); serialize_value(&buffer, mEpsilon);
} }
bool TRTInstanceNormalization::supportsFormatCombination( bool TRTInstanceNormalization::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) { int nbOutputs) TRT_NOEXCEPT {
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT || return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
inOut[pos].type == nvinfer1::DataType::kHALF) && inOut[pos].type == nvinfer1::DataType::kHALF) &&
inOut[pos].format == nvinfer1::PluginFormat::kLINEAR && inOut[pos].format == nvinfer1::PluginFormat::kLINEAR &&
inOut[pos].type == inOut[0].type); inOut[pos].type == inOut[0].type);
} }
const char* TRTInstanceNormalization::getPluginType() const { const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME; return PLUGIN_NAME;
} }
const char* TRTInstanceNormalization::getPluginVersion() const { const char* TRTInstanceNormalization::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION; return PLUGIN_VERSION;
} }
IPluginV2DynamicExt* TRTInstanceNormalization::clone() const { IPluginV2DynamicExt* TRTInstanceNormalization::clone() const TRT_NOEXCEPT {
auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon}; auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon};
plugin->setPluginNamespace(mPluginNamespace.c_str()); plugin->setPluginNamespace(mPluginNamespace.c_str());
return plugin; return plugin;
} }
nvinfer1::DataType TRTInstanceNormalization::getOutputDataType( nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0]; return inputTypes[0];
} }
// Attach the plugin object to an execution context and grant the plugin the // Attach the plugin object to an execution context and grant the plugin the
// access to some context resource. // access to some context resource.
void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext, void TRTInstanceNormalization::attachToContext(
cublasContext* cublasContext, cudnnContext* cudnnContext, cublasContext* cublasContext,
IGpuAllocator* gpuAllocator) { IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {
_cudnn_handle = cudnnContext; _cudnn_handle = cudnnContext;
cudnnCreateTensorDescriptor(&_b_desc); cudnnCreateTensorDescriptor(&_b_desc);
cudnnCreateTensorDescriptor(&_x_desc); cudnnCreateTensorDescriptor(&_x_desc);
@ -145,7 +147,7 @@ void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext,
} }
// Detach the plugin object from its execution context. // Detach the plugin object from its execution context.
void TRTInstanceNormalization::detachFromContext() { void TRTInstanceNormalization::detachFromContext() TRT_NOEXCEPT {
cudnnDestroyTensorDescriptor(_y_desc); cudnnDestroyTensorDescriptor(_y_desc);
cudnnDestroyTensorDescriptor(_x_desc); cudnnDestroyTensorDescriptor(_x_desc);
cudnnDestroyTensorDescriptor(_b_desc); cudnnDestroyTensorDescriptor(_b_desc);
@ -153,7 +155,7 @@ void TRTInstanceNormalization::detachFromContext() {
void TRTInstanceNormalization::configurePlugin( void TRTInstanceNormalization::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {} const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
// TRTInstanceNormalizationCreator methods // TRTInstanceNormalizationCreator methods
TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() { TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() {
@ -165,16 +167,18 @@ TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() {
mFC.fields = mPluginAttributes.data(); mFC.fields = mPluginAttributes.data();
} }
const char* TRTInstanceNormalizationCreator::getPluginName() const { const char* TRTInstanceNormalizationCreator::getPluginName() const
TRT_NOEXCEPT {
return PLUGIN_NAME; return PLUGIN_NAME;
} }
const char* TRTInstanceNormalizationCreator::getPluginVersion() const { const char* TRTInstanceNormalizationCreator::getPluginVersion() const
TRT_NOEXCEPT {
return PLUGIN_VERSION; return PLUGIN_VERSION;
} }
IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin( IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) { const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
float epsilon = 1e-5; float epsilon = 1e-5;
const PluginField* fields = fc->fields; const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i) { for (int i = 0; i < fc->nbFields; ++i) {
@ -190,7 +194,8 @@ IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin(
} }
IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin( IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) { const char* name, const void* serialData,
size_t serialLength) TRT_NOEXCEPT {
TRTInstanceNormalization* obj = TRTInstanceNormalization* obj =
new TRTInstanceNormalization{name, serialData, serialLength}; new TRTInstanceNormalization{name, serialData, serialLength};
obj->setPluginNamespace(mNamespace.c_str()); obj->setPluginNamespace(mNamespace.c_str());

View File

@ -23,53 +23,55 @@ class TRTInstanceNormalization final : public TRTPluginBase {
TRTInstanceNormalization() = delete; TRTInstanceNormalization() = delete;
~TRTInstanceNormalization() override; ~TRTInstanceNormalization() TRT_NOEXCEPT override;
int getNbOutputs() const override; int getNbOutputs() const TRT_NOEXCEPT override;
// DynamicExt plugins returns DimsExprs class instead of Dims // DynamicExt plugins returns DimsExprs class instead of Dims
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override; nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override; int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override; cudaStream_t stream) TRT_NOEXCEPT override;
size_t getSerializationSize() const override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const override; void serialize(void* buffer) const TRT_NOEXCEPT override;
// DynamicExt plugin supportsFormat update. // DynamicExt plugin supportsFormat update.
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut, const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override; int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
const char* getPluginType() const override; const char* getPluginType() const TRT_NOEXCEPT override;
const char* getPluginVersion() const override; const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt* clone() const override; nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(
const nvinfer1::DataType* inputTypes, int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const override; int nbInputs) const TRT_NOEXCEPT override;
void attachToContext(cudnnContext* cudnn, cublasContext* cublas, void attachToContext(cudnnContext* cudnn, cublasContext* cublas,
nvinfer1::IGpuAllocator* allocator) override; nvinfer1::IGpuAllocator* allocator)
TRT_NOEXCEPT override;
void detachFromContext() override; void detachFromContext() TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override; int nbOutputs) TRT_NOEXCEPT override;
private: private:
float mEpsilon{}; float mEpsilon{};
@ -84,15 +86,17 @@ class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase {
~TRTInstanceNormalizationCreator() override = default; ~TRTInstanceNormalizationCreator() override = default;
const char* getPluginName() const override; const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const override; const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt* createPlugin( nvinfer1::IPluginV2DynamicExt* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override; const char* name,
const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt* deserializePlugin( nvinfer1::IPluginV2DynamicExt* deserializePlugin(
const char* name, const void* serialData, size_t serialLength) override; const char* name, const void* serialData,
size_t serialLength) TRT_NOEXCEPT override;
}; };
} // namespace mmlab } // namespace mmlab
#endif // TRT_INSTANCE_NORMALIZATION_HPP #endif // TRT_INSTANCE_NORMALIZATION_HPP

View File

@ -39,7 +39,8 @@ TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name,
deserialize_value(&data, &length, &mFeatmapStrides); deserialize_value(&data, &length, &mFeatmapStrides);
} }
nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const { nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const
TRT_NOEXCEPT {
TRTMultiLevelRoiAlign *plugin = new TRTMultiLevelRoiAlign( TRTMultiLevelRoiAlign *plugin = new TRTMultiLevelRoiAlign(
mLayerName, mAlignedHeight, mAlignedWidth, mSampleNum, mFeatmapStrides, mLayerName, mAlignedHeight, mAlignedWidth, mSampleNum, mFeatmapStrides,
mRoiScaleFactor, mFinestScale, mAligned); mRoiScaleFactor, mFinestScale, mAligned);
@ -50,7 +51,7 @@ nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const {
nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions( nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) { nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
ASSERT(nbInputs == mFeatmapStrides.size() + 1); ASSERT(nbInputs == mFeatmapStrides.size() + 1);
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
@ -65,14 +66,15 @@ nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
bool TRTMultiLevelRoiAlign::supportsFormatCombination( bool TRTMultiLevelRoiAlign::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) { int nbOutputs) TRT_NOEXCEPT {
return inOut[pos].type == nvinfer1::DataType::kFLOAT && return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
void TRTMultiLevelRoiAlign::configurePlugin( void TRTMultiLevelRoiAlign::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) { const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {
// Validate input arguments // Validate input arguments
ASSERT(nbOutputs == 1); ASSERT(nbOutputs == 1);
ASSERT(nbInputs == mFeatmapStrides.size() + 1); ASSERT(nbInputs == mFeatmapStrides.size() + 1);
@ -80,7 +82,8 @@ void TRTMultiLevelRoiAlign::configurePlugin(
size_t TRTMultiLevelRoiAlign::getWorkspaceSize( size_t TRTMultiLevelRoiAlign::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs, const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0; return 0;
} }
@ -88,7 +91,7 @@ int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, const void *const *inputs,
void *const *outputs, void *workSpace, void *const *outputs, void *workSpace,
cudaStream_t stream) { cudaStream_t stream) TRT_NOEXCEPT {
int num_rois = inputDesc[0].dims.d[0]; int num_rois = inputDesc[0].dims.d[0];
int batch_size = inputDesc[1].dims.d[0]; int batch_size = inputDesc[1].dims.d[0];
int channels = inputDesc[1].dims.d[1]; int channels = inputDesc[1].dims.d[1];
@ -118,27 +121,30 @@ int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
} }
nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType( nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return nvinfer1::DataType::kFLOAT; return nvinfer1::DataType::kFLOAT;
} }
// IPluginV2 Methods // IPluginV2 Methods
const char *TRTMultiLevelRoiAlign::getPluginType() const { return PLUGIN_NAME; } const char *TRTMultiLevelRoiAlign::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTMultiLevelRoiAlign::getPluginVersion() const { const char *TRTMultiLevelRoiAlign::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION; return PLUGIN_VERSION;
} }
int TRTMultiLevelRoiAlign::getNbOutputs() const { return 1; } int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTMultiLevelRoiAlign::getSerializationSize() const { size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) +
serialized_size(mAlignedWidth) + serialized_size(mSampleNum) + serialized_size(mAlignedWidth) + serialized_size(mSampleNum) +
serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) + serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) +
serialized_size(mAligned); serialized_size(mAligned);
} }
void TRTMultiLevelRoiAlign::serialize(void *buffer) const { void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mAlignedHeight); serialize_value(&buffer, mAlignedHeight);
serialize_value(&buffer, mAlignedWidth); serialize_value(&buffer, mAlignedWidth);
serialize_value(&buffer, mSampleNum); serialize_value(&buffer, mSampleNum);
@ -161,16 +167,17 @@ TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() {
mFC.fields = mPluginAttributes.data(); mFC.fields = mPluginAttributes.data();
} }
const char *TRTMultiLevelRoiAlignCreator::getPluginName() const { const char *TRTMultiLevelRoiAlignCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME; return PLUGIN_NAME;
} }
const char *TRTMultiLevelRoiAlignCreator::getPluginVersion() const { const char *TRTMultiLevelRoiAlignCreator::getPluginVersion() const
TRT_NOEXCEPT {
return PLUGIN_VERSION; return PLUGIN_VERSION;
} }
nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) { const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
int alignedHeight = 7; int alignedHeight = 7;
int alignedWidth = 7; int alignedWidth = 7;
int sampleNum = 2; int sampleNum = 2;
@ -215,7 +222,8 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
} }
nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::deserializePlugin( nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) { const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength); auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace()); plugin->setPluginNamespace(getPluginNamespace());
return plugin; return plugin;

View File

@ -24,37 +24,38 @@ class TRTMultiLevelRoiAlign : public TRTPluginBase {
TRTMultiLevelRoiAlign() = delete; TRTMultiLevelRoiAlign() = delete;
// IPluginV2DynamicExt Methods // IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override; nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override; nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut, const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override; int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override; int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override; int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override; cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods // IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(
const nvinfer1::DataType *inputTypes, int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const override; int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods // IPluginV2 Methods
const char *getPluginType() const override; const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const override; int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const override; void serialize(void *buffer) const TRT_NOEXCEPT override;
private: private:
int mAlignedHeight; int mAlignedHeight;
@ -70,16 +71,17 @@ class TRTMultiLevelRoiAlignCreator : public TRTPluginCreatorBase {
public: public:
TRTMultiLevelRoiAlignCreator(); TRTMultiLevelRoiAlignCreator();
const char *getPluginName() const override; const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin( nvinfer1::IPluginV2 *createPlugin(const char *name,
const char *name, const nvinfer1::PluginFieldCollection *fc) override; const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, nvinfer1::IPluginV2 *deserializePlugin(
const void *serialData, const char *name, const void *serialData,
size_t serialLength) override; size_t serialLength) TRT_NOEXCEPT override;
}; };
} // namespace mmlab } // namespace mmlab
#endif // TRT_ROI_ALIGN_HPP #endif // TRT_ROI_ALIGN_HPP

View File

@ -32,7 +32,7 @@ TRTNMS::TRTNMS(const std::string name, const void *data, size_t length)
deserialize_value(&data, &length, &mOffset); deserialize_value(&data, &length, &mOffset);
} }
nvinfer1::IPluginV2DynamicExt *TRTNMS::clone() const { nvinfer1::IPluginV2DynamicExt *TRTNMS::clone() const TRT_NOEXCEPT {
TRTNMS *plugin = TRTNMS *plugin =
new TRTNMS(mLayerName, mCenterPointBox, mMaxOutputBoxesPerClass, new TRTNMS(mLayerName, mCenterPointBox, mMaxOutputBoxesPerClass,
mIouThreshold, mScoreThreshold, mOffset); mIouThreshold, mScoreThreshold, mOffset);
@ -43,7 +43,7 @@ nvinfer1::IPluginV2DynamicExt *TRTNMS::clone() const {
nvinfer1::DimsExprs TRTNMS::getOutputDimensions( nvinfer1::DimsExprs TRTNMS::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) { nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 2; ret.nbDims = 2;
auto num_batches = inputs[0].d[0]; auto num_batches = inputs[0].d[0];
@ -65,7 +65,8 @@ nvinfer1::DimsExprs TRTNMS::getOutputDimensions(
bool TRTNMS::supportsFormatCombination(int pos, bool TRTNMS::supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut, const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) { int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
if (pos < nbInputs) { if (pos < nbInputs) {
switch (pos) { switch (pos) {
case 0: case 0:
@ -95,12 +96,12 @@ bool TRTNMS::supportsFormatCombination(int pos,
void TRTNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, void TRTNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) {} int nbOutputs) TRT_NOEXCEPT {}
size_t TRTNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t TRTNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const { int nbOutputs) const TRT_NOEXCEPT {
size_t boxes_word_size = mmlab::getElementSize(inputs[0].type); size_t boxes_word_size = mmlab::getElementSize(inputs[0].type);
size_t num_batches = inputs[0].dims.d[0]; size_t num_batches = inputs[0].dims.d[0];
size_t spatial_dimension = inputs[0].dims.d[1]; size_t spatial_dimension = inputs[0].dims.d[1];
@ -115,7 +116,7 @@ size_t TRTNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int TRTNMS::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int TRTNMS::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, const void *const *inputs, void *const *outputs,
void *workSpace, cudaStream_t stream) { void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
int num_batches = inputDesc[0].dims.d[0]; int num_batches = inputDesc[0].dims.d[0];
int spatial_dimension = inputDesc[0].dims.d[1]; int spatial_dimension = inputDesc[0].dims.d[1];
int num_classes = inputDesc[1].dims.d[1]; int num_classes = inputDesc[1].dims.d[1];
@ -133,25 +134,28 @@ int TRTNMS::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
} }
nvinfer1::DataType TRTNMS::getOutputDataType( nvinfer1::DataType TRTNMS::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return nvinfer1::DataType::kINT32; return nvinfer1::DataType::kINT32;
} }
// IPluginV2 Methods // IPluginV2 Methods
const char *TRTNMS::getPluginType() const { return PLUGIN_NAME; } const char *TRTNMS::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
const char *TRTNMS::getPluginVersion() const { return PLUGIN_VERSION; } const char *TRTNMS::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
int TRTNMS::getNbOutputs() const { return 1; } int TRTNMS::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTNMS::getSerializationSize() const { size_t TRTNMS::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mCenterPointBox) + return serialized_size(mCenterPointBox) +
serialized_size(mMaxOutputBoxesPerClass) + serialized_size(mMaxOutputBoxesPerClass) +
serialized_size(mIouThreshold) + serialized_size(mScoreThreshold) + serialized_size(mIouThreshold) + serialized_size(mScoreThreshold) +
serialized_size(mOffset); serialized_size(mOffset);
} }
void TRTNMS::serialize(void *buffer) const { void TRTNMS::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mCenterPointBox); serialize_value(&buffer, mCenterPointBox);
serialize_value(&buffer, mMaxOutputBoxesPerClass); serialize_value(&buffer, mMaxOutputBoxesPerClass);
serialize_value(&buffer, mIouThreshold); serialize_value(&buffer, mIouThreshold);
@ -171,12 +175,16 @@ TRTNMSCreator::TRTNMSCreator() {
mFC.fields = mPluginAttributes.data(); mFC.fields = mPluginAttributes.data();
} }
const char *TRTNMSCreator::getPluginName() const { return PLUGIN_NAME; } const char *TRTNMSCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTNMSCreator::getPluginVersion() const { return PLUGIN_VERSION; } const char *TRTNMSCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
nvinfer1::IPluginV2 *TRTNMSCreator::createPlugin( nvinfer1::IPluginV2 *TRTNMSCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) { const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
int centerPointBox = 0; int centerPointBox = 0;
int maxOutputBoxesPerClass = 0; int maxOutputBoxesPerClass = 0;
float iouThreshold = 0.0f; float iouThreshold = 0.0f;
@ -215,9 +223,9 @@ nvinfer1::IPluginV2 *TRTNMSCreator::createPlugin(
return plugin; return plugin;
} }
nvinfer1::IPluginV2 *TRTNMSCreator::deserializePlugin(const char *name, nvinfer1::IPluginV2 *TRTNMSCreator::deserializePlugin(
const void *serialData, const char *name, const void *serialData,
size_t serialLength) { size_t serialLength) TRT_NOEXCEPT {
auto plugin = new TRTNMS(name, serialData, serialLength); auto plugin = new TRTNMS(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace()); plugin->setPluginNamespace(getPluginNamespace());
return plugin; return plugin;

View File

@ -19,37 +19,38 @@ class TRTNMS : public TRTPluginBase {
TRTNMS() = delete; TRTNMS() = delete;
// IPluginV2DynamicExt Methods // IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override; nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override; nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut, const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override; int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override; int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override; int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override; cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods // IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(
const nvinfer1::DataType *inputTypes, int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const override; int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods // IPluginV2 Methods
const char *getPluginType() const override; const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const override; int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const override; void serialize(void *buffer) const TRT_NOEXCEPT override;
private: private:
int mCenterPointBox; int mCenterPointBox;
@ -63,16 +64,17 @@ class TRTNMSCreator : public TRTPluginCreatorBase {
public: public:
TRTNMSCreator(); TRTNMSCreator();
const char *getPluginName() const override; const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin( nvinfer1::IPluginV2 *createPlugin(const char *name,
const char *name, const nvinfer1::PluginFieldCollection *fc) override; const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, nvinfer1::IPluginV2 *deserializePlugin(
const void *serialData, const char *name, const void *serialData,
size_t serialLength) override; size_t serialLength) TRT_NOEXCEPT override;
}; };
} // namespace mmlab } // namespace mmlab
#endif // TRT_NMS_HPP #endif // TRT_NMS_HPP

View File

@ -36,7 +36,7 @@ TRTRoIAlign::TRTRoIAlign(const std::string name, const void *data,
deserialize_value(&data, &length, &mAligned); deserialize_value(&data, &length, &mAligned);
} }
nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const { nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const TRT_NOEXCEPT {
TRTRoIAlign *plugin = TRTRoIAlign *plugin =
new TRTRoIAlign(mLayerName, mOutWidth, mOutHeight, mSpatialScale, new TRTRoIAlign(mLayerName, mOutWidth, mOutHeight, mSpatialScale,
mSampleRatio, mPoolMode, mAligned); mSampleRatio, mPoolMode, mAligned);
@ -47,7 +47,7 @@ nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const {
nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions( nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) { nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 4; ret.nbDims = 4;
ret.d[0] = inputs[1].d[0]; ret.d[0] = inputs[1].d[0];
@ -60,19 +60,20 @@ nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions(
bool TRTRoIAlign::supportsFormatCombination( bool TRTRoIAlign::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) { int nbOutputs) TRT_NOEXCEPT {
return inOut[pos].type == nvinfer1::DataType::kFLOAT && return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
void TRTRoIAlign::configurePlugin( void TRTRoIAlign::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}
size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const { int nbOutputs) const TRT_NOEXCEPT {
size_t output_size = 0; size_t output_size = 0;
size_t word_size = 0; size_t word_size = 0;
switch (mPoolMode) { switch (mPoolMode) {
@ -94,7 +95,7 @@ size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, const void *const *inputs, void *const *outputs,
void *workSpace, cudaStream_t stream) { void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
int channels = inputDesc[0].dims.d[1]; int channels = inputDesc[0].dims.d[1];
int height = inputDesc[0].dims.d[2]; int height = inputDesc[0].dims.d[2];
int width = inputDesc[0].dims.d[3]; int width = inputDesc[0].dims.d[3];
@ -135,24 +136,29 @@ int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
} }
nvinfer1::DataType TRTRoIAlign::getOutputDataType( nvinfer1::DataType TRTRoIAlign::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0]; return inputTypes[0];
} }
// IPluginV2 Methods // IPluginV2 Methods
const char *TRTRoIAlign::getPluginType() const { return PLUGIN_NAME; } const char *TRTRoIAlign::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTRoIAlign::getPluginVersion() const { return PLUGIN_VERSION; } const char *TRTRoIAlign::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
int TRTRoIAlign::getNbOutputs() const { return 1; } int TRTRoIAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTRoIAlign::getSerializationSize() const { size_t TRTRoIAlign::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mOutWidth) + serialized_size(mOutHeight) + return serialized_size(mOutWidth) + serialized_size(mOutHeight) +
serialized_size(mSpatialScale) + serialized_size(mSampleRatio) + serialized_size(mSpatialScale) + serialized_size(mSampleRatio) +
serialized_size(mPoolMode) + serialized_size(mAligned); serialized_size(mPoolMode) + serialized_size(mAligned);
} }
void TRTRoIAlign::serialize(void *buffer) const { void TRTRoIAlign::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mOutWidth); serialize_value(&buffer, mOutWidth);
serialize_value(&buffer, mOutHeight); serialize_value(&buffer, mOutHeight);
serialize_value(&buffer, mSpatialScale); serialize_value(&buffer, mSpatialScale);
@ -172,14 +178,16 @@ TRTRoIAlignCreator::TRTRoIAlignCreator() {
mFC.fields = mPluginAttributes.data(); mFC.fields = mPluginAttributes.data();
} }
const char *TRTRoIAlignCreator::getPluginName() const { return PLUGIN_NAME; } const char *TRTRoIAlignCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTRoIAlignCreator::getPluginVersion() const { const char *TRTRoIAlignCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION; return PLUGIN_VERSION;
} }
nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin( nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) { const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
int outWidth = 7; int outWidth = 7;
int outHeight = 7; int outHeight = 7;
float spatialScale = 1.0; float spatialScale = 1.0;
@ -241,7 +249,8 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
} }
nvinfer1::IPluginV2 *TRTRoIAlignCreator::deserializePlugin( nvinfer1::IPluginV2 *TRTRoIAlignCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) { const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new TRTRoIAlign(name, serialData, serialLength); auto plugin = new TRTRoIAlign(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace()); plugin->setPluginNamespace(getPluginNamespace());
return plugin; return plugin;

View File

@ -18,37 +18,38 @@ class TRTRoIAlign : public TRTPluginBase {
TRTRoIAlign() = delete; TRTRoIAlign() = delete;
// IPluginV2DynamicExt Methods // IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override; nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override; nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut, const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override; int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override; int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override; int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override; cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods // IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(
const nvinfer1::DataType *inputTypes, int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const override; int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods // IPluginV2 Methods
const char *getPluginType() const override; const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const override; int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const override; void serialize(void *buffer) const TRT_NOEXCEPT override;
private: private:
int mOutWidth; int mOutWidth;
@ -63,15 +64,16 @@ class TRTRoIAlignCreator : public TRTPluginCreatorBase {
public: public:
TRTRoIAlignCreator(); TRTRoIAlignCreator();
const char *getPluginName() const override; const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin( nvinfer1::IPluginV2 *createPlugin(const char *name,
const char *name, const nvinfer1::PluginFieldCollection *fc) override; const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, nvinfer1::IPluginV2 *deserializePlugin(
const void *serialData, const char *name, const void *serialData,
size_t serialLength) override; size_t serialLength) TRT_NOEXCEPT override;
}; };
} // namespace mmlab } // namespace mmlab
#endif // TRT_ROI_ALIGN_HPP #endif // TRT_ROI_ALIGN_HPP

View File

@ -20,7 +20,7 @@ TRTScatterND::TRTScatterND(const std::string name, const void *data,
size_t length) size_t length)
: TRTPluginBase(name) {} : TRTPluginBase(name) {}
nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const { nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const TRT_NOEXCEPT {
TRTScatterND *plugin = new TRTScatterND(mLayerName); TRTScatterND *plugin = new TRTScatterND(mLayerName);
plugin->setPluginNamespace(getPluginNamespace()); plugin->setPluginNamespace(getPluginNamespace());
@ -29,13 +29,13 @@ nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const {
nvinfer1::DimsExprs TRTScatterND::getOutputDimensions( nvinfer1::DimsExprs TRTScatterND::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) { nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
return inputs[0]; return inputs[0];
} }
bool TRTScatterND::supportsFormatCombination( bool TRTScatterND::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) { int nbOutputs) TRT_NOEXCEPT {
if (pos < nbInputs) { if (pos < nbInputs) {
switch (pos) { switch (pos) {
case 0: case 0:
@ -70,19 +70,20 @@ bool TRTScatterND::supportsFormatCombination(
void TRTScatterND::configurePlugin( void TRTScatterND::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}
size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const { int nbOutputs) const TRT_NOEXCEPT {
return 0; return 0;
} }
int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, const void *const *inputs, void *const *outputs,
void *workSpace, cudaStream_t stream) { void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
const int *dims = &(inputDesc[0].dims.d[0]); const int *dims = &(inputDesc[0].dims.d[0]);
const int *indices_dims = &(inputDesc[1].dims.d[0]); const int *indices_dims = &(inputDesc[1].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims; int nbDims = inputDesc[0].dims.nbDims;
@ -115,20 +116,25 @@ int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
} }
nvinfer1::DataType TRTScatterND::getOutputDataType( nvinfer1::DataType TRTScatterND::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0]; return inputTypes[0];
} }
// IPluginV2 Methods // IPluginV2 Methods
const char *TRTScatterND::getPluginType() const { return PLUGIN_NAME; } const char *TRTScatterND::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTScatterND::getPluginVersion() const { return PLUGIN_VERSION; } const char *TRTScatterND::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
int TRTScatterND::getNbOutputs() const { return 1; } int TRTScatterND::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTScatterND::getSerializationSize() const { return 0; } size_t TRTScatterND::getSerializationSize() const TRT_NOEXCEPT { return 0; }
void TRTScatterND::serialize(void *buffer) const {} void TRTScatterND::serialize(void *buffer) const TRT_NOEXCEPT {}
TRTScatterNDCreator::TRTScatterNDCreator() { TRTScatterNDCreator::TRTScatterNDCreator() {
mPluginAttributes.clear(); mPluginAttributes.clear();
@ -136,21 +142,24 @@ TRTScatterNDCreator::TRTScatterNDCreator() {
mFC.fields = mPluginAttributes.data(); mFC.fields = mPluginAttributes.data();
} }
const char *TRTScatterNDCreator::getPluginName() const { return PLUGIN_NAME; } const char *TRTScatterNDCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTScatterNDCreator::getPluginVersion() const { const char *TRTScatterNDCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION; return PLUGIN_VERSION;
} }
nvinfer1::IPluginV2 *TRTScatterNDCreator::createPlugin( nvinfer1::IPluginV2 *TRTScatterNDCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) { const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
TRTScatterND *plugin = new TRTScatterND(name); TRTScatterND *plugin = new TRTScatterND(name);
plugin->setPluginNamespace(getPluginNamespace()); plugin->setPluginNamespace(getPluginNamespace());
return plugin; return plugin;
} }
nvinfer1::IPluginV2 *TRTScatterNDCreator::deserializePlugin( nvinfer1::IPluginV2 *TRTScatterNDCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) { const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new TRTScatterND(name, serialData, serialLength); auto plugin = new TRTScatterND(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace()); plugin->setPluginNamespace(getPluginNamespace());
return plugin; return plugin;

View File

@ -18,52 +18,54 @@ class TRTScatterND : public TRTPluginBase {
TRTScatterND() = delete; TRTScatterND() = delete;
// IPluginV2DynamicExt Methods // IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override; nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override; nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut, const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override; int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override; int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override; int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override; cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods // IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, nvinfer1::DataType getOutputDataType(
const nvinfer1::DataType *inputTypes, int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const override; int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods // IPluginV2 Methods
const char *getPluginType() const override; const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const override; int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const override; void serialize(void *buffer) const TRT_NOEXCEPT override;
}; };
class TRTScatterNDCreator : public TRTPluginCreatorBase { class TRTScatterNDCreator : public TRTPluginCreatorBase {
public: public:
TRTScatterNDCreator(); TRTScatterNDCreator();
const char *getPluginName() const override; const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const override; const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin( nvinfer1::IPluginV2 *createPlugin(const char *name,
const char *name, const nvinfer1::PluginFieldCollection *fc) override; const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, nvinfer1::IPluginV2 *deserializePlugin(
const void *serialData, const char *name, const void *serialData,
size_t serialLength) override; size_t serialLength) TRT_NOEXCEPT override;
}; };
} // namespace mmlab } // namespace mmlab
#endif // TRT_SCATTERND_HPP #endif // TRT_SCATTERND_HPP

View File

@ -1,6 +1,7 @@
import onnx import onnx
import tensorrt as trt import tensorrt as trt
import torch import torch
from packaging import version
def create_trt_engine(onnx_model, def create_trt_engine(onnx_model,
@ -56,7 +57,8 @@ def create_trt_engine(onnx_model,
raise RuntimeError(f'parse onnx failed:\n{error_msgs}') raise RuntimeError(f'parse onnx failed:\n{error_msgs}')
# config builder # config builder
builder.max_workspace_size = max_workspace_size if version.parse(trt.__version__) < version.parse('8'):
builder.max_workspace_size = max_workspace_size
config = builder.create_builder_config() config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size config.max_workspace_size = max_workspace_size