#ifndef TRT_PLUGIN_HELPER_HPP #define TRT_PLUGIN_HELPER_HPP #include #include "NvInferPlugin.h" namespace mmlab { const int MAXTENSORDIMS = 10; struct TensorDesc { int shape[MAXTENSORDIMS]; int stride[MAXTENSORDIMS]; int dim; }; inline unsigned int getElementSize(nvinfer1::DataType t) { switch (t) { case nvinfer1::DataType::kINT32: return 4; case nvinfer1::DataType::kFLOAT: return 4; case nvinfer1::DataType::kHALF: return 2; // case nvinfer1::DataType::kBOOL: case nvinfer1::DataType::kINT8: return 1; default: throw std::runtime_error("Invalid DataType."); } throw std::runtime_error("Invalid DataType."); return 0; } inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) { return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; } } // namespace mmlab #endif // TRT_PLUGIN_HELPER_HPP