mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
42 lines
942 B
C++
42 lines
942 B
C++
#ifndef TRT_PLUGIN_HELPER_HPP
|
|
#define TRT_PLUGIN_HELPER_HPP
|
|
#include <stdexcept>
|
|
|
|
#include "NvInferPlugin.h"
|
|
|
|
namespace mmlab {
|
|
|
|
const int MAXTENSORDIMS = 10;
|
|
|
|
struct TensorDesc {
|
|
int shape[MAXTENSORDIMS];
|
|
int stride[MAXTENSORDIMS];
|
|
int dim;
|
|
};
|
|
|
|
inline unsigned int getElementSize(nvinfer1::DataType t) {
|
|
switch (t) {
|
|
case nvinfer1::DataType::kINT32:
|
|
return 4;
|
|
case nvinfer1::DataType::kFLOAT:
|
|
return 4;
|
|
case nvinfer1::DataType::kHALF:
|
|
return 2;
|
|
// case nvinfer1::DataType::kBOOL:
|
|
case nvinfer1::DataType::kINT8:
|
|
return 1;
|
|
default:
|
|
throw std::runtime_error("Invalid DataType.");
|
|
}
|
|
throw std::runtime_error("Invalid DataType.");
|
|
return 0;
|
|
}
|
|
|
|
inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) {
|
|
return size_t((origin_size + aligned_number - 1) / aligned_number) *
|
|
aligned_number;
|
|
}
|
|
|
|
} // namespace mmlab
|
|
#endif // TRT_PLUGIN_HELPER_HPP
|