[Feature]: Add custom operators support for TensorRT in mmcv (#686)

* start trt plugin prototype

* Add test module, modify roialign convertor

* finish roi_align trt plugin

* fix conflict of RoiAlign and MMCVRoiAlign

* fix for lint

* fix test tensorrt module

* test_tensorrt move import to test func

* add except error type

* add tensorrt to setup.cfg

* code format with yapf

* fix for clang-format

* move tensorrt_utils to mmcv/tensorrt, add comments, better test module

* fix line endings, docformatter

* isort init, remove trailing whitespace

* add except type

* fix setup.py

* put import extension inside trt setup

* change c++ guard, update pytest script, better setup, etc

* sort import with isort

* sort import with isort

* move init of plugin lib to init_plugins.py

* resolve format and add test dependency: tensorrt

* tensorrt should be installed from source not from pypi

* update naming style and input check

* resolve lint error

Co-authored-by: maningsheng <maningsheng@sensetime.com>
This commit is contained in:
q.yao 2021-01-06 11:05:19 +08:00 committed by GitHub
parent 643009e445
commit 0de9e149c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1067 additions and 68 deletions

View File

@ -1,4 +1,5 @@
from .info import is_custom_op_loaded
from .simplify import simplify
from .symbolic import register_extra_symbolics
__all__ = ['register_extra_symbolics', 'simplify']
__all__ = ['register_extra_symbolics', 'simplify', 'is_custom_op_loaded']

18
mmcv/onnx/info.py Normal file
View File

@ -0,0 +1,18 @@
import os
def is_custom_op_loaded():
flag = False
try:
from ..tensorrt import is_tensorrt_plugin_loaded
flag = is_tensorrt_plugin_loaded()
except (ImportError, ModuleNotFoundError):
pass
if not flag:
try:
from ..ops import get_onnxruntime_op_path
ort_lib_path = get_onnxruntime_op_path()
flag = os.path.exists(ort_lib_path)
except (ImportError, ModuleNotFoundError):
pass
return flag

View File

@ -11,20 +11,19 @@
struct MMCVRoiAlignKernel {
public:
MMCVRoiAlignKernel(Ort::CustomOpApi ort, const OrtKernelInfo *info)
MMCVRoiAlignKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
: ort_(ort) {
aligned_ = ort_.KernelInfoGetAttribute<int64_t>(info, "aligned");
aligned_height_ =
ort_.KernelInfoGetAttribute<int64_t>(info, "aligned_height");
aligned_width_ =
ort_.KernelInfoGetAttribute<int64_t>(info, "aligned_width");
pool_mode_ = ort_.KernelInfoGetAttribute<std::string>(info, "pool_mode");
ort_.KernelInfoGetAttribute<int64_t>(info, "output_height");
aligned_width_ = ort_.KernelInfoGetAttribute<int64_t>(info, "output_width");
pool_mode_ = ort_.KernelInfoGetAttribute<std::string>(info, "mode");
sampling_ratio_ =
ort_.KernelInfoGetAttribute<int64_t>(info, "sampling_ratio");
spatial_scale_ = ort_.KernelInfoGetAttribute<float>(info, "spatial_scale");
}
void Compute(OrtKernelContext *context);
void Compute(OrtKernelContext* context);
private:
Ort::CustomOpApi ort_;
@ -39,10 +38,10 @@ struct MMCVRoiAlignKernel {
struct MMCVRoiAlignCustomOp
: Ort::CustomOpBase<MMCVRoiAlignCustomOp, MMCVRoiAlignKernel> {
void *CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo *info) {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
return new MMCVRoiAlignKernel(api, info);
}
const char *GetName() const { return "MMCVRoiAlign"; }
const char* GetName() const { return "MMCVRoiAlign"; }
size_t GetInputTypeCount() const { return 2; }
ONNXTensorElementDataType GetInputType(size_t) const {
@ -55,7 +54,7 @@ struct MMCVRoiAlignCustomOp
}
// force cpu
const char *GetExecutionProviderType() const {
const char* GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};

View File

@ -1,11 +1,16 @@
#ifndef ROI_ALIGN_CUDA_KERNEL_CUH
#define ROI_ALIGN_CUDA_KERNEL_CUH
#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
#endif
#endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT
/*** Forward ***/
template <typename T>

View File

@ -0,0 +1,9 @@
#include "trt_plugin.hpp"
#include "trt_roi_align.hpp"
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
extern "C" {
bool initLibMMCVInferPlugins() { return true; }
} // extern "C"

View File

@ -0,0 +1,293 @@
#include "trt_roi_align.hpp"
#include <assert.h>
#include <chrono>
#include "trt_serialize.hpp"
extern void TRTRoIAlignForwardCUDAKernelLauncher_float(
const float *input, const float *rois, float *output, float *argmax_y,
float *argmax_x, int output_size, int channels, int height, int width,
int aligned_height, int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream);
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"MMCVRoiAlign"};
} // namespace
nvinfer1::PluginFieldCollection RoIAlignPluginDynamicCreator::mFC{};
std::vector<nvinfer1::PluginField>
RoIAlignPluginDynamicCreator::mPluginAttributes;
RoIAlignPluginDynamic::RoIAlignPluginDynamic(const std::string &name,
int outWidth, int outHeight,
float spatialScale,
int sampleRatio, int poolMode,
bool aligned)
: mLayerName(name),
mOutWidth(outWidth),
mOutHeight(outHeight),
mSpatialScale(spatialScale),
mSampleRatio(sampleRatio),
mPoolMode(poolMode),
mAligned(aligned) {}
RoIAlignPluginDynamic::RoIAlignPluginDynamic(const std::string name,
const void *data, size_t length)
: mLayerName(name) {
deserialize_value(&data, &length, &mOutWidth);
deserialize_value(&data, &length, &mOutHeight);
deserialize_value(&data, &length, &mSpatialScale);
deserialize_value(&data, &length, &mSampleRatio);
deserialize_value(&data, &length, &mPoolMode);
deserialize_value(&data, &length, &mAligned);
}
nvinfer1::IPluginV2DynamicExt *RoIAlignPluginDynamic::clone() const {
RoIAlignPluginDynamic *plugin = new RoIAlignPluginDynamic(
mLayerName, mOutWidth, mOutHeight, mSpatialScale, mSampleRatio, mPoolMode,
mAligned);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs RoIAlignPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[1].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = exprBuilder.constant(mOutHeight);
ret.d[3] = exprBuilder.constant(mOutWidth);
return ret;
}
bool RoIAlignPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
void RoIAlignPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
size_t RoIAlignPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
size_t output_size = 0;
size_t word_size = 0;
switch (mPoolMode) {
case 0: // max
output_size = outputs[0].dims.d[0] * outputs[0].dims.d[1] *
outputs[0].dims.d[2] * outputs[0].dims.d[3];
word_size = mmcv::getElementSize(outputs[0].type);
return output_size * word_size * 2;
break;
case 1:
return 0;
break;
default:
return 0;
}
return 0;
}
int RoIAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs,
void *const *outputs, void *workSpace,
cudaStream_t stream) {
int channels = inputDesc[0].dims.d[1];
int height = inputDesc[0].dims.d[2];
int width = inputDesc[0].dims.d[3];
int output_size = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] *
outputDesc[0].dims.d[2] * outputDesc[0].dims.d[3];
int word_size = mmcv::getElementSize(outputDesc[0].type);
const void *feat = inputs[0];
const void *rois = inputs[1];
void *output = outputs[0];
void *argmax_y = nullptr;
void *argmax_x = nullptr;
switch (mPoolMode) {
case 0: // max
argmax_y = workSpace;
argmax_x = argmax_y + output_size * word_size;
break;
case 1: // avg
break;
}
switch (outputDesc[0].type) {
case nvinfer1::DataType::kFLOAT:
TRTRoIAlignForwardCUDAKernelLauncher_float(
(const float *)feat, (const float *)rois, (float *)output,
(float *)argmax_y, (float *)argmax_x, output_size, channels, height,
width, mOutHeight, mOutWidth, mSpatialScale, mSampleRatio, mPoolMode,
mAligned, stream);
break;
default:
break;
}
return 0;
}
nvinfer1::DataType RoIAlignPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return inputTypes[0];
}
// IPluginV2 Methods
const char *RoIAlignPluginDynamic::getPluginType() const { return PLUGIN_NAME; }
const char *RoIAlignPluginDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}
int RoIAlignPluginDynamic::getNbOutputs() const { return 1; }
int RoIAlignPluginDynamic::initialize() { return 0; }
void RoIAlignPluginDynamic::terminate() {}
size_t RoIAlignPluginDynamic::getSerializationSize() const {
return sizeof(mOutWidth) + sizeof(mOutHeight) + sizeof(mSpatialScale) +
sizeof(mSampleRatio) + sizeof(mPoolMode) + sizeof(mAligned);
}
void RoIAlignPluginDynamic::serialize(void *buffer) const {
serialize_value(&buffer, mOutWidth);
serialize_value(&buffer, mOutHeight);
serialize_value(&buffer, mSpatialScale);
serialize_value(&buffer, mSampleRatio);
serialize_value(&buffer, mPoolMode);
serialize_value(&buffer, mAligned);
}
void RoIAlignPluginDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}
void RoIAlignPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
const char *RoIAlignPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}
////////////////////// creator /////////////////////////////
RoIAlignPluginDynamicCreator::RoIAlignPluginDynamicCreator() {
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_height"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_width"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("spatial_scale"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("sampling_ratio"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("mode"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("aligned"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *RoIAlignPluginDynamicCreator::getPluginName() const {
return PLUGIN_NAME;
}
const char *RoIAlignPluginDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection *
RoIAlignPluginDynamicCreator::getFieldNames() {
return &mFC;
}
nvinfer1::IPluginV2 *RoIAlignPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
int outWidth = 7;
int outHeight = 7;
float spatialScale = 1.0;
int sampleRatio = 0;
int poolMode = -1;
bool aligned = true;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("output_height") == 0) {
outHeight = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("output_width") == 0) {
outWidth = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("spatial_scale") == 0) {
spatialScale = static_cast<const float *>(fc->fields[i].data)[0];
}
if (field_name.compare("sampling_ratio") == 0) {
sampleRatio = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("mode") == 0) {
int data_size = fc->fields[i].length;
const char *data_start = static_cast<const char *>(fc->fields[i].data);
std::string poolModeStr(data_start, data_size);
if (poolModeStr == "avg") {
poolMode = 1;
} else if (poolModeStr == "max") {
poolMode = 0;
} else {
std::cout << "Unknown pool mode \"" << poolModeStr << "\"."
<< std::endl;
}
assert(poolMode >= 0);
}
if (field_name.compare("aligned") == 0) {
int aligned_int = static_cast<const int *>(fc->fields[i].data)[0];
aligned = aligned_int != 0;
}
}
assert(outHeight > 0);
assert(outWidth > 0);
assert(spatialScale > 0.);
assert(poolMode >= 0);
RoIAlignPluginDynamic *plugin = new RoIAlignPluginDynamic(
name, outWidth, outHeight, spatialScale, sampleRatio, poolMode, aligned);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *RoIAlignPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
auto plugin = new RoIAlignPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
void RoIAlignPluginDynamicCreator::setPluginNamespace(
const char *libNamespace) {
mNamespace = libNamespace;
}
const char *RoIAlignPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}

View File

@ -0,0 +1,28 @@
#include "common_cuda_helper.hpp"
#include "roi_align_cuda_kernel.cuh"
template <typename scalar_t>
void TRTRoIAlignForwardCUDAKernelLauncher(
const scalar_t* input, const scalar_t* rois, scalar_t* output,
scalar_t* argmax_y, scalar_t* argmax_x, int output_size, int channels,
int height, int width, int aligned_height, int aligned_width,
scalar_t spatial_scale, int sampling_ratio, int pool_mode, bool aligned,
cudaStream_t stream) {
roi_align_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input, rois, output, argmax_y, argmax_x, aligned_height,
aligned_width, static_cast<scalar_t>(spatial_scale), sampling_ratio,
pool_mode, aligned, channels, height, width);
}
void TRTRoIAlignForwardCUDAKernelLauncher_float(
const float* input, const float* rois, float* output, float* argmax_y,
float* argmax_x, int output_size, int channels, int height, int width,
int aligned_height, int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream) {
TRTRoIAlignForwardCUDAKernelLauncher<float>(
input, rois, output, argmax_y, argmax_x, output_size, channels, height,
width, aligned_height, aligned_width, spatial_scale, sampling_ratio,
pool_mode, aligned, stream);
}

View File

@ -0,0 +1,7 @@
#ifndef TRT_PLUGIN_HPP
#define TRT_PLUGIN_HPP
extern "C" {
bool initLibMMCVInferPlugins();
} // extern "C"
#endif // TRT_PLUGIN_HPP

View File

@ -0,0 +1,27 @@
#ifndef TRT_PLUGIN_HELPER_HPP
#define TRT_PLUGIN_HELPER_HPP
#include <stdexcept>
#include "NvInferPlugin.h"
namespace mmcv {
inline unsigned int getElementSize(nvinfer1::DataType t) {
switch (t) {
case nvinfer1::DataType::kINT32:
return 4;
case nvinfer1::DataType::kFLOAT:
return 4;
case nvinfer1::DataType::kHALF:
return 2;
// case nvinfer1::DataType::kBOOL:
case nvinfer1::DataType::kINT8:
return 1;
default:
throw std::runtime_error("Invalid DataType.");
}
throw std::runtime_error("Invalid DataType.");
return 0;
}
} // namespace mmcv
#endif // TRT_PLUGIN_HELPER_HPP

View File

@ -0,0 +1,108 @@
#ifndef TRT_ROI_ALIGN_HPP
#define TRT_ROI_ALIGN_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
class RoIAlignPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
RoIAlignPluginDynamic(const std::string &name, int outWidth, int outHeight,
float spatialScale, int sampleRatio, int poolMode,
bool aligned);
RoIAlignPluginDynamic(const std::string name, const void *data,
size_t length);
RoIAlignPluginDynamic() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const override;
// IPluginV2 Methods
const char *getPluginType() const override;
const char *getPluginVersion() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void *buffer) const override;
void destroy() override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
private:
const std::string mLayerName;
std::string mNamespace;
int mOutWidth;
int mOutHeight;
float mSpatialScale;
int mSampleRatio;
int mPoolMode; // 1:avg 0:max
bool mAligned;
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};
class RoIAlignPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
RoIAlignPluginDynamicCreator();
const char *getPluginName() const override;
const char *getPluginVersion() const override;
const nvinfer1::PluginFieldCollection *getFieldNames() override;
nvinfer1::IPluginV2 *createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif // TRT_ROI_ALIGN_HPP

View File

@ -0,0 +1,117 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_SERIALIZE_HPP
#define TRT_SERIALIZE_HPP
#include <cassert>
#include <cstring>
#include <iostream>
#include <type_traits>
#include <vector>
using std::cerr;
using std::cout;
using std::endl;
template <typename T>
inline void serialize_value(void** buffer, T const& value);
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size,
T* value);
namespace {
template <typename T, class Enable = void>
struct Serializer {};
template <typename T>
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t serialized_size(T const& value) { return sizeof(T); }
static void serialize(void** buffer, T const& value) {
::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T));
::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
};
template <>
struct Serializer<const char*> {
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
static void serialize(void** buffer, const char* value) {
::strcpy(static_cast<char*>(*buffer), value);
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
}
static void deserialize(void const** buffer, size_t* buffer_size,
const char** value) {
*value = static_cast<char const*>(*buffer);
size_t data_size = strnlen(*value, *buffer_size) + 1;
assert(*buffer_size >= data_size);
reinterpret_cast<char const*&>(*buffer) += data_size;
*buffer_size -= data_size;
}
};
template <typename T>
struct Serializer<std::vector<T>,
typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t serialized_size(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void serialize(void** buffer, std::vector<T> const& value) {
serialize_value(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void deserialize(void const** buffer, size_t* buffer_size,
std::vector<T>* value) {
size_t size;
deserialize_value(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte);
::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
} // namespace
template <typename T>
inline size_t serialized_size(T const& value) {
return Serializer<T>::serialized_size(value);
}
template <typename T>
inline void serialize_value(void** buffer, T const& value) {
return Serializer<T>::serialize(buffer, value);
}
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size,
T* value) {
return Serializer<T>::deserialize(buffer, buffer_size, value);
}
#endif // TRT_SERIALIZE_HPP

View File

@ -4,6 +4,7 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..onnx import is_custom_op_loaded
from ..utils import deprecated_api_warning, ext_loader
ext_module = ext_loader.load_ext('_ext',
@ -15,55 +16,48 @@ class RoIAlignFunction(Function):
@staticmethod
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
pool_mode, aligned):
has_custom_op = False
try:
import os.path as osp
from mmcv.ops import get_onnxruntime_op_path
ort_op_path = get_onnxruntime_op_path()
has_custom_op = osp.exists(ort_op_path)
except ImportError:
pass
has_custom_op = is_custom_op_loaded()
if has_custom_op:
return g.op(
'mmcv::MMCVRoiAlign',
input,
rois,
aligned_height_i=output_size[0],
aligned_width_i=output_size[1],
output_height_i=output_size[0],
output_width_i=output_size[1],
spatial_scale_f=spatial_scale,
sampling_ratio_i=sampling_ratio,
mode_s=pool_mode,
aligned_i=aligned)
else:
from torch.onnx.symbolic_opset9 import sub, squeeze
from torch.onnx.symbolic_helper import _slice_helper
from torch.onnx import TensorProtoDataType
# batch_indices = rois[:, 0].long()
batch_indices = _slice_helper(
g, rois, axes=[1], starts=[0], ends=[1])
batch_indices = squeeze(g, batch_indices, 1)
batch_indices = g.op(
'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
# rois = rois[:, 1:]
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
if aligned:
# rois -= 0.5/spatial_scale
aligned_offset = g.op(
'Constant',
value_t=torch.tensor([0.5 / spatial_scale],
dtype=torch.float32))
rois = sub(g, rois, aligned_offset)
# roi align
return g.op(
'RoiAlign',
input,
rois,
batch_indices,
output_height_i=output_size[0],
output_width_i=output_size[1],
spatial_scale_f=spatial_scale,
sampling_ratio_i=max(0, sampling_ratio),
pool_mode_s=pool_mode,
aligned_i=aligned)
from torch.onnx.symbolic_opset9 import sub, squeeze
from torch.onnx.symbolic_helper import _slice_helper
from torch.onnx import TensorProtoDataType
# batch_indices = rois[:, 0].long()
batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1])
batch_indices = squeeze(g, batch_indices, 1)
batch_indices = g.op(
'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
# rois = rois[:, 1:]
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
if aligned:
# rois -= 0.5/spatial_scale
aligned_offset = g.op(
'Constant',
value_t=torch.tensor([0.5 / spatial_scale],
dtype=torch.float32))
rois = sub(g, rois, aligned_offset)
# roi align
return g.op(
'RoiAlign',
input,
rois,
batch_indices,
output_height_i=output_size[0],
output_width_i=output_size[1],
spatial_scale_f=spatial_scale,
sampling_ratio_i=max(0, sampling_ratio),
mode_s=pool_mode)
mode_s=pool_mode)
@staticmethod
def forward(ctx,

12
mmcv/tensorrt/__init__.py Normal file
View File

@ -0,0 +1,12 @@
# flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt,
save_trt_engine)
# load tensorrt plugin lib
load_tensorrt_plugin()
__all__ = [
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'is_tensorrt_plugin_loaded'
]

View File

@ -0,0 +1,36 @@
import ctypes
import glob
import os
def get_tensorrt_op_path():
"""Get TensorRT plugins library path."""
wildcard = os.path.join(
os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
'_ext_trt.*.so')
paths = glob.glob(wildcard)
lib_path = paths[0] if len(paths) > 0 else ''
return lib_path
plugin_is_loaded = False
def is_tensorrt_plugin_loaded():
"""Check if TensorRT plugins library is loaded or not.
Returns:
bool: plugin_is_loaded flag
"""
global plugin_is_loaded
return plugin_is_loaded
def load_tensorrt_plugin():
"""load TensorRT plugins library."""
global plugin_is_loaded
lib_path = get_tensorrt_op_path()
if (not plugin_is_loaded) and os.path.exists(lib_path):
ctypes.CDLL(lib_path)
plugin_is_loaded = True

View File

@ -0,0 +1,209 @@
import tensorrt as trt
import torch
def onnx2trt(onnx_model,
opt_shape_dict,
log_level=trt.Logger.ERROR,
fp16_mode=False,
max_workspace_size=0,
device_id=0):
"""Convert onnx model to tensorrt engine.
Arguments:
onnx_model (str or onnx.ModelProto): the onnx model to convert from
opt_shape_dict (dict): the min/opt/max shape of each input
log_level (TensorRT log level): the log level of TensorRT
fp16_mode (bool): enable fp16 mode
max_workspace_size (int): set max workspace size of TensorRT engine.
some tactic and layers need large workspace.
device_id (int): choice the device to create engine.
Returns:
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
Example:
>>> engine = onnx2trt(
>>> "onnx_model.onnx",
>>> {'input': [[1, 3, 160, 160],
>>> [1, 3, 320, 320],
>>> [1, 3, 640, 640]]},
>>> log_level=trt.Logger.WARNING,
>>> fp16_mode=True,
>>> max_workspace_size=1 << 30,
>>> device_id=0)
>>> })
"""
device = torch.device('cuda:{}'.format(device_id))
# create builder and network
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
EXPLICIT_BATCH = 1 << (int)(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)
# parse onnx
parser = trt.OnnxParser(network, logger)
if isinstance(onnx_model, str):
assert parser.parse_from_file(onnx_model), 'parse onnx failed.'
else:
assert parser.parse(
onnx_model.SerializeToString()), 'parse onnx failed.'
# config builder
builder.max_workspace_size = max_workspace_size
config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size
profile = builder.create_optimization_profile()
for input_name, param in opt_shape_dict.items():
min_shape = tuple(param[0][:])
opt_shape = tuple(param[1][:])
max_shape = tuple(param[2][:])
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
if fp16_mode:
builder.fp16_mode = fp16_mode
config.set_flag(trt.BuilderFlag.FP16)
# create engine
with torch.cuda.device(device):
engine = builder.build_engine(network, config)
return engine
def save_trt_engine(engine, path):
"""Serialize TensorRT engine to disk.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
path (str): disk path to write the engine
"""
with open(path, mode='wb') as f:
f.write(bytearray(engine.serialize()))
def load_trt_engine(path):
"""Deserialize TensorRT engine from disk.
Arguments:
path (str): disk path to read the engine
Returns:
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
"""
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(path, mode='rb') as f:
engine_bytes = f.read()
engine = runtime.deserialize_cuda_engine(engine_bytes)
return engine
def torch_dtype_from_trt(dtype):
"""Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool:
return torch.bool
elif dtype == trt.int8:
return torch.int8
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError('%s is not supported by torch' % dtype)
def torch_device_from_trt(device):
"""Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')
elif device == trt.TensorLocation.HOST:
return torch.device('cpu')
else:
return TypeError('%s is not supported by torch' % device)
class TRTWraper(torch.nn.Module):
"""TensorRT engine Wraper.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
input_names (list[str]): names of each inputs
output_names (list[str]): names of each outputs
Note:
If the engine is converted from onnx model. The input_names and
output_names should be the same as onnx model.
"""
def __init__(self, engine, input_names, output_names):
super(TRTWraper, self).__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)
if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError('engine should be str or trt.ICudaEngine')
self._register_state_dict_hook(TRTWraper._on_state_dict)
self.context = self.engine.create_execution_context()
self.input_names = input_names
self.output_names = output_names
def _on_state_dict(self, state_dict, prefix, local_metadata):
state_dict[prefix + 'engine'] = bytearray(self.engine.serialize())
state_dict[prefix + 'input_names'] = self.input_names
state_dict[prefix + 'output_names'] = self.output_names
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
engine_bytes = state_dict[prefix + 'engine']
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
self.context = self.engine.create_execution_context()
self.input_names = state_dict[prefix + 'input_names']
self.output_names = state_dict[prefix + 'output_names']
def forward(self, inputs):
"""
Arguments:
inputs (dict): dict of input name-tensors pair
Return:
dict: dict of output name-tensors pair
"""
assert self.input_names is not None
assert self.output_names is not None
bindings = [None] * (len(self.input_names) + len(self.output_names))
for input_name, input_tensor in inputs.items():
idx = self.engine.get_binding_index(input_name)
self.context.set_binding_shape(idx, tuple(input_tensor.shape))
bindings[idx] = input_tensor.contiguous().data_ptr()
# create output tensors
outputs = {}
for i, output_name in enumerate(self.output_names):
idx = self.engine.get_binding_index(output_name)
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
shape = tuple(self.context.get_binding_shape(idx))
device = torch_device_from_trt(self.engine.get_location(idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
outputs[output_name] = output
bindings[idx] = output.data_ptr()
self.context.execute_async_v2(bindings,
torch.cuda.current_stream().cuda_stream)
return outputs

View File

@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcv
known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -54,7 +54,6 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
"""
import sys
from os.path import exists
import re
require_fpath = fname
def parse_line(line):
@ -134,6 +133,51 @@ except ImportError:
def get_extensions():
extensions = []
if os.getenv('MMCV_WITH_TRT', '0') != '0':
ext_name = 'mmcv._ext_trt'
from torch.utils.cpp_extension import include_paths, library_paths
library_dirs = []
libraries = []
include_dirs = []
tensorrt_path = os.getenv('TENSORRT_DIR', '0')
tensorrt_lib_path = glob.glob(
os.path.join(tensorrt_path, 'targets', '*', 'lib'))[0]
library_dirs += [tensorrt_lib_path]
libraries += ['nvinfer', 'nvparsers', 'nvinfer_plugin']
libraries += ['cudart']
kwargs = {}
define_macros = []
extra_compile_args = {'cxx': []}
include_path = os.path.abspath('./mmcv/ops/csrc')
include_trt_path = os.path.abspath('./mmcv/ops/csrc/tensorrt')
include_dirs.append(include_path)
include_dirs.append(include_trt_path)
include_dirs.append(os.path.join(tensorrt_path, 'include'))
include_dirs += include_paths(cuda=True)
op_files = glob.glob('./mmcv/ops/csrc/tensorrt/plugins/*')
define_macros += [('MMCV_WITH_CUDA', None)]
define_macros += [('MMCV_WITH_TRT', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
library_dirs += library_paths(cuda=True)
kwargs['library_dirs'] = library_dirs
kwargs['libraries'] = libraries
from setuptools import Extension
ext_ops = Extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
language='c++',
library_dirs=library_dirs,
libraries=libraries)
extensions.append(ext_ops)
if os.getenv('MMCV_WITH_OPS', '0') == '0':
return extensions
@ -157,7 +201,8 @@ def get_extensions():
extensions.append(ext_ops)
elif EXT_TYPE == 'pytorch':
ext_name = 'mmcv._ext'
from torch.utils.cpp_extension import (CUDAExtension, CppExtension)
from torch.utils.cpp_extension import CppExtension, CUDAExtension
# prevent ninja from using too many resources
os.environ.setdefault('MAX_JOBS', '4')
define_macros = []

View File

@ -61,9 +61,7 @@ def test_nms():
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA is unavailable for test_softnms')
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_softnms():
from mmcv.ops import get_onnxruntime_op_path, soft_nms
@ -77,7 +75,8 @@ def test_softnms():
'1.5.1'), 'test_softnms should be ran with onnxruntime >= 1.5.1'
ort_custom_op_path = get_onnxruntime_op_path()
assert os.path.exists(ort_custom_op_path)
if not os.path.exists(ort_custom_op_path):
pytest.skip('softnms for onnxruntime is not compiled.')
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
@ -138,13 +137,13 @@ def test_softnms():
def test_roialign():
from mmcv.ops import roi_align
ort_custom_op_path = ''
try:
from mmcv.ops import roi_align
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except ImportError:
pass
except (ImportError, ModuleNotFoundError):
pytest.skip('roi_align op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path()
# roi align config
pool_h = 2
pool_w = 2
@ -208,9 +207,8 @@ def test_roialign():
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_roipool():
if not torch.cuda.is_available():
return
from mmcv.ops import roi_pool
# roi pool config

View File

@ -0,0 +1,93 @@
import os
import numpy as np
import onnx
import pytest
import torch
onnx_file = 'tmp.onnx'
trt_file = 'tmp.engine'
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA is required for test_roialign')
def test_roialign():
try:
from mmcv.tensorrt import (TRTWraper, onnx2trt, save_trt_engine,
is_tensorrt_plugin_loaded)
if not is_tensorrt_plugin_loaded():
pytest.skip('test requires to complie TensorRT plugins in mmcv')
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires to install TensorRT from source.')
try:
from mmcv.ops import RoIAlign
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
# roi align config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.],
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
wrapped_model = RoIAlign((pool_w, pool_h), spatial_scale, sampling_ratio,
'avg', True).cuda()
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input).cuda()
rois = torch.from_numpy(np_rois).cuda()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'rois'],
output_names=['roi_feat'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'input': [list(input.shape),
list(input.shape),
list(input.shape)],
'rois': [list(rois.shape),
list(rois.shape),
list(rois.shape)]
}
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, ['input', 'rois'], ['roi_feat'])
with torch.no_grad():
trt_outputs = trt_model({'input': input, 'rois': rois})
trt_roi_feat = trt_outputs['roi_feat']
# compute pytorch_output
with torch.no_grad():
pytorch_roi_feat = wrapped_model(input, rois)
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_roi_feat, trt_roi_feat)