mmcv/mmcv/ops/csrc/tensorrt/trt_plugin_helper.hpp
q.yao 0de9e149c0
[Feature]: Add custom operators support for TensorRT in mmcv (#686)
* start trt plugin prototype

* Add test module, modify roialign convertor

* finish roi_align trt plugin

* fix conflict of RoiAlign and MMCVRoiAlign

* fix for lint

* fix test tensorrt module

* test_tensorrt move import to test func

* add except error type

* add tensorrt to setup.cfg

* code format with yapf

* fix for clang-format

* move tensorrt_utils to mmcv/tensorrt, add comments, better test module

* fix line endings, docformatter

* isort init, remove trailing whitespace

* add except type

* fix setup.py

* put import extension inside trt setup

* change c++ guard, update pytest script, better setup, etc

* sort import with isort

* sort import with isort

* move init of plugin lib to init_plugins.py

* resolve format and add test dependency: tensorrt

* tensorrt should be installed from source not from pypi

* update naming style and input check

* resolve lint error

Co-authored-by: maningsheng <maningsheng@sensetime.com>
2021-01-06 11:05:19 +08:00

28 lines
638 B
C++

#ifndef TRT_PLUGIN_HELPER_HPP
#define TRT_PLUGIN_HELPER_HPP
#include <stdexcept>
#include "NvInferPlugin.h"
namespace mmcv {
inline unsigned int getElementSize(nvinfer1::DataType t) {
switch (t) {
case nvinfer1::DataType::kINT32:
return 4;
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
return 2;
// case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kINT8:
return 1;
default:
throw std::runtime_error("Invalid DataType.");
}
throw std::runtime_error("Invalid DataType.");
return 0;
}
} // namespace mmcv
#endif // TRT_PLUGIN_HELPER_HPP