[Enhancement] Support RotatedRetinaNet TensorRT (#422)

* add rotated nms trt plugin

* fix ops output shape

* rebase

* fix lint

* add fp16, benchmark

* format docs

* remove unused definition, add ut

* add docs

* update docs

* add doc
pull/473/head
q.yao 2022-05-24 10:34:22 +08:00 committed by GitHub
parent 4710ab910d
commit e3a8baac4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1308 additions and 181 deletions

View File

@ -0,0 +1,3 @@
_base_ = ['./rotated-detection_tensorrt_dynamic-320x320-1024x1024.py']
backend_config = dict(common_config=dict(fp16_mode=True))

View File

@ -0,0 +1,32 @@
_base_ = ['./rotated-detection_static.py', '../_base_/backends/tensorrt.py']
onnx_config = dict(
output_names=['dets', 'labels'],
input_shape=None,
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'dets': {
0: 'batch',
1: 'num_dets',
},
'labels': {
0: 'batch',
1: 'num_dets',
},
},
)
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
opt_shape=[1, 3, 1024, 1024],
max_shape=[1, 3, 1024, 1024])))
])

View File

@ -1,118 +0,0 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
#include "trt_plugin_helper.hpp"
template <typename T_BBOX, typename T_SCORE, unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages,
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const int *indices, const T_SCORE *scores,
const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels,
bool clipBoxes) {
if (keepTopK > topK) return;
for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK;
i += gridDim.x * nthds_per_cta) {
const int imgId = i / keepTopK;
const int detId = i % keepTopK;
const int offset = imgId * numClasses * topK;
const int index = indices[offset + detId];
const T_SCORE score = scores[offset + detId];
if (index == -1) {
nmsedLabels[i] = -1;
nmsedDets[i * 5] = 0;
nmsedDets[i * 5 + 1] = 0;
nmsedDets[i * 5 + 2] = 0;
nmsedDets[i * 5 + 3] = 0;
nmsedDets[i * 5 + 4] = 0;
} else {
const int bboxOffset =
imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass));
const int bboxId =
((shareLocation ? (index % numPredsPerClass) : index % (numClasses * numPredsPerClass)) +
bboxOffset) *
4;
nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label
// clipped bbox xmin
nmsedDets[i * 5] =
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
// clipped bbox ymin
nmsedDets[i * 5 + 1] =
clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1];
// clipped bbox xmax
nmsedDets[i * 5 + 2] =
clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2];
// clipped bbox ymax
nmsedDets[i * 5 + 3] =
clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3];
nmsedDets[i * 5 + 4] = score;
}
}
}
template <typename T_BBOX, typename T_SCORE>
pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocation,
const int numImages, const int numPredsPerClass,
const int numClasses, const int topK, const int keepTopK,
const void *indices, const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
const int BS = 32;
const int GS = 32;
gatherNMSOutputs_kernel<T_BBOX, T_SCORE, BS><<<GS, BS, 0, stream>>>(
shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices,
(T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// gatherNMSOutputs LAUNCH CONFIG {{{
typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int,
const int, const int, const void *, const void *, const void *,
void *, void *, bool);
struct nmsOutLaunchConfig {
DataType t_bbox;
DataType t_score;
nmsOutFunc function;
nmsOutLaunchConfig(DataType t_bbox, DataType t_score) : t_bbox(t_bbox), t_score(t_score) {}
nmsOutLaunchConfig(DataType t_bbox, DataType t_score, nmsOutFunc function)
: t_bbox(t_bbox), t_score(t_score), function(function) {}
bool operator==(const nmsOutLaunchConfig &other) {
return t_bbox == other.t_bbox && t_score == other.t_score;
}
};
using nvinfer1::DataType;
static std::vector<nmsOutLaunchConfig> nmsOutFuncVec;
bool nmsOutputInit() {
nmsOutFuncVec.push_back(
nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, gatherNMSOutputs_gpu<float, float>));
return true;
}
static bool initialized = nmsOutputInit();
//}}}
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, const int numImages,
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE,
const void *indices, const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE);
for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) {
if (lc == nmsOutFuncVec[i]) {
DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i);
return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass,
numClasses, topK, keepTopK, indices, scores, bboxData,
nmsedDets, nmsedLabels, clipBoxes);
}
}
return STATUS_BAD_PARAM;
}

View File

@ -5,8 +5,8 @@
#include <cstring>
#include "kernel.h"
#include "trt_batched_nms_kernel.hpp"
#include "nms/batched_nms_kernel.hpp"
#include "nms/kernel.h"
#include "trt_serialize.hpp"
namespace mmdeploy {
@ -90,11 +90,12 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
int topk =
param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1];
bool rotated = false;
pluginStatus_t status = nmsInference(
stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId,
num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold,
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, workSpace,
param.isNormalized, false, mClipBoxes);
param.isNormalized, false, mClipBoxes, rotated);
ASSERT(status == STATUS_SUCCESS);
return 0;

View File

@ -0,0 +1,238 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_batched_rotated_nms.hpp"
#include <cstring>
#include "nms/batched_nms_kernel.hpp"
#include "nms/kernel.h"
#include "trt_serialize.hpp"
namespace mmdeploy {
using namespace nvinfer1;
using nvinfer1::plugin::NMSParameters;
namespace {
static const char* NMS_PLUGIN_VERSION{"1"};
static const char* NMS_PLUGIN_NAME{"TRTBatchedRotatedNMS"};
} // namespace
TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameters params)
: TRTPluginBase(name), param(params) {}
TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &param);
deserialize_value(&data, &length, &boxesSize);
deserialize_value(&data, &length, &scoresSize);
deserialize_value(&data, &length, &numPriors);
deserialize_value(&data, &length, &mClipBoxes);
}
int TRTBatchedRotatedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; }
nvinfer1::DimsExprs TRTBatchedRotatedNMS::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
ASSERT(nbInputs == 2);
ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs());
ASSERT(inputs[0].nbDims == 4);
ASSERT(inputs[1].nbDims == 3);
nvinfer1::DimsExprs ret;
ret.d[0] = inputs[0].d[0];
ret.d[1] = exprBuilder.constant(param.keepTopK);
switch (outputIndex) {
case 0:
ret.nbDims = 3;
ret.d[2] = exprBuilder.constant(6);
break;
case 1:
ret.nbDims = 2;
break;
default:
break;
}
return ret;
}
size_t TRTBatchedRotatedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
size_t batch_size = inputs[0].dims.d[0];
size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3];
size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2];
size_t num_priors = inputs[0].dims.d[1];
bool shareLocation = (inputs[0].dims.d[2] == 1);
int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1];
return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size,
param.numClasses, num_priors, topk, DataType::kFLOAT,
DataType::kFLOAT);
}
int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workSpace,
cudaStream_t stream) TRT_NOEXCEPT {
const void* const locData = inputs[0];
const void* const confData = inputs[1];
void* nmsedDets = outputs[0];
void* nmsedLabels = outputs[1];
size_t batch_size = inputDesc[0].dims.d[0];
size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3];
size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2];
size_t num_priors = inputDesc[0].dims.d[1];
bool shareLocation = (inputDesc[0].dims.d[2] == 1);
int topk =
param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1];
bool rotated = true;
pluginStatus_t status = nmsInference(
stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId,
num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold,
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, workSpace,
param.isNormalized, false, mClipBoxes, rotated);
ASSERT(status == STATUS_SUCCESS);
return 0;
}
size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT {
// NMSParameters, boxesSize,scoresSize,numPriors
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
}
void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, param);
serialize_value(&buffer, boxesSize);
serialize_value(&buffer, scoresSize);
serialize_value(&buffer, numPriors);
serialize_value(&buffer, mClipBoxes);
}
void TRTBatchedRotatedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {
// Validate input arguments
}
bool TRTBatchedRotatedNMS::supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 3) {
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
const char* TRTBatchedRotatedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; }
const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT {
return NMS_PLUGIN_VERSION;
}
IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT {
auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param);
plugin->boxesSize = boxesSize;
plugin->scoresSize = scoresSize;
plugin->numPriors = numPriors;
plugin->setPluginNamespace(mNamespace.c_str());
plugin->setClipParam(mClipBoxes);
return plugin;
}
nvinfer1::DataType TRTBatchedRotatedNMS::getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT {
ASSERT(index >= 0 && index < this->getNbOutputs());
if (index == 1) {
return nvinfer1::DataType::kINT32;
}
return inputTypes[0];
}
void TRTBatchedRotatedNMS::setClipParam(bool clip) { mClipBoxes = clip; }
TRTBatchedRotatedNMSCreator::TRTBatchedRotatedNMSCreator() {
mPluginAttributes.emplace_back(
PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(
PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* TRTBatchedRotatedNMSCreator::getPluginName() const TRT_NOEXCEPT {
return NMS_PLUGIN_NAME;
}
const char* TRTBatchedRotatedNMSCreator::getPluginVersion() const TRT_NOEXCEPT {
return NMS_PLUGIN_VERSION;
}
IPluginV2Ext* TRTBatchedRotatedNMSCreator::createPlugin(
const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT {
const PluginField* fields = fc->fields;
bool clipBoxes = true;
nvinfer1::plugin::NMSParameters params{};
for (int i = 0; i < fc->nbFields; ++i) {
const char* attrName = fields[i].name;
if (!strcmp(attrName, "background_label_id")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.backgroundLabelId = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "num_classes")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.numClasses = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "topk")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.topK = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "keep_topk")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.keepTopK = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "score_threshold")) {
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
params.scoreThreshold = *(static_cast<const float*>(fields[i].data));
} else if (!strcmp(attrName, "iou_threshold")) {
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
params.iouThreshold = *(static_cast<const float*>(fields[i].data));
} else if (!strcmp(attrName, "is_normalized")) {
params.isNormalized = *(static_cast<const bool*>(fields[i].data));
} else if (!strcmp(attrName, "clip_boxes")) {
clipBoxes = *(static_cast<const bool*>(fields[i].data));
}
}
TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, params);
plugin->setClipParam(clipBoxes);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
IPluginV2Ext* TRTBatchedRotatedNMSCreator::deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) TRT_NOEXCEPT {
// This object will be deleted when the network is destroyed, which will
// call NMS::destroy()
TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, serialData, serialLength);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(TRTBatchedRotatedNMSCreator);
} // namespace mmdeploy

View File

@ -0,0 +1,81 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_BATCHED_ROTATED_NMS_HPP
#define TRT_BATCHED_ROTATED_NMS_HPP
#include <string>
#include <vector>
#include "NvInferPluginUtils.h"
#include "trt_plugin_base.hpp"
namespace mmdeploy {
class TRTBatchedRotatedNMS : public TRTPluginBase {
public:
TRTBatchedRotatedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param);
TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length);
~TRTBatchedRotatedNMS() TRT_NOEXCEPT override = default;
int getNbOutputs() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs,
int nbInputs, nvinfer1::IExprBuilder& exprBuilder)
TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workSpace, cudaStream_t stream) TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
const char* getPluginType() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType,
int nbInputs) const TRT_NOEXCEPT override;
void setClipParam(bool clip);
private:
nvinfer1::plugin::NMSParameters param{};
int boxesSize{};
int scoresSize{};
int numPriors{};
bool mClipBoxes{};
};
class TRTBatchedRotatedNMSCreator : public TRTPluginCreatorBase {
public:
TRTBatchedRotatedNMSCreator();
~TRTBatchedRotatedNMSCreator() TRT_NOEXCEPT override = default;
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif

View File

@ -13,6 +13,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
const float scoreThreshold, const float iouThreshold,
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
const void* confData, void* nmsedDets, void* nmsedLabels,
void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes);
void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes,
bool rotated = false);
#endif

View File

@ -22,17 +22,6 @@ struct Bbox {
Bbox() = default;
};
template <typename T>
struct BboxInfo {
T conf_score;
int label;
int bbox_idx;
bool kept;
BboxInfo(T conf_score, int label, int bbox_idx, bool kept)
: conf_score(conf_score), label(label), bbox_idx(bbox_idx), kept(kept) {}
BboxInfo() = default;
};
size_t get_cuda_arch(int devID);
int8_t* alignPtr(int8_t* ptr, uintptr_t to);
@ -47,6 +36,13 @@ pluginStatus_t allClassNMS(cudaStream_t stream, int num, int num_classes, int nu
void* beforeNMS_scores, void* beforeNMS_index_array,
void* afterNMS_scores, void* afterNMS_index_array, bool flipXY = false);
pluginStatus_t allClassRotatedNMS(cudaStream_t stream, int num, int num_classes,
int num_preds_per_class, int top_k, float nms_threshold,
bool share_location, bool isNormalized, DataType DT_SCORE,
DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores,
void* beforeNMS_index_array, void* afterNMS_scores,
void* afterNMS_index_array, bool flipXY = false);
size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX);
size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX);
@ -80,18 +76,10 @@ pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation, int num
int numPredsPerClass, int numClasses, int topK, int keepTopK,
DataType DT_BBOX, DataType DT_SCORE, const void* indices,
const void* scores, const void* bboxData, void* nmsedDets,
void* nmsedLabels, bool clipBoxes = true);
void* nmsedLabels, bool clipBoxes = true, bool rotated = false);
size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses,
int numPredsPerClass, int topK, DataType DT_BBOX,
DataType DT_SCORE);
pluginStatus_t nmsInference(cudaStream_t stream, int N, int boxesSize, int scoresSize,
bool shareLocation, int backgroundLabelId, int numPredsPerClass,
int numClasses, int topK, int keepTopK, float scoreThreshold,
float iouThreshold, DataType DT_BBOX, const void* locData,
DataType DT_SCORE, const void* confData, void* nmsedDets,
void* nmsedLabels, void* workspace, bool isNormalized = true,
bool confSigmoid = false, bool clipBoxes = true);
#endif

View File

@ -3,7 +3,7 @@
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
#include "nms/kernel.h"
template <typename T_BBOX>
__device__ T_BBOX bboxSize(const Bbox<T_BBOX> &bbox, const bool normalized, T_BBOX offset) {
@ -63,13 +63,6 @@ __device__ float jaccardOverlap(const Bbox<T_BBOX> &bbox1, const Bbox<T_BBOX> &b
}
}
template <typename T_BBOX>
__device__ void emptyBboxInfo(BboxInfo<T_BBOX> *bbox_info) {
bbox_info->conf_score = T_BBOX(0);
bbox_info->label = -2; // -1 is used for all labels when shared_location is true
bbox_info->bbox_idx = -1;
bbox_info->kept = false;
}
/********** new NMS for only score and index array **********/
template <typename T_SCORE, typename T_BBOX, int TSIZE>
@ -255,7 +248,8 @@ pluginStatus_t allClassNMS(cudaStream_t stream, const int num, const int num_cla
printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n",
top_k);
}
nmsLaunchConfigSSD lc = nmsLaunchConfigSSD(DT_SCORE, DT_BBOX, allClassNMS_gpu<float, float>);
nmsLaunchConfigSSD lc(DT_SCORE, DT_BBOX);
for (unsigned i = 0; i < nmsFuncVec.size(); ++i) {
if (lc == nmsFuncVec[i]) {
DEBUG_PRINTF("all class nms kernel %d\n", i);

View File

@ -0,0 +1,494 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
#include <cmath>
#include <vector>
#include "nms/kernel.h"
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
__host__ __device__ __forceinline__ Point(const T &px = 0, const T &py = 0) : x(px), y(py) {}
__host__ __device__ __forceinline__ Point operator+(const Point &p) const {
return Point(x + p.x, y + p.y);
}
__host__ __device__ __forceinline__ Point &operator+=(const Point &p) {
x += p.x;
y += p.y;
return *this;
}
__host__ __device__ __forceinline__ Point operator-(const Point &p) const {
return Point(x - p.x, y - p.y);
}
__host__ __device__ __forceinline__ Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
__host__ __device__ __forceinline__ T dot_2d(const Point<T> &A, const Point<T> &B) {
return A.x * B.x + A.y * B.y;
}
template <typename T>
__host__ __device__ __forceinline__ T cross_2d(const Point<T> &A, const Point<T> &B) {
return A.x * B.y - B.x * A.y;
}
template <typename T>
__host__ __device__ __forceinline__ void get_rotated_vertices(const RotatedBox<T> &box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
// double theta = box.a * 0.01745329251;
// MODIFIED
double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
__host__ __device__ __forceinline__ int get_intersection_points(const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto &AB = vec2[0];
const auto &DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto &AB = vec1[0];
const auto &DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
__host__ __device__ __forceinline__ int convex_hull_graham(const Point<T> (&p)[24],
const int &num_in, Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto &start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) || (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
__host__ __device__ __forceinline__ T polygon_area(const Point<T> (&q)[24], const int &m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
__host__ __device__ __forceinline__ T rotated_boxes_intersection(const RotatedBox<T> &box1,
const RotatedBox<T> &box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
template <typename T>
__host__ __device__ __forceinline__ T single_box_iou_rotated(T const *const box1_raw,
T const *const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = rotated_boxes_intersection<T>(box1, box2);
T baseS = 1.0;
baseS = (area1 + area2 - intersection);
const T iou = intersection / baseS;
return iou;
}
/********** new NMS for only score and index array **********/
template <typename T_SCORE, typename T_BBOX, int TSIZE>
__global__ void allClassRotatedNMS_kernel(const int num, const int num_classes,
const int num_preds_per_class, const int top_k,
const float nms_threshold, const bool share_location,
const bool isNormalized,
T_BBOX *bbox_data, // bbox_data should be float to
// preserve location information
T_SCORE *beforeNMS_scores, int *beforeNMS_index_array,
T_SCORE *afterNMS_scores, int *afterNMS_index_array) {
//__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE];
extern __shared__ bool kept_bboxinfo_flag[];
for (int i = 0; i < num; i++) {
const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class;
const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation
const int bbox_idx_offset =
share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class);
// local thread data
int loc_bboxIndex[TSIZE];
T_BBOX loc_bbox[TSIZE * 5];
// initialize Bbox, Bboxinfo, kept_bboxinfo_flag
// Eliminate shared memory RAW hazard
__syncthreads();
#pragma unroll
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int item_idx = offset + cur_idx;
if (item_idx < max_idx) {
loc_bboxIndex[t] = beforeNMS_index_array[item_idx];
if (loc_bboxIndex[t] >= 0)
// if (loc_bboxIndex[t] != -1)
{
const int bbox_data_idx = share_location
? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset)
: loc_bboxIndex[t];
memcpy(&loc_bbox[t * 5], &bbox_data[bbox_data_idx * 5], 5 * sizeof(T_BBOX));
kept_bboxinfo_flag[cur_idx] = true;
} else {
kept_bboxinfo_flag[cur_idx] = false;
}
} else {
kept_bboxinfo_flag[cur_idx] = false;
}
}
// filter out overlapped boxes with lower scores
int ref_item_idx = offset;
int ref_bbox_idx =
share_location
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset)
: beforeNMS_index_array[ref_item_idx];
while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) {
T_BBOX ref_bbox[5];
memcpy(&ref_bbox[0], &bbox_data[ref_bbox_idx * 5], 5 * sizeof(T_BBOX));
// Eliminate shared memory RAW hazard
__syncthreads();
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int item_idx = offset + cur_idx;
if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) {
// TODO: may need to add bool normalized as argument, HERE true means
// normalized
if (single_box_iou_rotated(&ref_bbox[0], loc_bbox + t * 5) > nms_threshold) {
kept_bboxinfo_flag[cur_idx] = false;
}
}
}
__syncthreads();
do {
ref_item_idx++;
} while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]);
ref_bbox_idx =
share_location
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset)
: beforeNMS_index_array[ref_item_idx];
}
// store data
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int read_item_idx = offset + cur_idx;
const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx;
/*
* If not not keeping the bbox
* Set the score to 0
* Set the bounding box index to -1
*/
if (read_item_idx < max_idx) {
afterNMS_scores[write_item_idx] =
kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f;
afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1;
}
}
}
}
template <typename T_SCORE, typename T_BBOX>
pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, const int num, const int num_classes,
const int num_preds_per_class, const int top_k,
const float nms_threshold, const bool share_location,
const bool isNormalized, void *bbox_data,
void *beforeNMS_scores, void *beforeNMS_index_array,
void *afterNMS_scores, void *afterNMS_index_array) {
#define P(tsize) allClassRotatedNMS_kernel<T_SCORE, T_BBOX, (tsize)>
void (*kernel[10])(const int, const int, const int, const int, const float, const bool,
const bool, float *, T_SCORE *, int *, T_SCORE *, int *) = {
P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10),
};
const int BS = 512;
const int GS = num_classes;
const int t_size = (top_k + BS - 1) / BS;
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized,
(T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
(T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// allClassNMS LAUNCH CONFIG
typedef pluginStatus_t (*rotatedNmsFunc)(cudaStream_t, const int, const int, const int, const int,
const float, const bool, const bool, void *, void *,
void *, void *, void *);
struct rotatedNmsLaunchConfig {
DataType t_score;
DataType t_bbox;
rotatedNmsFunc function;
rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox) : t_score(t_score), t_bbox(t_bbox) {}
rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox, rotatedNmsFunc function)
: t_score(t_score), t_bbox(t_bbox), function(function) {}
bool operator==(const rotatedNmsLaunchConfig &other) {
return t_score == other.t_score && t_bbox == other.t_bbox;
}
};
static std::vector<rotatedNmsLaunchConfig> rotatedNmsFuncVec;
bool rotatedNmsInit() {
rotatedNmsFuncVec.push_back(rotatedNmsLaunchConfig(DataType::kFLOAT, DataType::kFLOAT,
allClassRotatedNMS_gpu<float, float>));
return true;
}
static bool initialized = rotatedNmsInit();
pluginStatus_t allClassRotatedNMS(cudaStream_t stream, const int num, const int num_classes,
const int num_preds_per_class, const int top_k,
const float nms_threshold, const bool share_location,
const bool isNormalized, const DataType DT_SCORE,
const DataType DT_BBOX, void *bbox_data, void *beforeNMS_scores,
void *beforeNMS_index_array, void *afterNMS_scores,
void *afterNMS_index_array, bool) {
auto __cuda_arch__ = get_cuda_arch(0); // assume there is only one arch 7.2 device
if (__cuda_arch__ == 720 && top_k >= 1000) {
printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n",
top_k);
}
rotatedNmsLaunchConfig lc(DT_SCORE, DT_BBOX);
for (unsigned i = 0; i < rotatedNmsFuncVec.size(); ++i) {
if (lc == rotatedNmsFuncVec[i]) {
DEBUG_PRINTF("all class rotated nms kernel %d\n", i);
return rotatedNmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k,
nms_threshold, share_location, isNormalized, bbox_data,
beforeNMS_scores, beforeNMS_index_array, afterNMS_scores,
afterNMS_index_array);
}
}
return STATUS_BAD_PARAM;
}

View File

@ -1,7 +1,7 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include "trt_batched_nms_kernel.hpp"
#include "nms/batched_nms_kernel.hpp"
pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize,
const int perBatchScoresSize, const bool shareLocation,
@ -10,7 +10,8 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
const float scoreThreshold, const float iouThreshold,
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
const void* confData, void* nmsedDets, void* nmsedLabels,
void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes) {
void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes,
bool rotated) {
const int topKVal = topK < 0 ? numPredsPerClass : topK;
const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK;
// locCount = batch_size * number_boxes_per_sample * 4
@ -45,8 +46,8 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
* This is equivalent to swapping axis
*/
if (!shareLocation) {
status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, 4, DataType::kFLOAT,
false, bboxDataRaw, bboxPermute);
status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, rotated ? 5 : 4,
DataType::kFLOAT, false, bboxDataRaw, bboxPermute);
ASSERT_FAILURE(status == STATUS_SUCCESS);
bboxData = bboxPermute;
}
@ -95,9 +96,16 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
// ymax]
bool flipXY = false;
// NMS
status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold,
shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, bboxData,
scores, indices, postNMSScores, postNMSIndices, flipXY);
if (rotated) {
status = allClassRotatedNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold,
shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT,
bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY);
} else {
status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold,
shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, bboxData,
scores, indices, postNMSScores, postNMSIndices, flipXY);
}
ASSERT_FAILURE(status == STATUS_SUCCESS);
// Sort the bounding boxes after NMS using scores
@ -109,7 +117,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
// Gather data from the sorted bounding boxes after NMS
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topKVal,
keepTopKVal, DataType::kFLOAT, DataType::kFLOAT, indices, scores,
bboxData, nmsedDets, nmsedLabels, clipBoxes);
bboxData, nmsedDets, nmsedLabels, clipBoxes, rotated);
ASSERT_FAILURE(status == STATUS_SUCCESS);

View File

@ -0,0 +1,152 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "nms/kernel.h"
#include "trt_plugin_helper.hpp"
template <typename T_BBOX, typename T_SCORE, bool rotated, unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages,
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const int *indices, const T_SCORE *scores,
const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels,
bool clipBoxes) {
if (keepTopK > topK) return;
for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK;
i += gridDim.x * nthds_per_cta) {
const int imgId = i / keepTopK;
const int detId = i % keepTopK;
const int offset = imgId * numClasses * topK;
const int index = indices[offset + detId];
const T_SCORE score = scores[offset + detId];
if (index == -1) {
nmsedLabels[i] = -1;
if (rotated) {
nmsedDets[i * 6] = 0;
nmsedDets[i * 6 + 1] = 0;
nmsedDets[i * 6 + 2] = 0;
nmsedDets[i * 6 + 3] = 0;
nmsedDets[i * 6 + 4] = 0;
nmsedDets[i * 6 + 5] = 0;
} else {
nmsedDets[i * 5] = 0;
nmsedDets[i * 5 + 1] = 0;
nmsedDets[i * 5 + 2] = 0;
nmsedDets[i * 5 + 3] = 0;
nmsedDets[i * 5 + 4] = 0;
}
} else {
const int bboxOffset =
imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass));
nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label
if (rotated) {
const int bboxId = ((shareLocation ? (index % numPredsPerClass)
: index % (numClasses * numPredsPerClass)) +
bboxOffset) *
5;
// clipped bbox xmin
nmsedDets[i * 6] =
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
// clipped bbox ymin
nmsedDets[i * 6 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 1];
// clipped bbox xmax
nmsedDets[i * 6 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 2];
// clipped bbox ymax
nmsedDets[i * 6 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 3];
// clipped bbox angle
nmsedDets[i * 6 + 4] = clipBoxes ? max(min(bboxData[bboxId + 4], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 4];
nmsedDets[i * 6 + 5] = score;
} else {
const int bboxId = ((shareLocation ? (index % numPredsPerClass)
: index % (numClasses * numPredsPerClass)) +
bboxOffset) *
4;
// clipped bbox xmin
nmsedDets[i * 5] =
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
// clipped bbox ymin
nmsedDets[i * 5 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 1];
// clipped bbox xmax
nmsedDets[i * 5 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 2];
// clipped bbox ymax
nmsedDets[i * 5 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 3];
nmsedDets[i * 5 + 4] = score;
}
}
}
}
template <typename T_BBOX, typename T_SCORE, bool rotated>
pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocation,
const int numImages, const int numPredsPerClass,
const int numClasses, const int topK, const int keepTopK,
const void *indices, const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
const int BS = 32;
const int GS = 32;
gatherNMSOutputs_kernel<T_BBOX, T_SCORE, rotated, BS><<<GS, BS, 0, stream>>>(
shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices,
(T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// gatherNMSOutputs LAUNCH CONFIG {{{
typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int,
const int, const int, const void *, const void *, const void *,
void *, void *, bool);
struct nmsOutLaunchConfig {
DataType t_bbox;
DataType t_score;
bool rotated;
nmsOutFunc function;
nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated)
: t_bbox(t_bbox), t_score(t_score), rotated(rotated) {}
nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated, nmsOutFunc function)
: t_bbox(t_bbox), t_score(t_score), rotated(rotated), function(function) {}
bool operator==(const nmsOutLaunchConfig &other) {
return t_bbox == other.t_bbox && t_score == other.t_score && rotated == other.rotated;
}
};
using nvinfer1::DataType;
static std::vector<nmsOutLaunchConfig> nmsOutFuncVec;
bool nmsOutputInit() {
nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, false,
gatherNMSOutputs_gpu<float, float, false>));
nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, true,
gatherNMSOutputs_gpu<float, float, true>));
return true;
}
static bool initialized = nmsOutputInit();
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, const int numImages,
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE,
const void *indices, const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels, bool clipBoxes, bool rotated) {
nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated);
for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) {
if (lc == nmsOutFuncVec[i]) {
DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i);
return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass,
numClasses, topK, keepTopK, indices, scores, bboxData,
nmsedDets, nmsedLabels, clipBoxes);
}
}
return STATUS_BAD_PARAM;
}

View File

@ -6,7 +6,7 @@
#include <cub/cub.cuh>
#include "cublas_v2.h"
#include "kernel.h"
#include "nms/kernel.h"
#include "trt_plugin_helper.hpp"
#define CUDA_MEM_ALIGN 256

View File

@ -3,7 +3,7 @@
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
#include "nms/kernel.h"
template <typename Dtype, unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__

View File

@ -4,8 +4,8 @@
#include <vector>
#include "cub/cub.cuh"
#include "cub_helper.h"
#include "kernel.h"
#include "nms/cub_helper.h"
#include "nms/kernel.h"
#include "trt_plugin_helper.hpp"
template <typename T_SCORE, unsigned nthds_per_cta>

View File

@ -4,8 +4,8 @@
#include <vector>
#include "cub/cub.cuh"
#include "cub_helper.h"
#include "kernel.h"
#include "nms/cub_helper.h"
#include "nms/kernel.h"
template <typename T_SCORE>
pluginStatus_t sortScoresPerImage_gpu(cudaStream_t stream, const int num_images,

View File

@ -1960,8 +1960,8 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
<td align="center">mAP</td>
<td align="center">0.698</td>
<td align="center">0.698</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">0.698</td>
<td align="center">0.697</td>
<td align="center">-</td>
<td align="center">-</td>
<td rowspan="2">$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>

View File

@ -10,7 +10,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en
| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
|:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:|
| RotatedRetinaNet | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
| RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
### Example

View File

@ -45,6 +45,12 @@
- [Inputs](#inputs-6)
- [Outputs](#outputs-6)
- [Type Constraints](#type-constraints-6)
- [TRTBatchedRotatedNMS](#trtbatchedrotatednms)
- [Description](#description-7)
- [Parameters](#parameters-7)
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
<!-- TOC -->
@ -316,3 +322,43 @@ None
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear)
### TRTBatchedRotatedNMS
#### Description
Batched rotated NMS with a fixed number of output bounding boxes.
#### Parameters
| Type | Parameter | Description |
| ------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
| `int` | `background_label_id` | The label ID for the background class. If there is no background class, set it to `-1`. |
| `int` | `num_classes` | The number of classes. |
| `int` | `topK` | The number of bounding boxes to be fed into the NMS step. |
| `int` | `keepTopK` | The number of total bounding boxes to be kept per-image after the NMS step. Should be less than or equal to the `topK` value. |
| `float` | `scoreThreshold` | The scalar threshold for score (low scoring boxes are removed). |
| `float` | `iouThreshold` | The scalar threshold for IoU (new boxes that have high IoU overlap with previously selected boxes are removed). |
| `int` | `isNormalized` | Set to `false` if the box coordinates are not normalized, meaning they are not in the range `[0,1]`. Defaults to `true`. |
| `int` | `clipBoxes` | Forcibly restrict bounding boxes to the normalized range `[0,1]`. Only applicable if `isNormalized` is also `true`. Defaults to `true`. |
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>boxes; 4-D tensor of shape (N, num_boxes, num_classes, 5), where N is the batch size; `num_boxes` is the number of boxes; `num_classes` is the number of classes, which could be 1 if the boxes are shared between all classes.</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>scores; 4-D tensor of shape (N, num_boxes, 1, num_classes). </dd>
</dl>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>dets; 3-D tensor of shape (N, valid_num_boxes, 6), `valid_num_boxes` is the number of boxes after NMS. For each row `dets[i,j,:] = [x0, y0, width, height, theta, score]`</dd>
<dt><tt>outputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>labels; 2-D tensor of shape (N, valid_num_boxes). </dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear)

View File

@ -70,6 +70,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
### Note

View File

@ -1563,8 +1563,8 @@ GPU: ncnn, TensorRT, PPLNN
<td align="center">mAP</td>
<td align="center">0.698</td>
<td align="center">0.698</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">0.698</td>
<td align="center">0.697</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>

View File

@ -68,6 +68,7 @@
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
### Note

View File

@ -2,7 +2,9 @@
import torch
from torch import Tensor
from mmdeploy.mmcv.ops import ONNXNMSRotatedOp
import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop
def select_nms_index(scores: torch.Tensor,
@ -63,7 +65,7 @@ def select_nms_index(scores: torch.Tensor,
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
topk_batch_inds = torch.arange(
batch_size, dtype=topk_inds.dtype,
device=topk_inds.device).view(-1, 1).expand_as(topk_inds)
device=topk_inds.device).unsqueeze(1)
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]
@ -71,19 +73,18 @@ def select_nms_index(scores: torch.Tensor,
return batched_dets, batched_labels
def multiclass_nms_rotated(boxes: Tensor,
scores: Tensor,
iou_threshold: float = 0.1,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
def _multiclass_nms_rotated(boxes: Tensor,
scores: Tensor,
iou_threshold: float = 0.1,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""NMSRotated for multi-class bboxes.
This function helps exporting to onnx with batch and multiclass NMSRotated
op. It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 5).
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
scores (Tensor): The detection scores of shape
@ -105,8 +106,7 @@ def multiclass_nms_rotated(boxes: Tensor,
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long()
batch_inds = torch.arange(batch_size).unsqueeze(1).long()
boxes = boxes[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
@ -118,3 +118,58 @@ def multiclass_nms_rotated(boxes: Tensor,
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
return dets, labels
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.'
'_multiclass_nms_rotated',
backend='tensorrt')
def multiclass_nms_rotated_static(ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Wrapper for `multiclass_nms` with TensorRT.
Args:
ctx (ContextCaller): The context with additional information.
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 6]
and `labels` of shape [N, num_det].
"""
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min(
max_output_boxes_per_class, keep_top_k)
dets, labels = TRTBatchedRotatedNMSop.apply(boxes, scores,
int(scores.shape[-1]),
pre_top_k, keep_top_k,
iou_threshold, score_threshold,
-1)
return dets, labels
@mark(
'multiclass_nms_rotated',
inputs=['boxes', 'scores'],
outputs=['dets', 'labels'])
def multiclass_nms_rotated(*args, **kwargs):
"""Wrapper function for `_multiclass_nms`."""
return mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.\
_multiclass_nms_rotated(*args, **kwargs)

View File

@ -74,3 +74,76 @@ class ONNXNMSRotatedOp(torch.autograd.Function):
scores,
iou_threshold_f=float(iou_threshold),
score_threshold_f=float(score_threshold))
class TRTBatchedRotatedNMSop(torch.autograd.Function):
"""Create mmdeploy::TRTBatchedRotatedNMSop op for TensorRT backend.
NMS in ONNX supports dynamic outputs. This class helps replace
onnx::NonMaxSuppression with mmdeploy::TRTBatchedRotatedNMSop.
"""
@staticmethod
def forward(ctx,
boxes: Tensor,
scores: Tensor,
num_classes: int,
pre_topk: int,
after_topk: int,
iou_threshold: float,
score_threshold: float,
background_label_id: int = -1):
"""Forward of batched rotated nms.
Args:
ctx (Context): The context with meta information.
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
num_classes (int): MThe number of classes in the network.
pre_topk (int): The number of bounding boxes to be fed into
the NMS step.
after_topk (int): The number of total bounding boxes to be kept
per-image after the NMS step. Should be less than or equal
to the pre_topk value.
iou_threshold (float): IOU threshold of nms.
score_threshold (float): score threshold of nms.
background_label_id (int): The label ID for the background class.
If there is no background class, set it to -1.
Returns:
dets (Tensor): Bboxes and scores of the rotated nms results.
labels (Tensor): Class id of the rotated nms results.
"""
batch_size, num_boxes, num_classes = scores.shape
out_boxes = min(num_boxes, after_topk)
return torch.rand(batch_size, out_boxes,
6).to(scores.device), torch.randint(
0, num_classes,
(batch_size, out_boxes)).to(scores.device)
@staticmethod
def symbolic(g,
boxes: Tensor,
scores: Tensor,
num_classes: int,
pre_topk: int,
after_topk: int,
iou_threshold: float,
score_threshold: float,
background_label_id: int = -1):
"""Symbolic function for mmdeploy::TRTBatchedNMS."""
return g.op(
'mmdeploy::TRTBatchedRotatedNMS',
boxes,
scores,
num_classes_i=num_classes,
background_label_id_i=background_label_id,
iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold,
topk_i=pre_topk,
keep_topk_i=after_topk,
is_normalized_i=False,
clip_boxes_i=False,
outputs=2)

View File

@ -379,6 +379,83 @@ def test_batched_nms(backend,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('num_classes,pre_topk,after_topk,iou_threshold,'
'score_threshold,background_label_id',
[(5, 6, 3, 0.7, 0.1, -1)])
def test_batched_rotated_nms(backend,
num_classes,
pre_topk,
after_topk,
iou_threshold,
score_threshold,
background_label_id,
input_list=None,
save_dir=None):
backend.check_env()
pytest.importorskip('mmrotate', reason='mmrorate is not installed.')
if input_list is None:
nms_boxes = torch.tensor(
[[[291.1746, 316.2263, 343.5029, 347.7312, 1.],
[288.4846, 315.0447, 343.7267, 346.5630, 2.],
[288.5307, 318.1989, 341.6425, 349.7222, 3.],
[918.9102, 83.7463, 933.3920, 164.9041, 4.],
[895.5786, 78.2361, 907.8049, 172.0883, 5.],
[292.5816, 316.5563, 340.3462, 352.9989, 6.],
[609.4592, 83.5447, 631.2532, 144.0749, 7.],
[917.7308, 85.5870, 933.2839, 168.4530, 8.],
[895.5138, 79.3596, 908.2865, 171.0418, 9.],
[291.4747, 318.6987, 347.1208, 349.5754, 10.]]])
scores = torch.tensor([[[0.9577, 0.9745, 0.3030, 0.6589, 0.2742],
[0.1618, 0.7963, 0.5124, 0.6964, 0.6850],
[0.8425, 0.4843, 0.9489, 0.8068, 0.7340],
[0.7337, 0.4340, 0.9923, 0.0704, 0.4506],
[0.3090, 0.5606, 0.6939, 0.3764, 0.6920],
[0.0044, 0.7986, 0.2221, 0.2782, 0.4378],
[0.7293, 0.2735, 0.8381, 0.0264, 0.6278],
[0.7144, 0.1066, 0.4125, 0.4041, 0.8819],
[0.4963, 0.7891, 0.6908, 0.1499, 0.5584],
[0.4385, 0.6035, 0.0508, 0.0662, 0.5938]]])
else:
nms_boxes = torch.tensor(input_list[0], dtype=torch.float32)
scores = torch.tensor(input_list[1], dtype=torch.float32)
from mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms import \
_multiclass_nms_rotated
expected_result = _multiclass_nms_rotated(
nms_boxes,
scores,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_topk + 1,
keep_top_k=after_topk + 1)
expected_result = (expected_result[0][:,
0:-1, :], expected_result[1][:,
0:-1])
boxes = nms_boxes.unsqueeze(2).tile(num_classes, 1)
from mmdeploy.mmcv.ops.nms_rotated import TRTBatchedRotatedNMSop
batched_rotated_nms = TRTBatchedRotatedNMSop.apply
def wrapped_function(boxes, scores):
return batched_rotated_nms(boxes, scores, num_classes, pre_topk,
after_topk, iou_threshold, score_threshold,
background_label_id)
wrapped_model = WrapFunction(wrapped_function)
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [boxes, scores],
'batched_rotated_nms',
input_names=['boxes', 'scores'],
output_names=['batched_rotated_nms_bboxes', 'inds'],
expected_result=expected_result,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize(
'out_size, pool_mode, sampling_ratio,roi_scale_factor,'