mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp
2021-06-15 20:44:45 +08:00

42 lines
942 B
C++

#ifndef TRT_PLUGIN_HELPER_HPP
#define TRT_PLUGIN_HELPER_HPP
#include <stdexcept>
#include "NvInferPlugin.h"
namespace mmlab {
const int MAXTENSORDIMS = 10;
struct TensorDesc {
int shape[MAXTENSORDIMS];
int stride[MAXTENSORDIMS];
int dim;
};
inline unsigned int getElementSize(nvinfer1::DataType t) {
switch (t) {
case nvinfer1::DataType::kINT32:
return 4;
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
return 2;
// case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kINT8:
return 1;
default:
throw std::runtime_error("Invalid DataType.");
}
throw std::runtime_error("Invalid DataType.");
return 0;
}
inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) {
return size_t((origin_size + aligned_number - 1) / aligned_number) *
aligned_number;
}
} // namespace mmlab
#endif // TRT_PLUGIN_HELPER_HPP