diff --git a/docs/tensorrt_custom_ops.md b/docs/tensorrt_custom_ops.md
index da696f03e..7bf369cfb 100644
--- a/docs/tensorrt_custom_ops.md
+++ b/docs/tensorrt_custom_ops.md
@@ -33,6 +33,18 @@
- [Inputs](#inputs-4)
- [Outputs](#outputs-4)
- [Type Constraints](#type-constraints-4)
+ - [cummax](#cummax)
+ - [Description](#description-5)
+ - [Parameters](#parameters-5)
+ - [Inputs](#inputs-5)
+ - [Outputs](#outputs-5)
+ - [Type Constraints](#type-constraints-5)
+ - [cummin](#cummin)
+ - [Description](#description-6)
+ - [Parameters](#parameters-6)
+ - [Inputs](#inputs-6)
+ - [Outputs](#outputs-6)
+ - [Type Constraints](#type-constraints-6)
@@ -227,3 +239,67 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints
- T:tensor(float32, Linear)
+
+## cummax
+
+### Description
+
+Returns a namedtuple (`values`, `indices`) where `values` is the cumulative maximum of elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`.
+
+### Parameters
+
+| Type | Parameter | Description |
+| ----- | --------- | --------------------------------------- |
+| `int` | `dim` | The dimension to do the operation over. |
+
+### Inputs
+
+
+- inputs[0]: T
+- The input tensor.
+
+
+### Outputs
+
+
+- outputs[0]: T
+- Output values.
+- outputs[1]: (int32, Linear)
+- Output indices.
+
+
+### Type Constraints
+
+- T:tensor(float32, Linear)
+
+## cummin
+
+### Description
+
+Returns a namedtuple (`values`, `indices`) where `values` is the cumulative minimum of elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`.
+
+### Parameters
+
+| Type | Parameter | Description |
+| ----- | --------- | --------------------------------------- |
+| `int` | `dim` | The dimension to do the operation over. |
+
+### Inputs
+
+
+- inputs[0]: T
+- The input tensor.
+
+
+### Outputs
+
+
+- outputs[0]: T
+- Output values.
+- outputs[1]: (int32, Linear)
+- Output indices.
+
+
+### Type Constraints
+
+- T:tensor(float32, Linear)
diff --git a/docs/tensorrt_plugin.md b/docs/tensorrt_plugin.md
index 5ed62d1ba..63b530000 100644
--- a/docs/tensorrt_plugin.md
+++ b/docs/tensorrt_plugin.md
@@ -30,7 +30,9 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
-| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | master |
+| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
+| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
+| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
Notes
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp
new file mode 100644
index 000000000..2e920cfed
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp
@@ -0,0 +1,241 @@
+#include "trt_cummaxmin.hpp"
+
+#include
+
+#include "trt_serialize.hpp"
+
+void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream);
+
+void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream);
+
+namespace {
+static const char *PLUGIN_VERSION{"1"};
+static const char *CUMMAXMIN_PLUGIN_NAME{"cummaxmin"};
+static const char *CUMMAX_PLUGIN_NAME{"cummax"};
+static const char *CUMMIN_PLUGIN_NAME{"cummin"};
+} // namespace
+
+CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string &name, int dim,
+ TRT_CUMCMPTYPE cumType)
+ : mLayerName(name), mDim(dim), mCumType(cumType) {}
+
+CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string name,
+ const void *data, size_t length)
+ : mLayerName(name) {
+ deserialize_value(&data, &length, &mDim);
+ deserialize_value(&data, &length, &mCumType);
+}
+
+CumMaxMinPluginDynamic::~CumMaxMinPluginDynamic() {}
+
+nvinfer1::IPluginV2DynamicExt *CumMaxMinPluginDynamic::clone() const {
+ CumMaxMinPluginDynamic *plugin =
+ new CumMaxMinPluginDynamic(mLayerName, mDim, mCumType);
+ plugin->setPluginNamespace(getPluginNamespace());
+
+ return plugin;
+}
+
+nvinfer1::DimsExprs CumMaxMinPluginDynamic::getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
+ nvinfer1::IExprBuilder &exprBuilder) {
+ return inputs[0];
+}
+
+bool CumMaxMinPluginDynamic::supportsFormatCombination(
+ int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
+ int nbOutputs) {
+ switch (pos) {
+ // input[0]
+ case 0:
+ return (inOut[pos].type == nvinfer1::DataType::kFLOAT ||
+ inOut[pos].type == nvinfer1::DataType::kINT32) &&
+ inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
+ // output[0]
+ case 1:
+ return inOut[pos].type == inOut[0].type &&
+ inOut[pos].format == inOut[0].format;
+ // output[1]
+ case 2:
+ return inOut[pos].type == nvinfer1::DataType::kINT32 &&
+ inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
+ default:
+ return false;
+ }
+}
+
+void CumMaxMinPluginDynamic::configurePlugin(
+ const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
+
+size_t CumMaxMinPluginDynamic::getWorkspaceSize(
+ const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
+ const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
+ int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
+}
+
+int CumMaxMinPluginDynamic::enqueue(
+ const nvinfer1::PluginTensorDesc *inputDesc,
+ const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
+ void *const *outputs, void *workSpace, cudaStream_t stream) {
+ const void *input = inputs[0];
+ void *output_value = outputs[0];
+ int *output_index = (int *)outputs[1];
+
+ const int *dims = &(inputDesc[0].dims.d[0]);
+ int nbDims = inputDesc[0].dims.nbDims;
+
+ switch (inputDesc[0].type) {
+ case nvinfer1::DataType::kFLOAT:
+ CumMaxMinForwardLauncher_float((float *)input, (float *)output_value,
+ output_index, dims, nbDims, mDim,
+ int(mCumType), stream);
+ break;
+ case nvinfer1::DataType::kINT32:
+ CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value,
+ output_index, dims, nbDims, mDim,
+ int(mCumType), stream);
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+nvinfer1::DataType CumMaxMinPluginDynamic::getOutputDataType(
+ int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
+ switch (index) {
+ case 0:
+ return inputTypes[0];
+ case 1:
+ return nvinfer1::DataType::kINT32;
+ default:
+ break;
+ }
+}
+
+// IPluginV2 Methods
+const char *CumMaxMinPluginDynamic::getPluginType() const {
+ switch (mCumType) {
+ case TRT_CUMCMPTYPE::TRT_CUMMAX:
+ return CUMMAX_PLUGIN_NAME;
+ case TRT_CUMCMPTYPE::TRT_CUMMIN:
+ return CUMMIN_PLUGIN_NAME;
+ default:
+ return "UnknownCumType";
+ }
+}
+
+const char *CumMaxMinPluginDynamic::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+int CumMaxMinPluginDynamic::getNbOutputs() const { return 2; }
+
+int CumMaxMinPluginDynamic::initialize() { return 0; }
+
+void CumMaxMinPluginDynamic::terminate() {}
+
+size_t CumMaxMinPluginDynamic::getSerializationSize() const {
+ return sizeof(mDim) + sizeof(mCumType);
+}
+
+void CumMaxMinPluginDynamic::serialize(void *buffer) const {
+ serialize_value(&buffer, mDim);
+ serialize_value(&buffer, mCumType);
+}
+
+void CumMaxMinPluginDynamic::destroy() {
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+void CumMaxMinPluginDynamic::setPluginNamespace(const char *libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char *CumMaxMinPluginDynamic::getPluginNamespace() const {
+ return mNamespace.c_str();
+}
+
+CumMaxMinPluginDynamicCreator::CumMaxMinPluginDynamicCreator(
+ TRT_CUMCMPTYPE cumType)
+ : mCumType(cumType) {
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(nvinfer1::PluginField("dim"));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+const char *CumMaxMinPluginDynamicCreator::getPluginName() const {
+ return CUMMAXMIN_PLUGIN_NAME;
+}
+
+const char *CumMaxMinPluginDynamicCreator::getPluginVersion() const {
+ return PLUGIN_VERSION;
+}
+
+const nvinfer1::PluginFieldCollection *
+CumMaxMinPluginDynamicCreator::getFieldNames() {
+ return &mFC;
+}
+
+nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::createPlugin(
+ const char *name, const nvinfer1::PluginFieldCollection *fc) {
+ int dim = 0;
+
+ for (int i = 0; i < fc->nbFields; i++) {
+ if (fc->fields[i].data == nullptr) {
+ continue;
+ }
+ std::string field_name(fc->fields[i].name);
+
+ if (field_name.compare("dim") == 0) {
+ dim = static_cast(fc->fields[i].data)[0];
+ }
+ }
+
+ CumMaxMinPluginDynamic *plugin =
+ new CumMaxMinPluginDynamic(name, dim, mCumType);
+ plugin->setPluginNamespace(getPluginNamespace());
+ return plugin;
+}
+
+nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::deserializePlugin(
+ const char *name, const void *serialData, size_t serialLength) {
+ // This object will be deleted when the network is destroyed, which will
+ // call FCPluginDynamic::destroy()
+ auto plugin = new CumMaxMinPluginDynamic(name, serialData, serialLength);
+ plugin->setPluginNamespace(getPluginNamespace());
+ return plugin;
+}
+
+void CumMaxMinPluginDynamicCreator::setPluginNamespace(
+ const char *libNamespace) {
+ mNamespace = libNamespace;
+}
+
+const char *CumMaxMinPluginDynamicCreator::getPluginNamespace() const {
+ return mNamespace.c_str();
+}
+
+CumMaxPluginDynamicCreator::CumMaxPluginDynamicCreator()
+ : CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMAX) {}
+
+const char *CumMaxPluginDynamicCreator::getPluginName() const {
+ return CUMMAX_PLUGIN_NAME;
+}
+
+CumMinPluginDynamicCreator::CumMinPluginDynamicCreator()
+ : CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMIN) {}
+
+const char *CumMinPluginDynamicCreator::getPluginName() const {
+ return CUMMIN_PLUGIN_NAME;
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu
new file mode 100644
index 000000000..753104071
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu
@@ -0,0 +1,89 @@
+
+#include "common_cuda_helper.hpp"
+#include "trt_cuda_helper.cuh"
+#include "trt_plugin_helper.hpp"
+
+using mmcv::TensorDesc;
+
+template
+__global__ void cummaxmin_kernel(const scalar_t *input, scalar_t *output_value,
+ int *output_index, TensorDesc tensor_desc,
+ int cum_dim, int cum_type) {
+ const size_t cum_size = tensor_desc.shape[cum_dim];
+ const size_t cum_stride = tensor_desc.stride[cum_dim];
+ const size_t data_size =
+ tensor_desc.stride[0] * tensor_desc.shape[0] / cum_size;
+ CUDA_1D_KERNEL_LOOP(index, data_size) {
+ size_t cum_offset =
+ index / cum_stride * (cum_size * cum_stride) + index % cum_stride;
+ int cum_index = 0;
+ auto cum_value = input[cum_offset];
+ output_value[cum_offset] = cum_value;
+ output_index[cum_offset] = cum_index;
+
+ for (size_t cum_index_current = 1; cum_index_current < cum_size;
+ ++cum_index_current) {
+ cum_offset += cum_stride;
+ const auto cum_value_current = input[cum_offset];
+ switch (cum_type) {
+ case 0: // max
+ if (cum_value_current > cum_value) {
+ cum_value = cum_value_current;
+ cum_index = cum_index_current;
+ }
+ break;
+ case 1: // min
+ if (cum_value_current < cum_value) {
+ cum_value = cum_value_current;
+ cum_index = cum_index_current;
+ }
+ break;
+ }
+ output_value[cum_offset] = cum_value;
+ output_index[cum_offset] = cum_index;
+ }
+ }
+}
+
+template
+void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value,
+ int *output_index, const int *dims, int nbDims,
+ int cum_dim, int cum_type, cudaStream_t stream) {
+ // fill tensordesc and initial
+ TensorDesc tensor_desc;
+ memset((void *)&tensor_desc, 0, sizeof(TensorDesc));
+ tensor_desc.dim = nbDims;
+ tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
+ tensor_desc.stride[nbDims - 1] = 1;
+ for (int i = nbDims - 2; i >= 0; --i) {
+ tensor_desc.shape[i] = dims[i];
+ tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
+ }
+
+ // cum dim should be larger than 0
+ cum_dim = cum_dim >= 0 ? cum_dim : (nbDims + cum_dim);
+
+ const int data_size =
+ tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim];
+
+ const int col_block = DIVUP(data_size, THREADS_PER_BLOCK);
+
+ cummaxmin_kernel<<>>(
+ input, output_value, output_index, tensor_desc, cum_dim, cum_type);
+}
+
+void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream) {
+ CumMaxMinForwardLauncher(input, output_value, output_index, dims,
+ nbDims, cum_dim, cum_type, stream);
+}
+
+void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
+ int *output_index, const int *dims,
+ int nbDims, int cum_dim, int cum_type,
+ cudaStream_t stream) {
+ CumMaxMinForwardLauncher(input, output_value, output_index, dims, nbDims,
+ cum_dim, cum_type, stream);
+}
diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
index 06d034c36..ab4ee11e8 100644
--- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
+++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
@@ -1,11 +1,14 @@
#include "trt_plugin.hpp"
+#include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp"
#include "trt_nms.hpp"
#include "trt_roi_align.hpp"
#include "trt_scatternd.hpp"
+REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator);
+REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator);
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
diff --git a/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp
new file mode 100644
index 000000000..5b856b02f
--- /dev/null
+++ b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp
@@ -0,0 +1,122 @@
+#ifndef TRT_CUMMAXMIN_HPP
+#define TRT_CUMMAXMIN_HPP
+#include
+#include
+
+#include "trt_plugin_helper.hpp"
+
+enum TRT_CUMCMPTYPE { TRT_CUMMAX = 0, TRT_CUMMIN = 1 };
+
+// implement of cummax and cummin
+class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
+ public:
+ CumMaxMinPluginDynamic(const std::string &name, int dim,
+ TRT_CUMCMPTYPE cumType);
+
+ CumMaxMinPluginDynamic(const std::string name, const void *data,
+ size_t length);
+
+ CumMaxMinPluginDynamic() = delete;
+
+ ~CumMaxMinPluginDynamic();
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt *clone() const override;
+ nvinfer1::DimsExprs getOutputDimensions(
+ int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
+ nvinfer1::IExprBuilder &exprBuilder) override;
+ bool supportsFormatCombination(int pos,
+ const nvinfer1::PluginTensorDesc *inOut,
+ int nbInputs, int nbOutputs) override;
+ void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
+ int nbInputs,
+ const nvinfer1::DynamicPluginTensorDesc *out,
+ int nbOutputs) override;
+ size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
+ int nbInputs,
+ const nvinfer1::PluginTensorDesc *outputs,
+ int nbOutputs) const override;
+ int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
+ const nvinfer1::PluginTensorDesc *outputDesc,
+ const void *const *inputs, void *const *outputs, void *workspace,
+ cudaStream_t stream) override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(int index,
+ const nvinfer1::DataType *inputTypes,
+ int nbInputs) const override;
+
+ // IPluginV2 Methods
+ const char *getPluginType() const override;
+ const char *getPluginVersion() const override;
+ int getNbOutputs() const override;
+ int initialize() override;
+ void terminate() override;
+ size_t getSerializationSize() const override;
+ void serialize(void *buffer) const override;
+ void destroy() override;
+ void setPluginNamespace(const char *pluginNamespace) override;
+ const char *getPluginNamespace() const override;
+
+ protected:
+ const std::string mLayerName;
+ std::string mNamespace;
+
+ int mDim;
+ TRT_CUMCMPTYPE mCumType;
+
+ protected:
+ // To prevent compiler warnings.
+ using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
+ using nvinfer1::IPluginV2DynamicExt::configurePlugin;
+ using nvinfer1::IPluginV2DynamicExt::enqueue;
+ using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
+ using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
+ using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
+ using nvinfer1::IPluginV2DynamicExt::supportsFormat;
+};
+
+// cummax and cummin creator
+class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator {
+ public:
+ CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE cumType);
+
+ const char *getPluginName() const override;
+
+ const char *getPluginVersion() const override;
+
+ const nvinfer1::PluginFieldCollection *getFieldNames() override;
+
+ nvinfer1::IPluginV2 *createPlugin(
+ const char *name, const nvinfer1::PluginFieldCollection *fc) override;
+
+ nvinfer1::IPluginV2 *deserializePlugin(const char *name,
+ const void *serialData,
+ size_t serialLength) override;
+
+ void setPluginNamespace(const char *pluginNamespace) override;
+
+ const char *getPluginNamespace() const override;
+
+ protected:
+ TRT_CUMCMPTYPE mCumType;
+ nvinfer1::PluginFieldCollection mFC;
+ std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+
+// cummax creator
+class CumMaxPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
+ public:
+ CumMaxPluginDynamicCreator();
+ const char *getPluginName() const override;
+};
+
+// cummin creator
+class CumMinPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
+ public:
+ CumMinPluginDynamicCreator();
+ const char *getPluginName() const override;
+};
+
+#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP
diff --git a/mmcv/tensorrt/tensorrt_utils.py b/mmcv/tensorrt/tensorrt_utils.py
index 5966881df..22e1e8860 100644
--- a/mmcv/tensorrt/tensorrt_utils.py
+++ b/mmcv/tensorrt/tensorrt_utils.py
@@ -238,7 +238,7 @@ class TRTWraper(torch.nn.Module):
output_names should be the same as onnx model.
"""
- def __init__(self, engine, input_names, output_names):
+ def __init__(self, engine, input_names=None, output_names=None):
super(TRTWraper, self).__init__()
self.engine = engine
if isinstance(self.engine, str):
@@ -250,6 +250,11 @@ class TRTWraper(torch.nn.Module):
self._register_state_dict_hook(TRTWraper._on_state_dict)
self.context = self.engine.create_execution_context()
+ # get input and output names from engine
+ if input_names is None or output_names is None:
+ names = [_ for _ in self.engine]
+ input_names = list(filter(self.engine.binding_is_input, names))
+ output_names = list(set(names) - set(input_names))
self.input_names = input_names
self.output_names = output_names
diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py
index 3f8fe473c..ddfa68165 100644
--- a/tests/test_ops/test_tensorrt.py
+++ b/tests/test_ops/test_tensorrt.py
@@ -1,5 +1,6 @@
import os
from functools import partial
+from typing import Callable
import numpy as np
import onnx
@@ -478,3 +479,99 @@ def test_grid_sample(mode, padding_mode, align_corners):
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results)
+
+
+@pytest.mark.parametrize('func', [torch.cummax, torch.cummin])
+def test_cummin_cummax(func: Callable):
+ # Note generally `cummax` or `cummin` is exportable to ONNX
+ # as long as the pytorch version >= 1.5.0, since `torch.cummax`
+ # is only supported with torch >= 1.5.0.
+ # But when `cummax` or `cummin` serves as an intermediate component
+ # whose outputs is used as inputs for another modules, it's expected
+ # that pytorch version must be >= 1.7.0. Otherwise error appears like:
+ # `RuntimeError: tuple appears in op that does not forward tuples,
+ # unsupported 'kind: prim::PythonOp`.
+ from packaging import version
+ if version.parse(torch.__version__) < version.parse('1.7.0'):
+ pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0')
+
+ opset = 11
+ # register custom op `mmcv::cummax` and `mmcv::cummin`
+ from mmcv.onnx.symbolic import register_extra_symbolics
+ register_extra_symbolics(opset)
+
+ input_list = [
+ # arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
+ torch.rand((2, 3, 4, 1, 5)).cuda(),
+ torch.rand((1)).cuda()
+ ]
+
+ input_names = ['input']
+ output_names = ['output', 'indices']
+
+ for input in input_list:
+ ndims = input.dim()
+ # valid dim range is [-ndims, ndims-1]
+ # test for all `dim` value which is valid
+ for dim in range(-ndims, ndims):
+ cummax_func = partial(func, dim=dim)
+ wrapped_model = WrapFunction(cummax_func).eval().cuda()
+
+ with torch.no_grad():
+ torch.onnx.export(
+ wrapped_model,
+ input,
+ onnx_file,
+ export_params=True,
+ keep_initializers_as_inputs=False,
+ input_names=input_names,
+ output_names=output_names,
+ opset_version=opset)
+
+ onnx_model = onnx.load(onnx_file)
+
+ # create trt engine and wraper
+ opt_shape_dict = {
+ 'input':
+ [list(input.shape),
+ list(input.shape),
+ list(input.shape)]
+ }
+ # trt config
+ fp16_mode = False
+ max_workspace_size = 1 << 30
+
+ trt_engine = onnx2trt(
+ onnx_model,
+ opt_shape_dict,
+ fp16_mode=fp16_mode,
+ max_workspace_size=max_workspace_size)
+
+ # remove ONNX model after conversion
+ if os.path.exists(onnx_file):
+ os.remove(onnx_file)
+
+ # save TensorRT model
+ save_trt_engine(trt_engine, trt_file)
+
+ # load and wrap TensorRT model
+ trt_model = TRTWraper(trt_file)
+
+ # remove trt model after loading
+ if os.path.exists(trt_file):
+ os.remove(trt_file)
+
+ # compute trt output
+ with torch.no_grad():
+ trt_results = trt_model({'input': input.contiguous().clone()})
+ trt_output = trt_results['output']
+ trt_indices = trt_results['indices']
+
+ # compute pytorch output
+ with torch.no_grad():
+ pytorch_results = wrapped_model(input.clone())
+ pytorch_output = pytorch_results[0]
+ pytorch_indices = pytorch_results[1]
+
+ torch.testing.assert_allclose(trt_output, pytorch_output)
+ torch.testing.assert_allclose(trt_indices, pytorch_indices)