mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* start trt plugin prototype * Add test module, modify roialign convertor * finish roi_align trt plugin * fix conflict of RoiAlign and MMCVRoiAlign * fix for lint * fix test tensorrt module * test_tensorrt move import to test func * add except error type * add tensorrt to setup.cfg * code format with yapf * fix for clang-format * move tensorrt_utils to mmcv/tensorrt, add comments, better test module * fix line endings, docformatter * isort init, remove trailing whitespace * add except type * fix setup.py * put import extension inside trt setup * change c++ guard, update pytest script, better setup, etc * sort import with isort * sort import with isort * move init of plugin lib to init_plugins.py * resolve format and add test dependency: tensorrt * tensorrt should be installed from source not from pypi * update naming style and input check * resolve lint error Co-authored-by: maningsheng <maningsheng@sensetime.com>
28 lines
638 B
C++
28 lines
638 B
C++
#ifndef TRT_PLUGIN_HELPER_HPP
|
|
#define TRT_PLUGIN_HELPER_HPP
|
|
#include <stdexcept>
|
|
|
|
#include "NvInferPlugin.h"
|
|
|
|
namespace mmcv {
|
|
|
|
inline unsigned int getElementSize(nvinfer1::DataType t) {
|
|
switch (t) {
|
|
case nvinfer1::DataType::kINT32:
|
|
return 4;
|
|
case nvinfer1::DataType::kFLOAT:
|
|
return 4;
|
|
case nvinfer1::DataType::kHALF:
|
|
return 2;
|
|
// case nvinfer1::DataType::kBOOL:
|
|
case nvinfer1::DataType::kINT8:
|
|
return 1;
|
|
default:
|
|
throw std::runtime_error("Invalid DataType.");
|
|
}
|
|
throw std::runtime_error("Invalid DataType.");
|
|
return 0;
|
|
}
|
|
} // namespace mmcv
|
|
#endif // TRT_PLUGIN_HELPER_HPP
|