mirror of https://github.com/open-mmlab/mmcv.git
[Feature] add cummax/cummin tensorrt plugin (#1031)
* add cummax/cummin tensorrt plugin * fix isort * fix with clang-format * fix with clang-format again * add documentpull/1041/head^2
parent
55b4847a41
commit
9d1436fb6c
|
@ -33,6 +33,18 @@
|
|||
- [Inputs](#inputs-4)
|
||||
- [Outputs](#outputs-4)
|
||||
- [Type Constraints](#type-constraints-4)
|
||||
- [cummax](#cummax)
|
||||
- [Description](#description-5)
|
||||
- [Parameters](#parameters-5)
|
||||
- [Inputs](#inputs-5)
|
||||
- [Outputs](#outputs-5)
|
||||
- [Type Constraints](#type-constraints-5)
|
||||
- [cummin](#cummin)
|
||||
- [Description](#description-6)
|
||||
- [Parameters](#parameters-6)
|
||||
- [Inputs](#inputs-6)
|
||||
- [Outputs](#outputs-6)
|
||||
- [Type Constraints](#type-constraints-6)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -227,3 +239,67 @@ Perform sample from `input` with pixel locations from `grid`.
|
|||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
## cummax
|
||||
|
||||
### Description
|
||||
|
||||
Returns a namedtuple (`values`, `indices`) where `values` is the cumulative maximum of elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| ----- | --------- | --------------------------------------- |
|
||||
| `int` | `dim` | The dimension to do the operation over. |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>The input tensor.</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>Output values.</dd>
|
||||
<dt><tt>outputs[1]</tt>: (int32, Linear)</dt>
|
||||
<dd>Output indices.</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
## cummin
|
||||
|
||||
### Description
|
||||
|
||||
Returns a namedtuple (`values`, `indices`) where `values` is the cumulative minimum of elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| ----- | --------- | --------------------------------------- |
|
||||
| `int` | `dim` | The dimension to do the operation over. |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>The input tensor.</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>Output values.</dd>
|
||||
<dt><tt>outputs[1]</tt>: (int32, Linear)</dt>
|
||||
<dd>Output indices.</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
|
|
@ -30,7 +30,9 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
|
|||
| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
|
||||
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
|
||||
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
|
||||
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | master |
|
||||
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
|
||||
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
|
||||
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
|
||||
|
||||
Notes
|
||||
|
||||
|
|
|
@ -0,0 +1,241 @@
|
|||
#include "trt_cummaxmin.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
|
||||
int *output_index, const int *dims,
|
||||
int nbDims, int cum_dim, int cum_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
|
||||
int *output_index, const int *dims,
|
||||
int nbDims, int cum_dim, int cum_type,
|
||||
cudaStream_t stream);
|
||||
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *CUMMAXMIN_PLUGIN_NAME{"cummaxmin"};
|
||||
static const char *CUMMAX_PLUGIN_NAME{"cummax"};
|
||||
static const char *CUMMIN_PLUGIN_NAME{"cummin"};
|
||||
} // namespace
|
||||
|
||||
CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string &name, int dim,
|
||||
TRT_CUMCMPTYPE cumType)
|
||||
: mLayerName(name), mDim(dim), mCumType(cumType) {}
|
||||
|
||||
CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string name,
|
||||
const void *data, size_t length)
|
||||
: mLayerName(name) {
|
||||
deserialize_value(&data, &length, &mDim);
|
||||
deserialize_value(&data, &length, &mCumType);
|
||||
}
|
||||
|
||||
CumMaxMinPluginDynamic::~CumMaxMinPluginDynamic() {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *CumMaxMinPluginDynamic::clone() const {
|
||||
CumMaxMinPluginDynamic *plugin =
|
||||
new CumMaxMinPluginDynamic(mLayerName, mDim, mCumType);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs CumMaxMinPluginDynamic::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
bool CumMaxMinPluginDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
switch (pos) {
|
||||
// input[0]
|
||||
case 0:
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT ||
|
||||
inOut[pos].type == nvinfer1::DataType::kINT32) &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
// output[0]
|
||||
case 1:
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
// output[1]
|
||||
case 2:
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void CumMaxMinPluginDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
|
||||
|
||||
size_t CumMaxMinPluginDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
|
||||
int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
|
||||
}
|
||||
|
||||
int CumMaxMinPluginDynamic::enqueue(
|
||||
const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workSpace, cudaStream_t stream) {
|
||||
const void *input = inputs[0];
|
||||
void *output_value = outputs[0];
|
||||
int *output_index = (int *)outputs[1];
|
||||
|
||||
const int *dims = &(inputDesc[0].dims.d[0]);
|
||||
int nbDims = inputDesc[0].dims.nbDims;
|
||||
|
||||
switch (inputDesc[0].type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
CumMaxMinForwardLauncher_float((float *)input, (float *)output_value,
|
||||
output_index, dims, nbDims, mDim,
|
||||
int(mCumType), stream);
|
||||
break;
|
||||
case nvinfer1::DataType::kINT32:
|
||||
CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value,
|
||||
output_index, dims, nbDims, mDim,
|
||||
int(mCumType), stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType CumMaxMinPluginDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return inputTypes[0];
|
||||
case 1:
|
||||
return nvinfer1::DataType::kINT32;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *CumMaxMinPluginDynamic::getPluginType() const {
|
||||
switch (mCumType) {
|
||||
case TRT_CUMCMPTYPE::TRT_CUMMAX:
|
||||
return CUMMAX_PLUGIN_NAME;
|
||||
case TRT_CUMCMPTYPE::TRT_CUMMIN:
|
||||
return CUMMIN_PLUGIN_NAME;
|
||||
default:
|
||||
return "UnknownCumType";
|
||||
}
|
||||
}
|
||||
|
||||
const char *CumMaxMinPluginDynamic::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int CumMaxMinPluginDynamic::getNbOutputs() const { return 2; }
|
||||
|
||||
int CumMaxMinPluginDynamic::initialize() { return 0; }
|
||||
|
||||
void CumMaxMinPluginDynamic::terminate() {}
|
||||
|
||||
size_t CumMaxMinPluginDynamic::getSerializationSize() const {
|
||||
return sizeof(mDim) + sizeof(mCumType);
|
||||
}
|
||||
|
||||
void CumMaxMinPluginDynamic::serialize(void *buffer) const {
|
||||
serialize_value(&buffer, mDim);
|
||||
serialize_value(&buffer, mCumType);
|
||||
}
|
||||
|
||||
void CumMaxMinPluginDynamic::destroy() {
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
void CumMaxMinPluginDynamic::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *CumMaxMinPluginDynamic::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
CumMaxMinPluginDynamicCreator::CumMaxMinPluginDynamicCreator(
|
||||
TRT_CUMCMPTYPE cumType)
|
||||
: mCumType(cumType) {
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("dim"));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *CumMaxMinPluginDynamicCreator::getPluginName() const {
|
||||
return CUMMAXMIN_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *CumMaxMinPluginDynamicCreator::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const nvinfer1::PluginFieldCollection *
|
||||
CumMaxMinPluginDynamicCreator::getFieldNames() {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
int dim = 0;
|
||||
|
||||
for (int i = 0; i < fc->nbFields; i++) {
|
||||
if (fc->fields[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string field_name(fc->fields[i].name);
|
||||
|
||||
if (field_name.compare("dim") == 0) {
|
||||
dim = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
}
|
||||
|
||||
CumMaxMinPluginDynamic *plugin =
|
||||
new CumMaxMinPluginDynamic(name, dim, mCumType);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call FCPluginDynamic::destroy()
|
||||
auto plugin = new CumMaxMinPluginDynamic(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void CumMaxMinPluginDynamicCreator::setPluginNamespace(
|
||||
const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *CumMaxMinPluginDynamicCreator::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
CumMaxPluginDynamicCreator::CumMaxPluginDynamicCreator()
|
||||
: CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMAX) {}
|
||||
|
||||
const char *CumMaxPluginDynamicCreator::getPluginName() const {
|
||||
return CUMMAX_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
CumMinPluginDynamicCreator::CumMinPluginDynamicCreator()
|
||||
: CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMIN) {}
|
||||
|
||||
const char *CumMinPluginDynamicCreator::getPluginName() const {
|
||||
return CUMMIN_PLUGIN_NAME;
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_cuda_helper.cuh"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
using mmcv::TensorDesc;
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void cummaxmin_kernel(const scalar_t *input, scalar_t *output_value,
|
||||
int *output_index, TensorDesc tensor_desc,
|
||||
int cum_dim, int cum_type) {
|
||||
const size_t cum_size = tensor_desc.shape[cum_dim];
|
||||
const size_t cum_stride = tensor_desc.stride[cum_dim];
|
||||
const size_t data_size =
|
||||
tensor_desc.stride[0] * tensor_desc.shape[0] / cum_size;
|
||||
CUDA_1D_KERNEL_LOOP(index, data_size) {
|
||||
size_t cum_offset =
|
||||
index / cum_stride * (cum_size * cum_stride) + index % cum_stride;
|
||||
int cum_index = 0;
|
||||
auto cum_value = input[cum_offset];
|
||||
output_value[cum_offset] = cum_value;
|
||||
output_index[cum_offset] = cum_index;
|
||||
|
||||
for (size_t cum_index_current = 1; cum_index_current < cum_size;
|
||||
++cum_index_current) {
|
||||
cum_offset += cum_stride;
|
||||
const auto cum_value_current = input[cum_offset];
|
||||
switch (cum_type) {
|
||||
case 0: // max
|
||||
if (cum_value_current > cum_value) {
|
||||
cum_value = cum_value_current;
|
||||
cum_index = cum_index_current;
|
||||
}
|
||||
break;
|
||||
case 1: // min
|
||||
if (cum_value_current < cum_value) {
|
||||
cum_value = cum_value_current;
|
||||
cum_index = cum_index_current;
|
||||
}
|
||||
break;
|
||||
}
|
||||
output_value[cum_offset] = cum_value;
|
||||
output_index[cum_offset] = cum_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value,
|
||||
int *output_index, const int *dims, int nbDims,
|
||||
int cum_dim, int cum_type, cudaStream_t stream) {
|
||||
// fill tensordesc and initial
|
||||
TensorDesc tensor_desc;
|
||||
memset((void *)&tensor_desc, 0, sizeof(TensorDesc));
|
||||
tensor_desc.dim = nbDims;
|
||||
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
|
||||
tensor_desc.stride[nbDims - 1] = 1;
|
||||
for (int i = nbDims - 2; i >= 0; --i) {
|
||||
tensor_desc.shape[i] = dims[i];
|
||||
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
|
||||
}
|
||||
|
||||
// cum dim should be larger than 0
|
||||
cum_dim = cum_dim >= 0 ? cum_dim : (nbDims + cum_dim);
|
||||
|
||||
const int data_size =
|
||||
tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim];
|
||||
|
||||
const int col_block = DIVUP(data_size, THREADS_PER_BLOCK);
|
||||
|
||||
cummaxmin_kernel<scalar_t><<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
input, output_value, output_index, tensor_desc, cum_dim, cum_type);
|
||||
}
|
||||
|
||||
void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
|
||||
int *output_index, const int *dims,
|
||||
int nbDims, int cum_dim, int cum_type,
|
||||
cudaStream_t stream) {
|
||||
CumMaxMinForwardLauncher<float>(input, output_value, output_index, dims,
|
||||
nbDims, cum_dim, cum_type, stream);
|
||||
}
|
||||
|
||||
void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
|
||||
int *output_index, const int *dims,
|
||||
int nbDims, int cum_dim, int cum_type,
|
||||
cudaStream_t stream) {
|
||||
CumMaxMinForwardLauncher<int>(input, output_value, output_index, dims, nbDims,
|
||||
cum_dim, cum_type, stream);
|
||||
}
|
|
@ -1,11 +1,14 @@
|
|||
#include "trt_plugin.hpp"
|
||||
|
||||
#include "trt_cummaxmin.hpp"
|
||||
#include "trt_deform_conv.hpp"
|
||||
#include "trt_grid_sampler.hpp"
|
||||
#include "trt_nms.hpp"
|
||||
#include "trt_roi_align.hpp"
|
||||
#include "trt_scatternd.hpp"
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
#ifndef TRT_CUMMAXMIN_HPP
|
||||
#define TRT_CUMMAXMIN_HPP
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
enum TRT_CUMCMPTYPE { TRT_CUMMAX = 0, TRT_CUMMIN = 1 };
|
||||
|
||||
// implement of cummax and cummin
|
||||
class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
CumMaxMinPluginDynamic(const std::string &name, int dim,
|
||||
TRT_CUMCMPTYPE cumType);
|
||||
|
||||
CumMaxMinPluginDynamic(const std::string name, const void *data,
|
||||
size_t length);
|
||||
|
||||
CumMaxMinPluginDynamic() = delete;
|
||||
|
||||
~CumMaxMinPluginDynamic();
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
void destroy() override;
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
protected:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
|
||||
int mDim;
|
||||
TRT_CUMCMPTYPE mCumType;
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
};
|
||||
|
||||
// cummax and cummin creator
|
||||
class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE cumType);
|
||||
|
||||
const char *getPluginName() const override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
protected:
|
||||
TRT_CUMCMPTYPE mCumType;
|
||||
nvinfer1::PluginFieldCollection mFC;
|
||||
std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
std::string mNamespace;
|
||||
};
|
||||
|
||||
// cummax creator
|
||||
class CumMaxPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
|
||||
public:
|
||||
CumMaxPluginDynamicCreator();
|
||||
const char *getPluginName() const override;
|
||||
};
|
||||
|
||||
// cummin creator
|
||||
class CumMinPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
|
||||
public:
|
||||
CumMinPluginDynamicCreator();
|
||||
const char *getPluginName() const override;
|
||||
};
|
||||
|
||||
#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP
|
|
@ -238,7 +238,7 @@ class TRTWraper(torch.nn.Module):
|
|||
output_names should be the same as onnx model.
|
||||
"""
|
||||
|
||||
def __init__(self, engine, input_names, output_names):
|
||||
def __init__(self, engine, input_names=None, output_names=None):
|
||||
super(TRTWraper, self).__init__()
|
||||
self.engine = engine
|
||||
if isinstance(self.engine, str):
|
||||
|
@ -250,6 +250,11 @@ class TRTWraper(torch.nn.Module):
|
|||
self._register_state_dict_hook(TRTWraper._on_state_dict)
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
# get input and output names from engine
|
||||
if input_names is None or output_names is None:
|
||||
names = [_ for _ in self.engine]
|
||||
input_names = list(filter(self.engine.binding_is_input, names))
|
||||
output_names = list(set(names) - set(input_names))
|
||||
self.input_names = input_names
|
||||
self.output_names = output_names
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
|
@ -478,3 +479,99 @@ def test_grid_sample(mode, padding_mode, align_corners):
|
|||
if os.path.exists(trt_file):
|
||||
os.remove(trt_file)
|
||||
assert torch.allclose(pytorch_results, trt_results)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('func', [torch.cummax, torch.cummin])
|
||||
def test_cummin_cummax(func: Callable):
|
||||
# Note generally `cummax` or `cummin` is exportable to ONNX
|
||||
# as long as the pytorch version >= 1.5.0, since `torch.cummax`
|
||||
# is only supported with torch >= 1.5.0.
|
||||
# But when `cummax` or `cummin` serves as an intermediate component
|
||||
# whose outputs is used as inputs for another modules, it's expected
|
||||
# that pytorch version must be >= 1.7.0. Otherwise error appears like:
|
||||
# `RuntimeError: tuple appears in op that does not forward tuples,
|
||||
# unsupported 'kind: prim::PythonOp`.
|
||||
from packaging import version
|
||||
if version.parse(torch.__version__) < version.parse('1.7.0'):
|
||||
pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0')
|
||||
|
||||
opset = 11
|
||||
# register custom op `mmcv::cummax` and `mmcv::cummin`
|
||||
from mmcv.onnx.symbolic import register_extra_symbolics
|
||||
register_extra_symbolics(opset)
|
||||
|
||||
input_list = [
|
||||
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
|
||||
torch.rand((2, 3, 4, 1, 5)).cuda(),
|
||||
torch.rand((1)).cuda()
|
||||
]
|
||||
|
||||
input_names = ['input']
|
||||
output_names = ['output', 'indices']
|
||||
|
||||
for input in input_list:
|
||||
ndims = input.dim()
|
||||
# valid dim range is [-ndims, ndims-1]
|
||||
# test for all `dim` value which is valid
|
||||
for dim in range(-ndims, ndims):
|
||||
cummax_func = partial(func, dim=dim)
|
||||
wrapped_model = WrapFunction(cummax_func).eval().cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapped_model,
|
||||
input,
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {
|
||||
'input':
|
||||
[list(input.shape),
|
||||
list(input.shape),
|
||||
list(input.shape)]
|
||||
}
|
||||
# trt config
|
||||
fp16_mode = False
|
||||
max_workspace_size = 1 << 30
|
||||
|
||||
trt_engine = onnx2trt(
|
||||
onnx_model,
|
||||
opt_shape_dict,
|
||||
fp16_mode=fp16_mode,
|
||||
max_workspace_size=max_workspace_size)
|
||||
|
||||
# remove ONNX model after conversion
|
||||
if os.path.exists(onnx_file):
|
||||
os.remove(onnx_file)
|
||||
|
||||
# save TensorRT model
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
|
||||
# load and wrap TensorRT model
|
||||
trt_model = TRTWraper(trt_file)
|
||||
|
||||
# remove trt model after loading
|
||||
if os.path.exists(trt_file):
|
||||
os.remove(trt_file)
|
||||
|
||||
# compute trt output
|
||||
with torch.no_grad():
|
||||
trt_results = trt_model({'input': input.contiguous().clone()})
|
||||
trt_output = trt_results['output']
|
||||
trt_indices = trt_results['indices']
|
||||
|
||||
# compute pytorch output
|
||||
with torch.no_grad():
|
||||
pytorch_results = wrapped_model(input.clone())
|
||||
pytorch_output = pytorch_results[0]
|
||||
pytorch_indices = pytorch_results[1]
|
||||
|
||||
torch.testing.assert_allclose(trt_output, pytorch_output)
|
||||
torch.testing.assert_allclose(trt_indices, pytorch_indices)
|
||||
|
|
Loading…
Reference in New Issue