mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add TRT8 support (#23)
This commit is contained in:
parent
77080bd931
commit
6ff8e96e71
@ -1,2 +1,2 @@
|
||||
[settings]
|
||||
known_third_party = mmcls,mmcv,mmdet,numpy,onnx,onnxruntime,pytest,setuptools,tensorrt,torch
|
||||
known_third_party = mmcls,mmcv,mmdet,numpy,onnx,onnxruntime,packaging,pytest,setuptools,tensorrt,torch
|
||||
|
@ -63,7 +63,7 @@ endforeach(PLUGIN_ITER)
|
||||
|
||||
list(APPEND BACKEND_OPS_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/common_impl/trt_cuda_helper.cu")
|
||||
|
||||
set(INFER_PLUGIN_LIB ${TENSORRT_LIBRARY})
|
||||
set(INFER_PLUGIN_LIB ${TENSORRT_LIBRARY} cublas cudnn)
|
||||
|
||||
cuda_add_library(${SHARED_TARGET} MODULE ${BACKEND_OPS_SRCS})
|
||||
target_link_libraries(${SHARED_TARGET} ${INFER_PLUGIN_LIB})
|
||||
|
@ -31,11 +31,11 @@ TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data,
|
||||
deserialize_value(&data, &length, &mClipBoxes);
|
||||
}
|
||||
|
||||
int TRTBatchedNMS::getNbOutputs() const { return 2; }
|
||||
int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; }
|
||||
|
||||
nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) {
|
||||
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
|
||||
ASSERT(nbInputs == 2);
|
||||
ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs());
|
||||
ASSERT(inputs[0].nbDims == 4);
|
||||
@ -61,7 +61,8 @@ nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
|
||||
|
||||
size_t TRTBatchedNMS::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
|
||||
const nvinfer1::PluginTensorDesc* outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
size_t batch_size = inputs[0].dims.d[0];
|
||||
size_t boxes_size =
|
||||
inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3];
|
||||
@ -77,7 +78,7 @@ size_t TRTBatchedNMS::getWorkspaceSize(
|
||||
int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs,
|
||||
void* workSpace, cudaStream_t stream) {
|
||||
void* workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
const void* const locData = inputs[0];
|
||||
const void* const confData = inputs[1];
|
||||
|
||||
@ -102,12 +103,12 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t TRTBatchedNMS::getSerializationSize() const {
|
||||
size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT {
|
||||
// NMSParameters, boxesSize,scoresSize,numPriors
|
||||
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
|
||||
}
|
||||
|
||||
void TRTBatchedNMS::serialize(void* buffer) const {
|
||||
void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, param);
|
||||
serialize_value(&buffer, boxesSize);
|
||||
serialize_value(&buffer, scoresSize);
|
||||
@ -117,13 +118,14 @@ void TRTBatchedNMS::serialize(void* buffer) const {
|
||||
|
||||
void TRTBatchedNMS::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) {
|
||||
const nvinfer1::DynamicPluginTensorDesc* outputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
// Validate input arguments
|
||||
}
|
||||
|
||||
bool TRTBatchedNMS::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 3) {
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
@ -132,13 +134,15 @@ bool TRTBatchedNMS::supportsFormatCombination(
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMS::getPluginType() const { return NMS_PLUGIN_NAME; }
|
||||
const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT {
|
||||
return NMS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMS::getPluginVersion() const {
|
||||
const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return NMS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
IPluginV2DynamicExt* TRTBatchedNMS::clone() const {
|
||||
IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT {
|
||||
auto* plugin = new TRTBatchedNMS(mLayerName, param);
|
||||
plugin->boxesSize = boxesSize;
|
||||
plugin->scoresSize = scoresSize;
|
||||
@ -149,7 +153,8 @@ IPluginV2DynamicExt* TRTBatchedNMS::clone() const {
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTBatchedNMS::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
|
||||
int index, const nvinfer1::DataType* inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
ASSERT(index >= 0 && index < this->getNbOutputs());
|
||||
if (index == 1) {
|
||||
return nvinfer1::DataType::kINT32;
|
||||
@ -181,16 +186,16 @@ TRTBatchedNMSCreator::TRTBatchedNMSCreator() {
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSCreator::getPluginName() const {
|
||||
const char* TRTBatchedNMSCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return NMS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSCreator::getPluginVersion() const {
|
||||
const char* TRTBatchedNMSCreator::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return NMS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(
|
||||
const char* name, const PluginFieldCollection* fc) {
|
||||
const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT {
|
||||
const PluginField* fields = fc->fields;
|
||||
bool clipBoxes = true;
|
||||
nvinfer1::plugin::NMSParameters params{};
|
||||
@ -228,9 +233,9 @@ IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(
|
||||
return plugin;
|
||||
}
|
||||
|
||||
IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name,
|
||||
const void* serialData,
|
||||
size_t serialLength) {
|
||||
IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(
|
||||
const char* name, const void* serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call NMS::destroy()
|
||||
TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength);
|
||||
|
@ -13,46 +13,47 @@ class TRTBatchedNMS : public TRTPluginBase {
|
||||
|
||||
TRTBatchedNMS(const std::string& name, const void* data, size_t length);
|
||||
|
||||
~TRTBatchedNMS() override = default;
|
||||
~TRTBatchedNMS() TRT_NOEXCEPT override = default;
|
||||
|
||||
int getNbOutputs() const override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) override;
|
||||
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
|
||||
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs,
|
||||
int nbOutputs) const override;
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
|
||||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workSpace,
|
||||
cudaStream_t stream) override;
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
size_t getSerializationSize() const override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
|
||||
void serialize(void* buffer) const override;
|
||||
void serialize(void* buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* outputs,
|
||||
int nbOutputs) override;
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc* inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginType() const override;
|
||||
const char* getPluginType() const TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
const char* getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const override;
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType* inputType,
|
||||
int nbInputs) const override;
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputType,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
void setClipParam(bool clip);
|
||||
|
||||
@ -68,18 +69,19 @@ class TRTBatchedNMSCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
TRTBatchedNMSCreator();
|
||||
|
||||
~TRTBatchedNMSCreator() override = default;
|
||||
~TRTBatchedNMSCreator() TRT_NOEXCEPT override = default;
|
||||
|
||||
const char* getPluginName() const override;
|
||||
const char* getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
const char* getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2Ext* createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
|
||||
const char* name,
|
||||
const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
|
||||
const void* serialData,
|
||||
size_t serialLength) override;
|
||||
nvinfer1::IPluginV2Ext* deserializePlugin(
|
||||
const char* name, const void* serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmlab
|
||||
#endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
|
||||
|
@ -1,26 +1,37 @@
|
||||
#ifndef TRT_PLUGIN_BASE_HPP
|
||||
#define TRT_PLUGIN_BASE_HPP
|
||||
#include "NvInferPlugin.h"
|
||||
#include "NvInferVersion.h"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
namespace mmlab {
|
||||
|
||||
#if NV_TENSORRT_MAJOR > 7
|
||||
#define TRT_NOEXCEPT noexcept
|
||||
#else
|
||||
#define TRT_NOEXCEPT
|
||||
#endif
|
||||
|
||||
class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
TRTPluginBase(const std::string &name) : mLayerName(name) {}
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginVersion() const override { return "1"; }
|
||||
int initialize() override { return STATUS_SUCCESS; }
|
||||
void terminate() override {}
|
||||
void destroy() override { delete this; }
|
||||
void setPluginNamespace(const char *pluginNamespace) override {
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
|
||||
int initialize() TRT_NOEXCEPT override { return STATUS_SUCCESS; }
|
||||
void terminate() TRT_NOEXCEPT override {}
|
||||
void destroy() TRT_NOEXCEPT override { delete this; }
|
||||
void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override {
|
||||
mNamespace = pluginNamespace;
|
||||
}
|
||||
const char *getPluginNamespace() const override { return mNamespace.c_str(); }
|
||||
const char *getPluginNamespace() const TRT_NOEXCEPT override {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
protected:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
|
||||
#if NV_TENSORRT_MAJOR < 8
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
@ -30,21 +41,24 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
#endif
|
||||
};
|
||||
|
||||
class TRTPluginCreatorBase : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
const char *getPluginVersion() const override { return "1"; };
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; };
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() override {
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() TRT_NOEXCEPT override {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) override {
|
||||
void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override {
|
||||
mNamespace = pluginNamespace;
|
||||
}
|
||||
|
||||
const char *getPluginNamespace() const override { return mNamespace.c_str(); }
|
||||
const char *getPluginNamespace() const TRT_NOEXCEPT override {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
protected:
|
||||
nvinfer1::PluginFieldCollection mFC;
|
||||
|
@ -31,18 +31,19 @@ TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name,
|
||||
TRTInstanceNormalization::~TRTInstanceNormalization() {}
|
||||
|
||||
// TRTInstanceNormalization returns one output.
|
||||
int TRTInstanceNormalization::getNbOutputs() const { return 1; }
|
||||
int TRTInstanceNormalization::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
DimsExprs TRTInstanceNormalization::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) {
|
||||
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
|
||||
nvinfer1::DimsExprs output(inputs[0]);
|
||||
return output;
|
||||
}
|
||||
|
||||
size_t TRTInstanceNormalization::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
|
||||
const nvinfer1::PluginTensorDesc* outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
int n = inputs[0].dims.d[0];
|
||||
int c = inputs[0].dims.d[1];
|
||||
int elem_size = getElementSize(inputs[1].type);
|
||||
@ -52,7 +53,7 @@ size_t TRTInstanceNormalization::getWorkspaceSize(
|
||||
int TRTInstanceNormalization::enqueue(
|
||||
const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
|
||||
void* const* outputs, void* workspace, cudaStream_t stream) {
|
||||
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
nvinfer1::Dims input_dims = inputDesc[0].dims;
|
||||
int n = input_dims.d[0];
|
||||
int c = input_dims.d[1];
|
||||
@ -97,47 +98,48 @@ int TRTInstanceNormalization::enqueue(
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t TRTInstanceNormalization::getSerializationSize() const {
|
||||
size_t TRTInstanceNormalization::getSerializationSize() const TRT_NOEXCEPT {
|
||||
return serialized_size(mEpsilon);
|
||||
}
|
||||
|
||||
void TRTInstanceNormalization::serialize(void* buffer) const {
|
||||
void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, mEpsilon);
|
||||
}
|
||||
|
||||
bool TRTInstanceNormalization::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
|
||||
inOut[pos].type == nvinfer1::DataType::kHALF) &&
|
||||
inOut[pos].format == nvinfer1::PluginFormat::kLINEAR &&
|
||||
inOut[pos].type == inOut[0].type);
|
||||
}
|
||||
|
||||
const char* TRTInstanceNormalization::getPluginType() const {
|
||||
const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* TRTInstanceNormalization::getPluginVersion() const {
|
||||
const char* TRTInstanceNormalization::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
IPluginV2DynamicExt* TRTInstanceNormalization::clone() const {
|
||||
IPluginV2DynamicExt* TRTInstanceNormalization::clone() const TRT_NOEXCEPT {
|
||||
auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon};
|
||||
plugin->setPluginNamespace(mPluginNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
|
||||
int index, const nvinfer1::DataType* inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// Attach the plugin object to an execution context and grant the plugin the
|
||||
// access to some context resource.
|
||||
void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext,
|
||||
cublasContext* cublasContext,
|
||||
IGpuAllocator* gpuAllocator) {
|
||||
void TRTInstanceNormalization::attachToContext(
|
||||
cudnnContext* cudnnContext, cublasContext* cublasContext,
|
||||
IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {
|
||||
_cudnn_handle = cudnnContext;
|
||||
cudnnCreateTensorDescriptor(&_b_desc);
|
||||
cudnnCreateTensorDescriptor(&_x_desc);
|
||||
@ -145,7 +147,7 @@ void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext,
|
||||
}
|
||||
|
||||
// Detach the plugin object from its execution context.
|
||||
void TRTInstanceNormalization::detachFromContext() {
|
||||
void TRTInstanceNormalization::detachFromContext() TRT_NOEXCEPT {
|
||||
cudnnDestroyTensorDescriptor(_y_desc);
|
||||
cudnnDestroyTensorDescriptor(_x_desc);
|
||||
cudnnDestroyTensorDescriptor(_b_desc);
|
||||
@ -153,7 +155,7 @@ void TRTInstanceNormalization::detachFromContext() {
|
||||
|
||||
void TRTInstanceNormalization::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
|
||||
|
||||
// TRTInstanceNormalizationCreator methods
|
||||
TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() {
|
||||
@ -165,16 +167,18 @@ TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() {
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char* TRTInstanceNormalizationCreator::getPluginName() const {
|
||||
const char* TRTInstanceNormalizationCreator::getPluginName() const
|
||||
TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* TRTInstanceNormalizationCreator::getPluginVersion() const {
|
||||
const char* TRTInstanceNormalizationCreator::getPluginVersion() const
|
||||
TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) {
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
|
||||
float epsilon = 1e-5;
|
||||
const PluginField* fields = fc->fields;
|
||||
for (int i = 0; i < fc->nbFields; ++i) {
|
||||
@ -190,7 +194,8 @@ IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin(
|
||||
}
|
||||
|
||||
IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin(
|
||||
const char* name, const void* serialData, size_t serialLength) {
|
||||
const char* name, const void* serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
TRTInstanceNormalization* obj =
|
||||
new TRTInstanceNormalization{name, serialData, serialLength};
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
|
@ -23,53 +23,55 @@ class TRTInstanceNormalization final : public TRTPluginBase {
|
||||
|
||||
TRTInstanceNormalization() = delete;
|
||||
|
||||
~TRTInstanceNormalization() override;
|
||||
~TRTInstanceNormalization() TRT_NOEXCEPT override;
|
||||
|
||||
int getNbOutputs() const override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
|
||||
// DynamicExt plugins returns DimsExprs class instead of Dims
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) override;
|
||||
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
|
||||
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs,
|
||||
int nbOutputs) const override;
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
|
||||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) override;
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
size_t getSerializationSize() const override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
|
||||
void serialize(void* buffer) const override;
|
||||
void serialize(void* buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
// DynamicExt plugin supportsFormat update.
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc* inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginType() const override;
|
||||
const char* getPluginType() const TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
const char* getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const override;
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType* inputTypes,
|
||||
int nbInputs) const override;
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
void attachToContext(cudnnContext* cudnn, cublasContext* cublas,
|
||||
nvinfer1::IGpuAllocator* allocator) override;
|
||||
nvinfer1::IGpuAllocator* allocator)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
void detachFromContext() override;
|
||||
void detachFromContext() TRT_NOEXCEPT override;
|
||||
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||
int nbOutputs) override;
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
float mEpsilon{};
|
||||
@ -84,15 +86,17 @@ class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase {
|
||||
|
||||
~TRTInstanceNormalizationCreator() override = default;
|
||||
|
||||
const char* getPluginName() const override;
|
||||
const char* getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
const char* getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt* createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
|
||||
const char* name,
|
||||
const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt* deserializePlugin(
|
||||
const char* name, const void* serialData, size_t serialLength) override;
|
||||
const char* name, const void* serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmlab
|
||||
#endif // TRT_INSTANCE_NORMALIZATION_HPP
|
||||
|
@ -39,7 +39,8 @@ TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name,
|
||||
deserialize_value(&data, &length, &mFeatmapStrides);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const {
|
||||
nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const
|
||||
TRT_NOEXCEPT {
|
||||
TRTMultiLevelRoiAlign *plugin = new TRTMultiLevelRoiAlign(
|
||||
mLayerName, mAlignedHeight, mAlignedWidth, mSampleNum, mFeatmapStrides,
|
||||
mRoiScaleFactor, mFinestScale, mAligned);
|
||||
@ -50,7 +51,7 @@ nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const {
|
||||
|
||||
nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
|
||||
|
||||
nvinfer1::DimsExprs ret;
|
||||
@ -65,14 +66,15 @@ nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
|
||||
|
||||
bool TRTMultiLevelRoiAlign::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
void TRTMultiLevelRoiAlign::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
// Validate input arguments
|
||||
ASSERT(nbOutputs == 1);
|
||||
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
|
||||
@ -80,7 +82,8 @@ void TRTMultiLevelRoiAlign::configurePlugin(
|
||||
|
||||
size_t TRTMultiLevelRoiAlign::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -88,7 +91,7 @@ int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs,
|
||||
void *const *outputs, void *workSpace,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream) TRT_NOEXCEPT {
|
||||
int num_rois = inputDesc[0].dims.d[0];
|
||||
int batch_size = inputDesc[1].dims.d[0];
|
||||
int channels = inputDesc[1].dims.d[1];
|
||||
@ -118,27 +121,30 @@ int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return nvinfer1::DataType::kFLOAT;
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *TRTMultiLevelRoiAlign::getPluginType() const { return PLUGIN_NAME; }
|
||||
const char *TRTMultiLevelRoiAlign::getPluginType() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTMultiLevelRoiAlign::getPluginVersion() const {
|
||||
const char *TRTMultiLevelRoiAlign::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int TRTMultiLevelRoiAlign::getNbOutputs() const { return 1; }
|
||||
int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t TRTMultiLevelRoiAlign::getSerializationSize() const {
|
||||
size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT {
|
||||
return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) +
|
||||
serialized_size(mAlignedWidth) + serialized_size(mSampleNum) +
|
||||
serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) +
|
||||
serialized_size(mAligned);
|
||||
}
|
||||
|
||||
void TRTMultiLevelRoiAlign::serialize(void *buffer) const {
|
||||
void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, mAlignedHeight);
|
||||
serialize_value(&buffer, mAlignedWidth);
|
||||
serialize_value(&buffer, mSampleNum);
|
||||
@ -161,16 +167,17 @@ TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() {
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *TRTMultiLevelRoiAlignCreator::getPluginName() const {
|
||||
const char *TRTMultiLevelRoiAlignCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTMultiLevelRoiAlignCreator::getPluginVersion() const {
|
||||
const char *TRTMultiLevelRoiAlignCreator::getPluginVersion() const
|
||||
TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
int alignedHeight = 7;
|
||||
int alignedWidth = 7;
|
||||
int sampleNum = 2;
|
||||
@ -215,7 +222,8 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
|
@ -24,37 +24,38 @@ class TRTMultiLevelRoiAlign : public TRTPluginBase {
|
||||
TRTMultiLevelRoiAlign() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
int mAlignedHeight;
|
||||
@ -70,16 +71,17 @@ class TRTMultiLevelRoiAlignCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
TRTMultiLevelRoiAlignCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name,
|
||||
const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
nvinfer1::IPluginV2 *deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmlab
|
||||
#endif // TRT_ROI_ALIGN_HPP
|
||||
|
@ -32,7 +32,7 @@ TRTNMS::TRTNMS(const std::string name, const void *data, size_t length)
|
||||
deserialize_value(&data, &length, &mOffset);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *TRTNMS::clone() const {
|
||||
nvinfer1::IPluginV2DynamicExt *TRTNMS::clone() const TRT_NOEXCEPT {
|
||||
TRTNMS *plugin =
|
||||
new TRTNMS(mLayerName, mCenterPointBox, mMaxOutputBoxesPerClass,
|
||||
mIouThreshold, mScoreThreshold, mOffset);
|
||||
@ -43,7 +43,7 @@ nvinfer1::IPluginV2DynamicExt *TRTNMS::clone() const {
|
||||
|
||||
nvinfer1::DimsExprs TRTNMS::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 2;
|
||||
auto num_batches = inputs[0].d[0];
|
||||
@ -65,7 +65,8 @@ nvinfer1::DimsExprs TRTNMS::getOutputDimensions(
|
||||
|
||||
bool TRTNMS::supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) {
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos < nbInputs) {
|
||||
switch (pos) {
|
||||
case 0:
|
||||
@ -95,12 +96,12 @@ bool TRTNMS::supportsFormatCombination(int pos,
|
||||
void TRTNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||
int nbOutputs) {}
|
||||
int nbOutputs) TRT_NOEXCEPT {}
|
||||
|
||||
size_t TRTNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const {
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
size_t boxes_word_size = mmlab::getElementSize(inputs[0].type);
|
||||
size_t num_batches = inputs[0].dims.d[0];
|
||||
size_t spatial_dimension = inputs[0].dims.d[1];
|
||||
@ -115,7 +116,7 @@ size_t TRTNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int TRTNMS::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs,
|
||||
void *workSpace, cudaStream_t stream) {
|
||||
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
int num_batches = inputDesc[0].dims.d[0];
|
||||
int spatial_dimension = inputDesc[0].dims.d[1];
|
||||
int num_classes = inputDesc[1].dims.d[1];
|
||||
@ -133,25 +134,28 @@ int TRTNMS::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTNMS::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return nvinfer1::DataType::kINT32;
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *TRTNMS::getPluginType() const { return PLUGIN_NAME; }
|
||||
const char *TRTNMS::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||
|
||||
const char *TRTNMS::getPluginVersion() const { return PLUGIN_VERSION; }
|
||||
const char *TRTNMS::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int TRTNMS::getNbOutputs() const { return 1; }
|
||||
int TRTNMS::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t TRTNMS::getSerializationSize() const {
|
||||
size_t TRTNMS::getSerializationSize() const TRT_NOEXCEPT {
|
||||
return serialized_size(mCenterPointBox) +
|
||||
serialized_size(mMaxOutputBoxesPerClass) +
|
||||
serialized_size(mIouThreshold) + serialized_size(mScoreThreshold) +
|
||||
serialized_size(mOffset);
|
||||
}
|
||||
|
||||
void TRTNMS::serialize(void *buffer) const {
|
||||
void TRTNMS::serialize(void *buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, mCenterPointBox);
|
||||
serialize_value(&buffer, mMaxOutputBoxesPerClass);
|
||||
serialize_value(&buffer, mIouThreshold);
|
||||
@ -171,12 +175,16 @@ TRTNMSCreator::TRTNMSCreator() {
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *TRTNMSCreator::getPluginName() const { return PLUGIN_NAME; }
|
||||
const char *TRTNMSCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTNMSCreator::getPluginVersion() const { return PLUGIN_VERSION; }
|
||||
const char *TRTNMSCreator::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTNMSCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
int centerPointBox = 0;
|
||||
int maxOutputBoxesPerClass = 0;
|
||||
float iouThreshold = 0.0f;
|
||||
@ -215,9 +223,9 @@ nvinfer1::IPluginV2 *TRTNMSCreator::createPlugin(
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTNMSCreator::deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) {
|
||||
nvinfer1::IPluginV2 *TRTNMSCreator::deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new TRTNMS(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
|
@ -19,37 +19,38 @@ class TRTNMS : public TRTPluginBase {
|
||||
TRTNMS() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
int mCenterPointBox;
|
||||
@ -63,16 +64,17 @@ class TRTNMSCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
TRTNMSCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name,
|
||||
const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
nvinfer1::IPluginV2 *deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmlab
|
||||
#endif // TRT_NMS_HPP
|
||||
|
@ -36,7 +36,7 @@ TRTRoIAlign::TRTRoIAlign(const std::string name, const void *data,
|
||||
deserialize_value(&data, &length, &mAligned);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const {
|
||||
nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const TRT_NOEXCEPT {
|
||||
TRTRoIAlign *plugin =
|
||||
new TRTRoIAlign(mLayerName, mOutWidth, mOutHeight, mSpatialScale,
|
||||
mSampleRatio, mPoolMode, mAligned);
|
||||
@ -47,7 +47,7 @@ nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const {
|
||||
|
||||
nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 4;
|
||||
ret.d[0] = inputs[1].d[0];
|
||||
@ -60,19 +60,20 @@ nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions(
|
||||
|
||||
bool TRTRoIAlign::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
void TRTRoIAlign::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {}
|
||||
|
||||
size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const {
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
size_t output_size = 0;
|
||||
size_t word_size = 0;
|
||||
switch (mPoolMode) {
|
||||
@ -94,7 +95,7 @@ size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs,
|
||||
void *workSpace, cudaStream_t stream) {
|
||||
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
int channels = inputDesc[0].dims.d[1];
|
||||
int height = inputDesc[0].dims.d[2];
|
||||
int width = inputDesc[0].dims.d[3];
|
||||
@ -135,24 +136,29 @@ int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTRoIAlign::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *TRTRoIAlign::getPluginType() const { return PLUGIN_NAME; }
|
||||
const char *TRTRoIAlign::getPluginType() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTRoIAlign::getPluginVersion() const { return PLUGIN_VERSION; }
|
||||
const char *TRTRoIAlign::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int TRTRoIAlign::getNbOutputs() const { return 1; }
|
||||
int TRTRoIAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t TRTRoIAlign::getSerializationSize() const {
|
||||
size_t TRTRoIAlign::getSerializationSize() const TRT_NOEXCEPT {
|
||||
return serialized_size(mOutWidth) + serialized_size(mOutHeight) +
|
||||
serialized_size(mSpatialScale) + serialized_size(mSampleRatio) +
|
||||
serialized_size(mPoolMode) + serialized_size(mAligned);
|
||||
}
|
||||
|
||||
void TRTRoIAlign::serialize(void *buffer) const {
|
||||
void TRTRoIAlign::serialize(void *buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, mOutWidth);
|
||||
serialize_value(&buffer, mOutHeight);
|
||||
serialize_value(&buffer, mSpatialScale);
|
||||
@ -172,14 +178,16 @@ TRTRoIAlignCreator::TRTRoIAlignCreator() {
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *TRTRoIAlignCreator::getPluginName() const { return PLUGIN_NAME; }
|
||||
const char *TRTRoIAlignCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTRoIAlignCreator::getPluginVersion() const {
|
||||
const char *TRTRoIAlignCreator::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
int outWidth = 7;
|
||||
int outHeight = 7;
|
||||
float spatialScale = 1.0;
|
||||
@ -241,7 +249,8 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTRoIAlignCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new TRTRoIAlign(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
|
@ -18,37 +18,38 @@ class TRTRoIAlign : public TRTPluginBase {
|
||||
TRTRoIAlign() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
int mOutWidth;
|
||||
@ -63,15 +64,16 @@ class TRTRoIAlignCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
TRTRoIAlignCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name,
|
||||
const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
nvinfer1::IPluginV2 *deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmlab
|
||||
#endif // TRT_ROI_ALIGN_HPP
|
||||
|
@ -20,7 +20,7 @@ TRTScatterND::TRTScatterND(const std::string name, const void *data,
|
||||
size_t length)
|
||||
: TRTPluginBase(name) {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const {
|
||||
nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const TRT_NOEXCEPT {
|
||||
TRTScatterND *plugin = new TRTScatterND(mLayerName);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
@ -29,13 +29,13 @@ nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const {
|
||||
|
||||
nvinfer1::DimsExprs TRTScatterND::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
bool TRTScatterND::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos < nbInputs) {
|
||||
switch (pos) {
|
||||
case 0:
|
||||
@ -70,19 +70,20 @@ bool TRTScatterND::supportsFormatCombination(
|
||||
|
||||
void TRTScatterND::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {}
|
||||
|
||||
size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const {
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs,
|
||||
void *workSpace, cudaStream_t stream) {
|
||||
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
const int *dims = &(inputDesc[0].dims.d[0]);
|
||||
const int *indices_dims = &(inputDesc[1].dims.d[0]);
|
||||
int nbDims = inputDesc[0].dims.nbDims;
|
||||
@ -115,20 +116,25 @@ int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTScatterND::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *TRTScatterND::getPluginType() const { return PLUGIN_NAME; }
|
||||
const char *TRTScatterND::getPluginType() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTScatterND::getPluginVersion() const { return PLUGIN_VERSION; }
|
||||
const char *TRTScatterND::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int TRTScatterND::getNbOutputs() const { return 1; }
|
||||
int TRTScatterND::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t TRTScatterND::getSerializationSize() const { return 0; }
|
||||
size_t TRTScatterND::getSerializationSize() const TRT_NOEXCEPT { return 0; }
|
||||
|
||||
void TRTScatterND::serialize(void *buffer) const {}
|
||||
void TRTScatterND::serialize(void *buffer) const TRT_NOEXCEPT {}
|
||||
|
||||
TRTScatterNDCreator::TRTScatterNDCreator() {
|
||||
mPluginAttributes.clear();
|
||||
@ -136,21 +142,24 @@ TRTScatterNDCreator::TRTScatterNDCreator() {
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *TRTScatterNDCreator::getPluginName() const { return PLUGIN_NAME; }
|
||||
const char *TRTScatterNDCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTScatterNDCreator::getPluginVersion() const {
|
||||
const char *TRTScatterNDCreator::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTScatterNDCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
TRTScatterND *plugin = new TRTScatterND(name);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTScatterNDCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new TRTScatterND(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
|
@ -18,52 +18,54 @@ class TRTScatterND : public TRTPluginBase {
|
||||
TRTScatterND() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
};
|
||||
|
||||
class TRTScatterNDCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
TRTScatterNDCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name,
|
||||
const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
nvinfer1::IPluginV2 *deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmlab
|
||||
#endif // TRT_SCATTERND_HPP
|
||||
|
@ -1,6 +1,7 @@
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
|
||||
def create_trt_engine(onnx_model,
|
||||
@ -56,7 +57,8 @@ def create_trt_engine(onnx_model,
|
||||
raise RuntimeError(f'parse onnx failed:\n{error_msgs}')
|
||||
|
||||
# config builder
|
||||
builder.max_workspace_size = max_workspace_size
|
||||
if version.parse(trt.__version__) < version.parse('8'):
|
||||
builder.max_workspace_size = max_workspace_size
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = max_workspace_size
|
||||
|
Loading…
x
Reference in New Issue
Block a user