mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature]: Add custom operators support for TensorRT in mmcv (#686)
* start trt plugin prototype * Add test module, modify roialign convertor * finish roi_align trt plugin * fix conflict of RoiAlign and MMCVRoiAlign * fix for lint * fix test tensorrt module * test_tensorrt move import to test func * add except error type * add tensorrt to setup.cfg * code format with yapf * fix for clang-format * move tensorrt_utils to mmcv/tensorrt, add comments, better test module * fix line endings, docformatter * isort init, remove trailing whitespace * add except type * fix setup.py * put import extension inside trt setup * change c++ guard, update pytest script, better setup, etc * sort import with isort * sort import with isort * move init of plugin lib to init_plugins.py * resolve format and add test dependency: tensorrt * tensorrt should be installed from source not from pypi * update naming style and input check * resolve lint error Co-authored-by: maningsheng <maningsheng@sensetime.com>
This commit is contained in:
parent
643009e445
commit
0de9e149c0
@ -1,4 +1,5 @@
|
||||
from .info import is_custom_op_loaded
|
||||
from .simplify import simplify
|
||||
from .symbolic import register_extra_symbolics
|
||||
|
||||
__all__ = ['register_extra_symbolics', 'simplify']
|
||||
__all__ = ['register_extra_symbolics', 'simplify', 'is_custom_op_loaded']
|
||||
|
18
mmcv/onnx/info.py
Normal file
18
mmcv/onnx/info.py
Normal file
@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
|
||||
def is_custom_op_loaded():
|
||||
flag = False
|
||||
try:
|
||||
from ..tensorrt import is_tensorrt_plugin_loaded
|
||||
flag = is_tensorrt_plugin_loaded()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
if not flag:
|
||||
try:
|
||||
from ..ops import get_onnxruntime_op_path
|
||||
ort_lib_path = get_onnxruntime_op_path()
|
||||
flag = os.path.exists(ort_lib_path)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
return flag
|
@ -11,20 +11,19 @@
|
||||
|
||||
struct MMCVRoiAlignKernel {
|
||||
public:
|
||||
MMCVRoiAlignKernel(Ort::CustomOpApi ort, const OrtKernelInfo *info)
|
||||
MMCVRoiAlignKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
|
||||
: ort_(ort) {
|
||||
aligned_ = ort_.KernelInfoGetAttribute<int64_t>(info, "aligned");
|
||||
aligned_height_ =
|
||||
ort_.KernelInfoGetAttribute<int64_t>(info, "aligned_height");
|
||||
aligned_width_ =
|
||||
ort_.KernelInfoGetAttribute<int64_t>(info, "aligned_width");
|
||||
pool_mode_ = ort_.KernelInfoGetAttribute<std::string>(info, "pool_mode");
|
||||
ort_.KernelInfoGetAttribute<int64_t>(info, "output_height");
|
||||
aligned_width_ = ort_.KernelInfoGetAttribute<int64_t>(info, "output_width");
|
||||
pool_mode_ = ort_.KernelInfoGetAttribute<std::string>(info, "mode");
|
||||
sampling_ratio_ =
|
||||
ort_.KernelInfoGetAttribute<int64_t>(info, "sampling_ratio");
|
||||
spatial_scale_ = ort_.KernelInfoGetAttribute<float>(info, "spatial_scale");
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext *context);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
@ -39,10 +38,10 @@ struct MMCVRoiAlignKernel {
|
||||
|
||||
struct MMCVRoiAlignCustomOp
|
||||
: Ort::CustomOpBase<MMCVRoiAlignCustomOp, MMCVRoiAlignKernel> {
|
||||
void *CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo *info) {
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
|
||||
return new MMCVRoiAlignKernel(api, info);
|
||||
}
|
||||
const char *GetName() const { return "MMCVRoiAlign"; }
|
||||
const char* GetName() const { return "MMCVRoiAlign"; }
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; }
|
||||
ONNXTensorElementDataType GetInputType(size_t) const {
|
||||
@ -55,7 +54,7 @@ struct MMCVRoiAlignCustomOp
|
||||
}
|
||||
|
||||
// force cpu
|
||||
const char *GetExecutionProviderType() const {
|
||||
const char* GetExecutionProviderType() const {
|
||||
return "CPUExecutionProvider";
|
||||
}
|
||||
};
|
||||
|
@ -1,11 +1,16 @@
|
||||
#ifndef ROI_ALIGN_CUDA_KERNEL_CUH
|
||||
#define ROI_ALIGN_CUDA_KERNEL_CUH
|
||||
|
||||
#include <float.h>
|
||||
#ifdef MMCV_WITH_TRT
|
||||
#include "common_cuda_helper.hpp"
|
||||
#else // MMCV_WITH_TRT
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#else // MMCV_USE_PARROTS
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
#endif // MMCV_USE_PARROTS
|
||||
#endif // MMCV_WITH_TRT
|
||||
|
||||
/*** Forward ***/
|
||||
template <typename T>
|
||||
|
9
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
Normal file
9
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include "trt_plugin.hpp"
|
||||
|
||||
#include "trt_roi_align.hpp"
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
|
||||
|
||||
extern "C" {
|
||||
bool initLibMMCVInferPlugins() { return true; }
|
||||
} // extern "C"
|
293
mmcv/ops/csrc/tensorrt/plugins/trt_roi_align.cpp
Normal file
293
mmcv/ops/csrc/tensorrt/plugins/trt_roi_align.cpp
Normal file
@ -0,0 +1,293 @@
|
||||
#include "trt_roi_align.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
extern void TRTRoIAlignForwardCUDAKernelLauncher_float(
|
||||
const float *input, const float *rois, float *output, float *argmax_y,
|
||||
float *argmax_x, int output_size, int channels, int height, int width,
|
||||
int aligned_height, int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream);
|
||||
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"MMCVRoiAlign"};
|
||||
} // namespace
|
||||
|
||||
nvinfer1::PluginFieldCollection RoIAlignPluginDynamicCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField>
|
||||
RoIAlignPluginDynamicCreator::mPluginAttributes;
|
||||
|
||||
RoIAlignPluginDynamic::RoIAlignPluginDynamic(const std::string &name,
|
||||
int outWidth, int outHeight,
|
||||
float spatialScale,
|
||||
int sampleRatio, int poolMode,
|
||||
bool aligned)
|
||||
: mLayerName(name),
|
||||
mOutWidth(outWidth),
|
||||
mOutHeight(outHeight),
|
||||
mSpatialScale(spatialScale),
|
||||
mSampleRatio(sampleRatio),
|
||||
mPoolMode(poolMode),
|
||||
mAligned(aligned) {}
|
||||
|
||||
RoIAlignPluginDynamic::RoIAlignPluginDynamic(const std::string name,
|
||||
const void *data, size_t length)
|
||||
: mLayerName(name) {
|
||||
deserialize_value(&data, &length, &mOutWidth);
|
||||
deserialize_value(&data, &length, &mOutHeight);
|
||||
deserialize_value(&data, &length, &mSpatialScale);
|
||||
deserialize_value(&data, &length, &mSampleRatio);
|
||||
deserialize_value(&data, &length, &mPoolMode);
|
||||
deserialize_value(&data, &length, &mAligned);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *RoIAlignPluginDynamic::clone() const {
|
||||
RoIAlignPluginDynamic *plugin = new RoIAlignPluginDynamic(
|
||||
mLayerName, mOutWidth, mOutHeight, mSpatialScale, mSampleRatio, mPoolMode,
|
||||
mAligned);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs RoIAlignPluginDynamic::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 4;
|
||||
ret.d[0] = inputs[1].d[0];
|
||||
ret.d[1] = inputs[0].d[1];
|
||||
ret.d[2] = exprBuilder.constant(mOutHeight);
|
||||
ret.d[3] = exprBuilder.constant(mOutWidth);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool RoIAlignPluginDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
void RoIAlignPluginDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
|
||||
|
||||
size_t RoIAlignPluginDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
|
||||
size_t output_size = 0;
|
||||
size_t word_size = 0;
|
||||
switch (mPoolMode) {
|
||||
case 0: // max
|
||||
output_size = outputs[0].dims.d[0] * outputs[0].dims.d[1] *
|
||||
outputs[0].dims.d[2] * outputs[0].dims.d[3];
|
||||
word_size = mmcv::getElementSize(outputs[0].type);
|
||||
return output_size * word_size * 2;
|
||||
break;
|
||||
case 1:
|
||||
return 0;
|
||||
break;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int RoIAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs,
|
||||
void *const *outputs, void *workSpace,
|
||||
cudaStream_t stream) {
|
||||
int channels = inputDesc[0].dims.d[1];
|
||||
int height = inputDesc[0].dims.d[2];
|
||||
int width = inputDesc[0].dims.d[3];
|
||||
|
||||
int output_size = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] *
|
||||
outputDesc[0].dims.d[2] * outputDesc[0].dims.d[3];
|
||||
int word_size = mmcv::getElementSize(outputDesc[0].type);
|
||||
|
||||
const void *feat = inputs[0];
|
||||
const void *rois = inputs[1];
|
||||
void *output = outputs[0];
|
||||
void *argmax_y = nullptr;
|
||||
void *argmax_x = nullptr;
|
||||
|
||||
switch (mPoolMode) {
|
||||
case 0: // max
|
||||
argmax_y = workSpace;
|
||||
argmax_x = argmax_y + output_size * word_size;
|
||||
break;
|
||||
case 1: // avg
|
||||
break;
|
||||
}
|
||||
|
||||
switch (outputDesc[0].type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
TRTRoIAlignForwardCUDAKernelLauncher_float(
|
||||
(const float *)feat, (const float *)rois, (float *)output,
|
||||
(float *)argmax_y, (float *)argmax_x, output_size, channels, height,
|
||||
width, mOutHeight, mOutWidth, mSpatialScale, mSampleRatio, mPoolMode,
|
||||
mAligned, stream);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType RoIAlignPluginDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *RoIAlignPluginDynamic::getPluginType() const { return PLUGIN_NAME; }
|
||||
|
||||
const char *RoIAlignPluginDynamic::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int RoIAlignPluginDynamic::getNbOutputs() const { return 1; }
|
||||
|
||||
int RoIAlignPluginDynamic::initialize() { return 0; }
|
||||
|
||||
void RoIAlignPluginDynamic::terminate() {}
|
||||
|
||||
size_t RoIAlignPluginDynamic::getSerializationSize() const {
|
||||
return sizeof(mOutWidth) + sizeof(mOutHeight) + sizeof(mSpatialScale) +
|
||||
sizeof(mSampleRatio) + sizeof(mPoolMode) + sizeof(mAligned);
|
||||
}
|
||||
|
||||
void RoIAlignPluginDynamic::serialize(void *buffer) const {
|
||||
serialize_value(&buffer, mOutWidth);
|
||||
serialize_value(&buffer, mOutHeight);
|
||||
serialize_value(&buffer, mSpatialScale);
|
||||
serialize_value(&buffer, mSampleRatio);
|
||||
serialize_value(&buffer, mPoolMode);
|
||||
serialize_value(&buffer, mAligned);
|
||||
}
|
||||
|
||||
void RoIAlignPluginDynamic::destroy() {
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
void RoIAlignPluginDynamic::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *RoIAlignPluginDynamic::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
RoIAlignPluginDynamicCreator::RoIAlignPluginDynamicCreator() {
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_height"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_width"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("spatial_scale"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("sampling_ratio"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("mode"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("aligned"));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *RoIAlignPluginDynamicCreator::getPluginName() const {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *RoIAlignPluginDynamicCreator::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const nvinfer1::PluginFieldCollection *
|
||||
RoIAlignPluginDynamicCreator::getFieldNames() {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *RoIAlignPluginDynamicCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
int outWidth = 7;
|
||||
int outHeight = 7;
|
||||
float spatialScale = 1.0;
|
||||
int sampleRatio = 0;
|
||||
int poolMode = -1;
|
||||
bool aligned = true;
|
||||
for (int i = 0; i < fc->nbFields; i++) {
|
||||
if (fc->fields[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string field_name(fc->fields[i].name);
|
||||
|
||||
if (field_name.compare("output_height") == 0) {
|
||||
outHeight = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
|
||||
if (field_name.compare("output_width") == 0) {
|
||||
outWidth = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
|
||||
if (field_name.compare("spatial_scale") == 0) {
|
||||
spatialScale = static_cast<const float *>(fc->fields[i].data)[0];
|
||||
}
|
||||
|
||||
if (field_name.compare("sampling_ratio") == 0) {
|
||||
sampleRatio = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
|
||||
if (field_name.compare("mode") == 0) {
|
||||
int data_size = fc->fields[i].length;
|
||||
const char *data_start = static_cast<const char *>(fc->fields[i].data);
|
||||
std::string poolModeStr(data_start, data_size);
|
||||
if (poolModeStr == "avg") {
|
||||
poolMode = 1;
|
||||
} else if (poolModeStr == "max") {
|
||||
poolMode = 0;
|
||||
} else {
|
||||
std::cout << "Unknown pool mode \"" << poolModeStr << "\"."
|
||||
<< std::endl;
|
||||
}
|
||||
assert(poolMode >= 0);
|
||||
}
|
||||
|
||||
if (field_name.compare("aligned") == 0) {
|
||||
int aligned_int = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
aligned = aligned_int != 0;
|
||||
}
|
||||
}
|
||||
|
||||
assert(outHeight > 0);
|
||||
assert(outWidth > 0);
|
||||
assert(spatialScale > 0.);
|
||||
assert(poolMode >= 0);
|
||||
|
||||
RoIAlignPluginDynamic *plugin = new RoIAlignPluginDynamic(
|
||||
name, outWidth, outHeight, spatialScale, sampleRatio, poolMode, aligned);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *RoIAlignPluginDynamicCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
auto plugin = new RoIAlignPluginDynamic(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void RoIAlignPluginDynamicCreator::setPluginNamespace(
|
||||
const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *RoIAlignPluginDynamicCreator::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
28
mmcv/ops/csrc/tensorrt/plugins/trt_roi_align_kernel.cu
Normal file
28
mmcv/ops/csrc/tensorrt/plugins/trt_roi_align_kernel.cu
Normal file
@ -0,0 +1,28 @@
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "roi_align_cuda_kernel.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
void TRTRoIAlignForwardCUDAKernelLauncher(
|
||||
const scalar_t* input, const scalar_t* rois, scalar_t* output,
|
||||
scalar_t* argmax_y, scalar_t* argmax_x, int output_size, int channels,
|
||||
int height, int width, int aligned_height, int aligned_width,
|
||||
scalar_t spatial_scale, int sampling_ratio, int pool_mode, bool aligned,
|
||||
cudaStream_t stream) {
|
||||
roi_align_forward_cuda_kernel<scalar_t>
|
||||
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, input, rois, output, argmax_y, argmax_x, aligned_height,
|
||||
aligned_width, static_cast<scalar_t>(spatial_scale), sampling_ratio,
|
||||
pool_mode, aligned, channels, height, width);
|
||||
}
|
||||
|
||||
void TRTRoIAlignForwardCUDAKernelLauncher_float(
|
||||
const float* input, const float* rois, float* output, float* argmax_y,
|
||||
float* argmax_x, int output_size, int channels, int height, int width,
|
||||
int aligned_height, int aligned_width, float spatial_scale,
|
||||
int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream) {
|
||||
TRTRoIAlignForwardCUDAKernelLauncher<float>(
|
||||
input, rois, output, argmax_y, argmax_x, output_size, channels, height,
|
||||
width, aligned_height, aligned_width, spatial_scale, sampling_ratio,
|
||||
pool_mode, aligned, stream);
|
||||
}
|
7
mmcv/ops/csrc/tensorrt/trt_plugin.hpp
Normal file
7
mmcv/ops/csrc/tensorrt/trt_plugin.hpp
Normal file
@ -0,0 +1,7 @@
|
||||
#ifndef TRT_PLUGIN_HPP
|
||||
#define TRT_PLUGIN_HPP
|
||||
|
||||
extern "C" {
|
||||
bool initLibMMCVInferPlugins();
|
||||
} // extern "C"
|
||||
#endif // TRT_PLUGIN_HPP
|
27
mmcv/ops/csrc/tensorrt/trt_plugin_helper.hpp
Normal file
27
mmcv/ops/csrc/tensorrt/trt_plugin_helper.hpp
Normal file
@ -0,0 +1,27 @@
|
||||
#ifndef TRT_PLUGIN_HELPER_HPP
|
||||
#define TRT_PLUGIN_HELPER_HPP
|
||||
#include <stdexcept>
|
||||
|
||||
#include "NvInferPlugin.h"
|
||||
|
||||
namespace mmcv {
|
||||
|
||||
inline unsigned int getElementSize(nvinfer1::DataType t) {
|
||||
switch (t) {
|
||||
case nvinfer1::DataType::kINT32:
|
||||
return 4;
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return 4;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return 2;
|
||||
// case nvinfer1::DataType::kBOOL:
|
||||
case nvinfer1::DataType::kINT8:
|
||||
return 1;
|
||||
default:
|
||||
throw std::runtime_error("Invalid DataType.");
|
||||
}
|
||||
throw std::runtime_error("Invalid DataType.");
|
||||
return 0;
|
||||
}
|
||||
} // namespace mmcv
|
||||
#endif // TRT_PLUGIN_HELPER_HPP
|
108
mmcv/ops/csrc/tensorrt/trt_roi_align.hpp
Normal file
108
mmcv/ops/csrc/tensorrt/trt_roi_align.hpp
Normal file
@ -0,0 +1,108 @@
|
||||
#ifndef TRT_ROI_ALIGN_HPP
|
||||
#define TRT_ROI_ALIGN_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
class RoIAlignPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
RoIAlignPluginDynamic(const std::string &name, int outWidth, int outHeight,
|
||||
float spatialScale, int sampleRatio, int poolMode,
|
||||
bool aligned);
|
||||
|
||||
RoIAlignPluginDynamic(const std::string name, const void *data,
|
||||
size_t length);
|
||||
|
||||
RoIAlignPluginDynamic() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
void destroy() override;
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
|
||||
int mOutWidth;
|
||||
int mOutHeight;
|
||||
float mSpatialScale;
|
||||
int mSampleRatio;
|
||||
int mPoolMode; // 1:avg 0:max
|
||||
bool mAligned;
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
};
|
||||
|
||||
class RoIAlignPluginDynamicCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
RoIAlignPluginDynamicCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
std::string mNamespace;
|
||||
};
|
||||
#endif // TRT_ROI_ALIGN_HPP
|
117
mmcv/ops/csrc/tensorrt/trt_serialize.hpp
Normal file
117
mmcv/ops/csrc/tensorrt/trt_serialize.hpp
Normal file
@ -0,0 +1,117 @@
|
||||
/*
|
||||
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef TRT_SERIALIZE_HPP
|
||||
#define TRT_SERIALIZE_HPP
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
using std::cerr;
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
template <typename T>
|
||||
inline void serialize_value(void** buffer, T const& value);
|
||||
|
||||
template <typename T>
|
||||
inline void deserialize_value(void const** buffer, size_t* buffer_size,
|
||||
T* value);
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct Serializer {};
|
||||
|
||||
template <typename T>
|
||||
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
|
||||
std::is_enum<T>::value ||
|
||||
std::is_pod<T>::value>::type> {
|
||||
static size_t serialized_size(T const& value) { return sizeof(T); }
|
||||
static void serialize(void** buffer, T const& value) {
|
||||
::memcpy(*buffer, &value, sizeof(T));
|
||||
reinterpret_cast<char*&>(*buffer) += sizeof(T);
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
|
||||
assert(*buffer_size >= sizeof(T));
|
||||
::memcpy(value, *buffer, sizeof(T));
|
||||
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
|
||||
*buffer_size -= sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Serializer<const char*> {
|
||||
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
|
||||
static void serialize(void** buffer, const char* value) {
|
||||
::strcpy(static_cast<char*>(*buffer), value);
|
||||
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size,
|
||||
const char** value) {
|
||||
*value = static_cast<char const*>(*buffer);
|
||||
size_t data_size = strnlen(*value, *buffer_size) + 1;
|
||||
assert(*buffer_size >= data_size);
|
||||
reinterpret_cast<char const*&>(*buffer) += data_size;
|
||||
*buffer_size -= data_size;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Serializer<std::vector<T>,
|
||||
typename std::enable_if<std::is_arithmetic<T>::value ||
|
||||
std::is_enum<T>::value ||
|
||||
std::is_pod<T>::value>::type> {
|
||||
static size_t serialized_size(std::vector<T> const& value) {
|
||||
return sizeof(value.size()) + value.size() * sizeof(T);
|
||||
}
|
||||
static void serialize(void** buffer, std::vector<T> const& value) {
|
||||
serialize_value(buffer, value.size());
|
||||
size_t nbyte = value.size() * sizeof(T);
|
||||
::memcpy(*buffer, value.data(), nbyte);
|
||||
reinterpret_cast<char*&>(*buffer) += nbyte;
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size,
|
||||
std::vector<T>* value) {
|
||||
size_t size;
|
||||
deserialize_value(buffer, buffer_size, &size);
|
||||
value->resize(size);
|
||||
size_t nbyte = value->size() * sizeof(T);
|
||||
assert(*buffer_size >= nbyte);
|
||||
::memcpy(value->data(), *buffer, nbyte);
|
||||
reinterpret_cast<char const*&>(*buffer) += nbyte;
|
||||
*buffer_size -= nbyte;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
inline size_t serialized_size(T const& value) {
|
||||
return Serializer<T>::serialized_size(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void serialize_value(void** buffer, T const& value) {
|
||||
return Serializer<T>::serialize(buffer, value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void deserialize_value(void const** buffer, size_t* buffer_size,
|
||||
T* value) {
|
||||
return Serializer<T>::deserialize(buffer, buffer_size, value);
|
||||
}
|
||||
#endif // TRT_SERIALIZE_HPP
|
@ -4,6 +4,7 @@ from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from ..onnx import is_custom_op_loaded
|
||||
from ..utils import deprecated_api_warning, ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext',
|
||||
@ -15,55 +16,48 @@ class RoIAlignFunction(Function):
|
||||
@staticmethod
|
||||
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
|
||||
pool_mode, aligned):
|
||||
has_custom_op = False
|
||||
try:
|
||||
import os.path as osp
|
||||
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_op_path = get_onnxruntime_op_path()
|
||||
has_custom_op = osp.exists(ort_op_path)
|
||||
except ImportError:
|
||||
pass
|
||||
has_custom_op = is_custom_op_loaded()
|
||||
if has_custom_op:
|
||||
return g.op(
|
||||
'mmcv::MMCVRoiAlign',
|
||||
input,
|
||||
rois,
|
||||
aligned_height_i=output_size[0],
|
||||
aligned_width_i=output_size[1],
|
||||
output_height_i=output_size[0],
|
||||
output_width_i=output_size[1],
|
||||
spatial_scale_f=spatial_scale,
|
||||
sampling_ratio_i=sampling_ratio,
|
||||
mode_s=pool_mode,
|
||||
aligned_i=aligned)
|
||||
else:
|
||||
from torch.onnx.symbolic_opset9 import sub, squeeze
|
||||
from torch.onnx.symbolic_helper import _slice_helper
|
||||
from torch.onnx import TensorProtoDataType
|
||||
# batch_indices = rois[:, 0].long()
|
||||
batch_indices = _slice_helper(
|
||||
g, rois, axes=[1], starts=[0], ends=[1])
|
||||
batch_indices = squeeze(g, batch_indices, 1)
|
||||
batch_indices = g.op(
|
||||
'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
|
||||
# rois = rois[:, 1:]
|
||||
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
|
||||
if aligned:
|
||||
# rois -= 0.5/spatial_scale
|
||||
aligned_offset = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([0.5 / spatial_scale],
|
||||
dtype=torch.float32))
|
||||
rois = sub(g, rois, aligned_offset)
|
||||
# roi align
|
||||
return g.op(
|
||||
'RoiAlign',
|
||||
input,
|
||||
rois,
|
||||
batch_indices,
|
||||
output_height_i=output_size[0],
|
||||
output_width_i=output_size[1],
|
||||
spatial_scale_f=spatial_scale,
|
||||
sampling_ratio_i=max(0, sampling_ratio),
|
||||
pool_mode_s=pool_mode,
|
||||
aligned_i=aligned)
|
||||
|
||||
from torch.onnx.symbolic_opset9 import sub, squeeze
|
||||
from torch.onnx.symbolic_helper import _slice_helper
|
||||
from torch.onnx import TensorProtoDataType
|
||||
# batch_indices = rois[:, 0].long()
|
||||
batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1])
|
||||
batch_indices = squeeze(g, batch_indices, 1)
|
||||
batch_indices = g.op(
|
||||
'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
|
||||
# rois = rois[:, 1:]
|
||||
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
|
||||
if aligned:
|
||||
# rois -= 0.5/spatial_scale
|
||||
aligned_offset = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([0.5 / spatial_scale],
|
||||
dtype=torch.float32))
|
||||
rois = sub(g, rois, aligned_offset)
|
||||
# roi align
|
||||
return g.op(
|
||||
'RoiAlign',
|
||||
input,
|
||||
rois,
|
||||
batch_indices,
|
||||
output_height_i=output_size[0],
|
||||
output_width_i=output_size[1],
|
||||
spatial_scale_f=spatial_scale,
|
||||
sampling_ratio_i=max(0, sampling_ratio),
|
||||
mode_s=pool_mode)
|
||||
mode_s=pool_mode)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
|
12
mmcv/tensorrt/__init__.py
Normal file
12
mmcv/tensorrt/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# flake8: noqa
|
||||
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
|
||||
from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt,
|
||||
save_trt_engine)
|
||||
|
||||
# load tensorrt plugin lib
|
||||
load_tensorrt_plugin()
|
||||
|
||||
__all__ = [
|
||||
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
|
||||
'is_tensorrt_plugin_loaded'
|
||||
]
|
36
mmcv/tensorrt/init_plugins.py
Normal file
36
mmcv/tensorrt/init_plugins.py
Normal file
@ -0,0 +1,36 @@
|
||||
import ctypes
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
||||
def get_tensorrt_op_path():
|
||||
"""Get TensorRT plugins library path."""
|
||||
wildcard = os.path.join(
|
||||
os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
|
||||
'_ext_trt.*.so')
|
||||
|
||||
paths = glob.glob(wildcard)
|
||||
lib_path = paths[0] if len(paths) > 0 else ''
|
||||
return lib_path
|
||||
|
||||
|
||||
plugin_is_loaded = False
|
||||
|
||||
|
||||
def is_tensorrt_plugin_loaded():
|
||||
"""Check if TensorRT plugins library is loaded or not.
|
||||
|
||||
Returns:
|
||||
bool: plugin_is_loaded flag
|
||||
"""
|
||||
global plugin_is_loaded
|
||||
return plugin_is_loaded
|
||||
|
||||
|
||||
def load_tensorrt_plugin():
|
||||
"""load TensorRT plugins library."""
|
||||
global plugin_is_loaded
|
||||
lib_path = get_tensorrt_op_path()
|
||||
if (not plugin_is_loaded) and os.path.exists(lib_path):
|
||||
ctypes.CDLL(lib_path)
|
||||
plugin_is_loaded = True
|
209
mmcv/tensorrt/tensorrt_utils.py
Normal file
209
mmcv/tensorrt/tensorrt_utils.py
Normal file
@ -0,0 +1,209 @@
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
|
||||
def onnx2trt(onnx_model,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=0,
|
||||
device_id=0):
|
||||
"""Convert onnx model to tensorrt engine.
|
||||
|
||||
Arguments:
|
||||
onnx_model (str or onnx.ModelProto): the onnx model to convert from
|
||||
opt_shape_dict (dict): the min/opt/max shape of each input
|
||||
log_level (TensorRT log level): the log level of TensorRT
|
||||
fp16_mode (bool): enable fp16 mode
|
||||
max_workspace_size (int): set max workspace size of TensorRT engine.
|
||||
some tactic and layers need large workspace.
|
||||
device_id (int): choice the device to create engine.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
|
||||
|
||||
Example:
|
||||
>>> engine = onnx2trt(
|
||||
>>> "onnx_model.onnx",
|
||||
>>> {'input': [[1, 3, 160, 160],
|
||||
>>> [1, 3, 320, 320],
|
||||
>>> [1, 3, 640, 640]]},
|
||||
>>> log_level=trt.Logger.WARNING,
|
||||
>>> fp16_mode=True,
|
||||
>>> max_workspace_size=1 << 30,
|
||||
>>> device_id=0)
|
||||
>>> })
|
||||
"""
|
||||
device = torch.device('cuda:{}'.format(device_id))
|
||||
# create builder and network
|
||||
logger = trt.Logger(log_level)
|
||||
builder = trt.Builder(logger)
|
||||
EXPLICIT_BATCH = 1 << (int)(
|
||||
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(EXPLICIT_BATCH)
|
||||
|
||||
# parse onnx
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
if isinstance(onnx_model, str):
|
||||
assert parser.parse_from_file(onnx_model), 'parse onnx failed.'
|
||||
else:
|
||||
assert parser.parse(
|
||||
onnx_model.SerializeToString()), 'parse onnx failed.'
|
||||
|
||||
# config builder
|
||||
builder.max_workspace_size = max_workspace_size
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = max_workspace_size
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
for input_name, param in opt_shape_dict.items():
|
||||
min_shape = tuple(param[0][:])
|
||||
opt_shape = tuple(param[1][:])
|
||||
max_shape = tuple(param[2][:])
|
||||
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
if fp16_mode:
|
||||
builder.fp16_mode = fp16_mode
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# create engine
|
||||
with torch.cuda.device(device):
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def save_trt_engine(engine, path):
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Arguments:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
|
||||
path (str): disk path to write the engine
|
||||
"""
|
||||
with open(path, mode='wb') as f:
|
||||
f.write(bytearray(engine.serialize()))
|
||||
|
||||
|
||||
def load_trt_engine(path):
|
||||
"""Deserialize TensorRT engine from disk.
|
||||
|
||||
Arguments:
|
||||
path (str): disk path to read the engine
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
|
||||
"""
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
with open(path, mode='rb') as f:
|
||||
engine_bytes = f.read()
|
||||
engine = runtime.deserialize_cuda_engine(engine_bytes)
|
||||
return engine
|
||||
|
||||
|
||||
def torch_dtype_from_trt(dtype):
|
||||
"""Convert pytorch dtype to TensorRT dtype."""
|
||||
if dtype == trt.bool:
|
||||
return torch.bool
|
||||
elif dtype == trt.int8:
|
||||
return torch.int8
|
||||
elif dtype == trt.int32:
|
||||
return torch.int32
|
||||
elif dtype == trt.float16:
|
||||
return torch.float16
|
||||
elif dtype == trt.float32:
|
||||
return torch.float32
|
||||
else:
|
||||
raise TypeError('%s is not supported by torch' % dtype)
|
||||
|
||||
|
||||
def torch_device_from_trt(device):
|
||||
"""Convert pytorch device to TensorRT device."""
|
||||
if device == trt.TensorLocation.DEVICE:
|
||||
return torch.device('cuda')
|
||||
elif device == trt.TensorLocation.HOST:
|
||||
return torch.device('cpu')
|
||||
else:
|
||||
return TypeError('%s is not supported by torch' % device)
|
||||
|
||||
|
||||
class TRTWraper(torch.nn.Module):
|
||||
"""TensorRT engine Wraper.
|
||||
|
||||
Arguments:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
|
||||
input_names (list[str]): names of each inputs
|
||||
output_names (list[str]): names of each outputs
|
||||
|
||||
Note:
|
||||
If the engine is converted from onnx model. The input_names and
|
||||
output_names should be the same as onnx model.
|
||||
"""
|
||||
|
||||
def __init__(self, engine, input_names, output_names):
|
||||
super(TRTWraper, self).__init__()
|
||||
self.engine = engine
|
||||
if isinstance(self.engine, str):
|
||||
self.engine = load_trt_engine(engine)
|
||||
|
||||
if not isinstance(self.engine, trt.ICudaEngine):
|
||||
raise TypeError('engine should be str or trt.ICudaEngine')
|
||||
|
||||
self._register_state_dict_hook(TRTWraper._on_state_dict)
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
self.input_names = input_names
|
||||
self.output_names = output_names
|
||||
|
||||
def _on_state_dict(self, state_dict, prefix, local_metadata):
|
||||
state_dict[prefix + 'engine'] = bytearray(self.engine.serialize())
|
||||
state_dict[prefix + 'input_names'] = self.input_names
|
||||
state_dict[prefix + 'output_names'] = self.output_names
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
engine_bytes = state_dict[prefix + 'engine']
|
||||
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
self.input_names = state_dict[prefix + 'input_names']
|
||||
self.output_names = state_dict[prefix + 'output_names']
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Arguments:
|
||||
inputs (dict): dict of input name-tensors pair
|
||||
|
||||
Return:
|
||||
dict: dict of output name-tensors pair
|
||||
"""
|
||||
assert self.input_names is not None
|
||||
assert self.output_names is not None
|
||||
bindings = [None] * (len(self.input_names) + len(self.output_names))
|
||||
|
||||
for input_name, input_tensor in inputs.items():
|
||||
idx = self.engine.get_binding_index(input_name)
|
||||
|
||||
self.context.set_binding_shape(idx, tuple(input_tensor.shape))
|
||||
bindings[idx] = input_tensor.contiguous().data_ptr()
|
||||
|
||||
# create output tensors
|
||||
outputs = {}
|
||||
for i, output_name in enumerate(self.output_names):
|
||||
idx = self.engine.get_binding_index(output_name)
|
||||
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
|
||||
shape = tuple(self.context.get_binding_shape(idx))
|
||||
|
||||
device = torch_device_from_trt(self.engine.get_location(idx))
|
||||
output = torch.empty(size=shape, dtype=dtype, device=device)
|
||||
outputs[output_name] = output
|
||||
bindings[idx] = output.data_ptr()
|
||||
|
||||
self.context.execute_async_v2(bindings,
|
||||
torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
return outputs
|
@ -14,6 +14,6 @@ line_length = 79
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcv
|
||||
known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
|
||||
known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
49
setup.py
49
setup.py
@ -54,7 +54,6 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
|
||||
"""
|
||||
import sys
|
||||
from os.path import exists
|
||||
import re
|
||||
require_fpath = fname
|
||||
|
||||
def parse_line(line):
|
||||
@ -134,6 +133,51 @@ except ImportError:
|
||||
def get_extensions():
|
||||
extensions = []
|
||||
|
||||
if os.getenv('MMCV_WITH_TRT', '0') != '0':
|
||||
ext_name = 'mmcv._ext_trt'
|
||||
from torch.utils.cpp_extension import include_paths, library_paths
|
||||
library_dirs = []
|
||||
libraries = []
|
||||
include_dirs = []
|
||||
tensorrt_path = os.getenv('TENSORRT_DIR', '0')
|
||||
tensorrt_lib_path = glob.glob(
|
||||
os.path.join(tensorrt_path, 'targets', '*', 'lib'))[0]
|
||||
library_dirs += [tensorrt_lib_path]
|
||||
libraries += ['nvinfer', 'nvparsers', 'nvinfer_plugin']
|
||||
libraries += ['cudart']
|
||||
kwargs = {}
|
||||
define_macros = []
|
||||
extra_compile_args = {'cxx': []}
|
||||
|
||||
include_path = os.path.abspath('./mmcv/ops/csrc')
|
||||
include_trt_path = os.path.abspath('./mmcv/ops/csrc/tensorrt')
|
||||
include_dirs.append(include_path)
|
||||
include_dirs.append(include_trt_path)
|
||||
include_dirs.append(os.path.join(tensorrt_path, 'include'))
|
||||
include_dirs += include_paths(cuda=True)
|
||||
|
||||
op_files = glob.glob('./mmcv/ops/csrc/tensorrt/plugins/*')
|
||||
define_macros += [('MMCV_WITH_CUDA', None)]
|
||||
define_macros += [('MMCV_WITH_TRT', None)]
|
||||
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
|
||||
library_dirs += library_paths(cuda=True)
|
||||
|
||||
kwargs['library_dirs'] = library_dirs
|
||||
kwargs['libraries'] = libraries
|
||||
|
||||
from setuptools import Extension
|
||||
ext_ops = Extension(
|
||||
name=ext_name,
|
||||
sources=op_files,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
language='c++',
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries)
|
||||
extensions.append(ext_ops)
|
||||
|
||||
if os.getenv('MMCV_WITH_OPS', '0') == '0':
|
||||
return extensions
|
||||
|
||||
@ -157,7 +201,8 @@ def get_extensions():
|
||||
extensions.append(ext_ops)
|
||||
elif EXT_TYPE == 'pytorch':
|
||||
ext_name = 'mmcv._ext'
|
||||
from torch.utils.cpp_extension import (CUDAExtension, CppExtension)
|
||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
||||
|
||||
# prevent ninja from using too many resources
|
||||
os.environ.setdefault('MAX_JOBS', '4')
|
||||
define_macros = []
|
||||
|
@ -61,9 +61,7 @@ def test_nms():
|
||||
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason='CUDA is unavailable for test_softnms')
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
|
||||
def test_softnms():
|
||||
from mmcv.ops import get_onnxruntime_op_path, soft_nms
|
||||
|
||||
@ -77,7 +75,8 @@ def test_softnms():
|
||||
'1.5.1'), 'test_softnms should be ran with onnxruntime >= 1.5.1'
|
||||
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
assert os.path.exists(ort_custom_op_path)
|
||||
if not os.path.exists(ort_custom_op_path):
|
||||
pytest.skip('softnms for onnxruntime is not compiled.')
|
||||
|
||||
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
|
||||
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
|
||||
@ -138,13 +137,13 @@ def test_softnms():
|
||||
|
||||
|
||||
def test_roialign():
|
||||
from mmcv.ops import roi_align
|
||||
ort_custom_op_path = ''
|
||||
try:
|
||||
from mmcv.ops import roi_align
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except ImportError:
|
||||
pass
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('roi_align op is not successfully compiled')
|
||||
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
# roi align config
|
||||
pool_h = 2
|
||||
pool_w = 2
|
||||
@ -208,9 +207,8 @@ def test_roialign():
|
||||
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
|
||||
def test_roipool():
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
from mmcv.ops import roi_pool
|
||||
|
||||
# roi pool config
|
||||
|
93
tests/test_ops/test_tensorrt.py
Normal file
93
tests/test_ops/test_tensorrt.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
onnx_file = 'tmp.onnx'
|
||||
trt_file = 'tmp.engine'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is required for test_roialign')
|
||||
def test_roialign():
|
||||
try:
|
||||
from mmcv.tensorrt import (TRTWraper, onnx2trt, save_trt_engine,
|
||||
is_tensorrt_plugin_loaded)
|
||||
if not is_tensorrt_plugin_loaded():
|
||||
pytest.skip('test requires to complie TensorRT plugins in mmcv')
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('test requires to install TensorRT from source.')
|
||||
|
||||
try:
|
||||
from mmcv.ops import RoIAlign
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('test requires compilation')
|
||||
|
||||
# trt config
|
||||
fp16_mode = False
|
||||
max_workspace_size = 1 << 30
|
||||
|
||||
# roi align config
|
||||
pool_h = 2
|
||||
pool_w = 2
|
||||
spatial_scale = 1.0
|
||||
sampling_ratio = 2
|
||||
|
||||
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
|
||||
([[[[1., 2.], [3., 4.]], [[4., 3.],
|
||||
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
|
||||
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
|
||||
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
|
||||
|
||||
wrapped_model = RoIAlign((pool_w, pool_h), spatial_scale, sampling_ratio,
|
||||
'avg', True).cuda()
|
||||
for case in inputs:
|
||||
np_input = np.array(case[0], dtype=np.float32)
|
||||
np_rois = np.array(case[1], dtype=np.float32)
|
||||
input = torch.from_numpy(np_input).cuda()
|
||||
rois = torch.from_numpy(np_rois).cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapped_model, (input, rois),
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=['input', 'rois'],
|
||||
output_names=['roi_feat'],
|
||||
opset_version=11)
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {
|
||||
'input': [list(input.shape),
|
||||
list(input.shape),
|
||||
list(input.shape)],
|
||||
'rois': [list(rois.shape),
|
||||
list(rois.shape),
|
||||
list(rois.shape)]
|
||||
}
|
||||
trt_engine = onnx2trt(
|
||||
onnx_model,
|
||||
opt_shape_dict,
|
||||
fp16_mode=fp16_mode,
|
||||
max_workspace_size=max_workspace_size)
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
trt_model = TRTWraper(trt_file, ['input', 'rois'], ['roi_feat'])
|
||||
|
||||
with torch.no_grad():
|
||||
trt_outputs = trt_model({'input': input, 'rois': rois})
|
||||
trt_roi_feat = trt_outputs['roi_feat']
|
||||
|
||||
# compute pytorch_output
|
||||
with torch.no_grad():
|
||||
pytorch_roi_feat = wrapped_model(input, rois)
|
||||
|
||||
# allclose
|
||||
if os.path.exists(onnx_file):
|
||||
os.remove(onnx_file)
|
||||
if os.path.exists(trt_file):
|
||||
os.remove(trt_file)
|
||||
assert torch.allclose(pytorch_roi_feat, trt_roi_feat)
|
Loading…
x
Reference in New Issue
Block a user