[Feature] Add TensorRT batched NMS support (#3)

* add trt batched_nms plugin

* update trt batched nms
pull/12/head
q.yao 2021-06-25 19:31:16 +08:00 committed by GitHub
parent 6c47ee3d2a
commit 5998d24766
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1725 additions and 12 deletions

3
.gitmodules vendored 100644
View File

@ -0,0 +1,3 @@
[submodule "third_party/cub"]
path = third_party/cub
url = https://github.com/NVIDIA/cub.git

View File

@ -1,2 +1,2 @@
[settings]
known_third_party = mmcv,mmdet,numpy,setuptools,tensorrt,torch
known_third_party = mmcv,mmdet,numpy,onnx,setuptools,tensorrt,torch

View File

@ -41,10 +41,17 @@ if(NOT TENSORRT_FOUND)
endif()
INCLUDE_DIRECTORIES(${TENSORRT_INCLUDE_DIR})
# cub
if (NOT DEFINED CUB_ROOT_DIR)
set(CUB_ROOT_DIR "${PROJECT_SOURCE_DIR}/third_party/cub")
endif()
INCLUDE_DIRECTORIES(${CUB_ROOT_DIR})
# add plugin source
set(PLUGIN_LISTS scatternd
nms
roi_align)
roi_align
batched_nms)
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
file(GLOB PLUGIN_OPS_SRCS ${PLUGIN_ITER}/*.cpp ${PLUGIN_ITER}/*.cu)

View File

@ -0,0 +1,289 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
template <typename T_BBOX>
__device__ T_BBOX bboxSize(const Bbox<T_BBOX> &bbox, const bool normalized,
T_BBOX offset) {
if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) {
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0.
return 0;
} else {
T_BBOX width = bbox.xmax - bbox.xmin;
T_BBOX height = bbox.ymax - bbox.ymin;
if (normalized) {
return width * height;
} else {
// If bbox is not within range [0, 1].
return (width + offset) * (height + offset);
}
}
}
template <typename T_BBOX>
__device__ void intersectBbox(const Bbox<T_BBOX> &bbox1,
const Bbox<T_BBOX> &bbox2,
Bbox<T_BBOX> *intersect_bbox) {
if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin ||
bbox2.ymin > bbox1.ymax || bbox2.ymax < bbox1.ymin) {
// Return [0, 0, 0, 0] if there is no intersection.
intersect_bbox->xmin = T_BBOX(0);
intersect_bbox->ymin = T_BBOX(0);
intersect_bbox->xmax = T_BBOX(0);
intersect_bbox->ymax = T_BBOX(0);
} else {
intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin);
intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin);
intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax);
intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax);
}
}
template <typename T_BBOX>
__device__ float jaccardOverlap(const Bbox<T_BBOX> &bbox1,
const Bbox<T_BBOX> &bbox2,
const bool normalized, T_BBOX offset) {
Bbox<T_BBOX> intersect_bbox;
intersectBbox(bbox1, bbox2, &intersect_bbox);
float intersect_width, intersect_height;
if (normalized) {
intersect_width = intersect_bbox.xmax - intersect_bbox.xmin;
intersect_height = intersect_bbox.ymax - intersect_bbox.ymin;
} else {
intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset;
intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset;
}
if (intersect_width > 0 && intersect_height > 0) {
float intersect_size = intersect_width * intersect_height;
float bbox1_size = bboxSize(bbox1, normalized, offset);
float bbox2_size = bboxSize(bbox2, normalized, offset);
return intersect_size / (bbox1_size + bbox2_size - intersect_size);
} else {
return 0.;
}
}
template <typename T_BBOX>
__device__ void emptyBboxInfo(BboxInfo<T_BBOX> *bbox_info) {
bbox_info->conf_score = T_BBOX(0);
bbox_info->label =
-2; // -1 is used for all labels when shared_location is ture
bbox_info->bbox_idx = -1;
bbox_info->kept = false;
}
/********** new NMS for only score and index array **********/
template <typename T_SCORE, typename T_BBOX, int TSIZE>
__global__ void allClassNMS_kernel(
const int num, const int num_classes, const int num_preds_per_class,
const int top_k, const float nms_threshold, const bool share_location,
const bool isNormalized,
T_BBOX *bbox_data, // bbox_data should be float to preserve location
// information
T_SCORE *beforeNMS_scores, int *beforeNMS_index_array,
T_SCORE *afterNMS_scores, int *afterNMS_index_array, bool flipXY = false) {
//__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE];
extern __shared__ bool kept_bboxinfo_flag[];
for (int i = 0; i < num; i++) {
const int offset = i * num_classes * num_preds_per_class +
blockIdx.x * num_preds_per_class;
const int max_idx =
offset + top_k; // put top_k bboxes into NMS calculation
const int bbox_idx_offset = share_location
? (i * num_preds_per_class)
: (i * num_classes * num_preds_per_class);
// local thread data
int loc_bboxIndex[TSIZE];
Bbox<T_BBOX> loc_bbox[TSIZE];
// initialize Bbox, Bboxinfo, kept_bboxinfo_flag
// Eliminate shared memory RAW hazard
__syncthreads();
#pragma unroll
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int item_idx = offset + cur_idx;
if (item_idx < max_idx) {
loc_bboxIndex[t] = beforeNMS_index_array[item_idx];
if (loc_bboxIndex[t] >= 0)
// if (loc_bboxIndex[t] != -1)
{
const int bbox_data_idx =
share_location
? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset)
: loc_bboxIndex[t];
loc_bbox[t].xmin = flipXY ? bbox_data[bbox_data_idx * 4 + 1]
: bbox_data[bbox_data_idx * 4 + 0];
loc_bbox[t].ymin = flipXY ? bbox_data[bbox_data_idx * 4 + 0]
: bbox_data[bbox_data_idx * 4 + 1];
loc_bbox[t].xmax = flipXY ? bbox_data[bbox_data_idx * 4 + 3]
: bbox_data[bbox_data_idx * 4 + 2];
loc_bbox[t].ymax = flipXY ? bbox_data[bbox_data_idx * 4 + 2]
: bbox_data[bbox_data_idx * 4 + 3];
kept_bboxinfo_flag[cur_idx] = true;
} else {
kept_bboxinfo_flag[cur_idx] = false;
}
} else {
kept_bboxinfo_flag[cur_idx] = false;
}
}
// filter out overlapped boxes with lower scores
int ref_item_idx = offset;
int ref_bbox_idx =
share_location
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class +
bbox_idx_offset)
: beforeNMS_index_array[ref_item_idx];
while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) {
Bbox<T_BBOX> ref_bbox;
ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1]
: bbox_data[ref_bbox_idx * 4 + 0];
ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0]
: bbox_data[ref_bbox_idx * 4 + 1];
ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3]
: bbox_data[ref_bbox_idx * 4 + 2];
ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2]
: bbox_data[ref_bbox_idx * 4 + 3];
// Eliminate shared memory RAW hazard
__syncthreads();
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int item_idx = offset + cur_idx;
if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) {
// TODO: may need to add bool normalized as argument, HERE true means
// normalized
if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) >
nms_threshold) {
kept_bboxinfo_flag[cur_idx] = false;
}
}
}
__syncthreads();
do {
ref_item_idx++;
} while (ref_item_idx < max_idx &&
!kept_bboxinfo_flag[ref_item_idx - offset]);
ref_bbox_idx =
share_location
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class +
bbox_idx_offset)
: beforeNMS_index_array[ref_item_idx];
}
// store data
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int read_item_idx = offset + cur_idx;
const int write_item_idx =
(i * num_classes * top_k + blockIdx.x * top_k) + cur_idx;
/*
* If not not keeping the bbox
* Set the score to 0
* Set the bounding box index to -1
*/
if (read_item_idx < max_idx) {
afterNMS_scores[write_item_idx] = kept_bboxinfo_flag[cur_idx]
? beforeNMS_scores[read_item_idx]
: 0.0f;
afterNMS_index_array[write_item_idx] =
kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1;
}
}
}
}
template <typename T_SCORE, typename T_BBOX>
pluginStatus_t allClassNMS_gpu(
cudaStream_t stream, const int num, const int num_classes,
const int num_preds_per_class, const int top_k, const float nms_threshold,
const bool share_location, const bool isNormalized, void *bbox_data,
void *beforeNMS_scores, void *beforeNMS_index_array, void *afterNMS_scores,
void *afterNMS_index_array, bool flipXY = false) {
#define P(tsize) allClassNMS_kernel<T_SCORE, T_BBOX, (tsize)>
void (*kernel[10])(const int, const int, const int, const int, const float,
const bool, const bool, float *, T_SCORE *, int *,
T_SCORE *, int *, bool) = {
P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10),
};
const int BS = 512;
const int GS = num_classes;
const int t_size = (top_k + BS - 1) / BS;
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
num, num_classes, num_preds_per_class, top_k, nms_threshold,
share_location, isNormalized, (T_BBOX *)bbox_data,
(T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
(T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// allClassNMS LAUNCH CONFIG
typedef pluginStatus_t (*nmsFunc)(cudaStream_t, const int, const int, const int,
const int, const float, const bool,
const bool, void *, void *, void *, void *,
void *, bool);
struct nmsLaunchConfigSSD {
DataType t_score;
DataType t_bbox;
nmsFunc function;
nmsLaunchConfigSSD(DataType t_score, DataType t_bbox)
: t_score(t_score), t_bbox(t_bbox) {}
nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function)
: t_score(t_score), t_bbox(t_bbox), function(function) {}
bool operator==(const nmsLaunchConfigSSD &other) {
return t_score == other.t_score && t_bbox == other.t_bbox;
}
};
static std::vector<nmsLaunchConfigSSD> nmsFuncVec;
bool nmsInit() {
nmsFuncVec.push_back(nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT,
allClassNMS_gpu<float, float>));
return true;
}
static bool initialized = nmsInit();
pluginStatus_t allClassNMS(cudaStream_t stream, const int num,
const int num_classes, const int num_preds_per_class,
const int top_k, const float nms_threshold,
const bool share_location, const bool isNormalized,
const DataType DT_SCORE, const DataType DT_BBOX,
void *bbox_data, void *beforeNMS_scores,
void *beforeNMS_index_array, void *afterNMS_scores,
void *afterNMS_index_array, bool flipXY) {
nmsLaunchConfigSSD lc =
nmsLaunchConfigSSD(DT_SCORE, DT_BBOX, allClassNMS_gpu<float, float>);
for (unsigned i = 0; i < nmsFuncVec.size(); ++i) {
if (lc == nmsFuncVec[i]) {
DEBUG_PRINTF("all class nms kernel %d\n", i);
return nmsFuncVec[i].function(
stream, num, num_classes, num_preds_per_class, top_k, nms_threshold,
share_location, isNormalized, bbox_data, beforeNMS_scores,
beforeNMS_index_array, afterNMS_scores, afterNMS_index_array, flipXY);
}
}
return STATUS_BAD_PARAM;
}

View File

@ -0,0 +1,128 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include "cuda_runtime_api.h"
#include "kernel.h"
pluginStatus_t nmsInference(
cudaStream_t stream, const int N, const int perBatchBoxesSize,
const int perBatchScoresSize, const bool shareLocation,
const int backgroundLabelId, const int numPredsPerClass,
const int numClasses, const int topK, const int keepTopK,
const float scoreThreshold, const float iouThreshold,
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
const void* confData, void* nmsedDets, void* nmsedLabels, void* workspace,
bool isNormalized, bool confSigmoid, bool clipBoxes) {
const int topKVal = topK < 0 ? numPredsPerClass : topK;
const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK;
// locCount = batch_size * number_boxes_per_sample * 4
const int locCount = N * perBatchBoxesSize;
/*
* shareLocation
* Bounding box are shared among all classes, i.e., a bounding box could be
* classified as any candidate class. Otherwise Bounding box are designed for
* specific classes, i.e., a bounding box could be classified as one certain
* class or not (binary classification).
*/
const int numLocClasses = shareLocation ? 1 : numClasses;
size_t bboxDataSize =
detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT);
void* bboxDataRaw = workspace;
cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice,
stream);
pluginStatus_t status;
/*
* bboxDataRaw format:
* [batch size, numPriors (per sample), numLocClasses, 4]
*/
// float for now
void* bboxData;
size_t bboxPermuteSize = detectionForwardBBoxPermuteSize(
shareLocation, N, perBatchBoxesSize, DataType::kFLOAT);
void* bboxPermute = nextWorkspacePtr((int8_t*)bboxDataRaw, bboxDataSize);
/*
* After permutation, bboxData format:
* [batch_size, numLocClasses, numPriors (per sample) (numPredsPerClass), 4]
* This is equivalent to swapping axis
*/
if (!shareLocation) {
status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, 4,
DataType::kFLOAT, false, bboxDataRaw, bboxPermute);
ASSERT_FAILURE(status == STATUS_SUCCESS);
bboxData = bboxPermute;
}
/*
* If shareLocation, numLocClasses = 1
* No need to permute data on linear memory
*/
else {
bboxData = bboxDataRaw;
}
/*
* Conf data format
* [batch size, numPriors * param.numClasses, 1, 1]
*/
const int numScores = N * perBatchScoresSize;
size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
void* scores = nextWorkspacePtr((int8_t*)bboxPermute, bboxPermuteSize);
// need a conf_scores
/*
* After permutation, bboxData format:
* [batch_size, numClasses, numPredsPerClass, 1]
*/
status = permuteData(stream, numScores, numClasses, numPredsPerClass, 1,
DataType::kFLOAT, confSigmoid, confData, scores);
ASSERT_FAILURE(status == STATUS_SUCCESS);
size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
void* indices = nextWorkspacePtr((int8_t*)scores, totalScoresSize);
size_t postNMSScoresSize =
detectionForwardPostNMSSize(N, numClasses, topKVal);
size_t postNMSIndicesSize =
detectionForwardPostNMSSize(N, numClasses, topKVal);
void* postNMSScores = nextWorkspacePtr((int8_t*)indices, indicesSize);
void* postNMSIndices =
nextWorkspacePtr((int8_t*)postNMSScores, postNMSScoresSize);
void* sortingWorkspace =
nextWorkspacePtr((int8_t*)postNMSIndices, postNMSIndicesSize);
// Sort the scores so that the following NMS could be applied.
status = sortScoresPerClass(
stream, N, numClasses, numPredsPerClass, backgroundLabelId,
scoreThreshold, DataType::kFLOAT, scores, indices, sortingWorkspace);
ASSERT_FAILURE(status == STATUS_SUCCESS);
// This is set to true as the input bounding boxes are of the format [ymin,
// xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax,
// ymax]
bool flipXY = false;
// NMS
status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal,
iouThreshold, shareLocation, isNormalized,
DataType::kFLOAT, DataType::kFLOAT, bboxData, scores,
indices, postNMSScores, postNMSIndices, flipXY);
ASSERT_FAILURE(status == STATUS_SUCCESS);
// Sort the bounding boxes after NMS using scores
status = sortScoresPerImage(stream, N, numClasses * topKVal, DataType::kFLOAT,
postNMSScores, postNMSIndices, scores, indices,
sortingWorkspace);
ASSERT_FAILURE(status == STATUS_SUCCESS);
// Gather data from the sorted bounding boxes after NMS
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass,
numClasses, topKVal, keepTopKVal, DataType::kFLOAT,
DataType::kFLOAT, indices, scores, bboxData,
nmsedDets, nmsedLabels, clipBoxes);
ASSERT_FAILURE(status == STATUS_SUCCESS);
return STATUS_SUCCESS;
}

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel.h"
template <typename KeyT, typename ValueT>
size_t cubSortPairsWorkspaceSize(int num_items, int num_segments) {
size_t temp_storage_bytes = 0;
cub::DeviceSegmentedRadixSort::SortPairsDescending(
(void*)NULL, temp_storage_bytes, (const KeyT*)NULL, (KeyT*)NULL,
(const ValueT*)NULL, (ValueT*)NULL,
num_items, // # items
num_segments, // # segments
(const int*)NULL, (const int*)NULL);
return temp_storage_bytes;
}

View File

@ -0,0 +1,133 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
#include "trt_plugin_helper.hpp"
template <typename T_BBOX, typename T_SCORE, unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages,
const int numPredsPerClass,
const int numClasses, const int topK,
const int keepTopK, const int *indices,
const T_SCORE *scores, const T_BBOX *bboxData,
T_BBOX *nmsedDets, int *nmsedLabels,
bool clipBoxes) {
if (keepTopK > topK) return;
for (int i = blockIdx.x * nthds_per_cta + threadIdx.x;
i < numImages * keepTopK; i += gridDim.x * nthds_per_cta) {
const int imgId = i / keepTopK;
const int detId = i % keepTopK;
const int offset = imgId * numClasses * topK;
const int index = indices[offset + detId];
const T_SCORE score = scores[offset + detId];
if (index == -1) {
nmsedLabels[i] = -1;
nmsedDets[i * 5] = 0;
nmsedDets[i * 5 + 1] = 0;
nmsedDets[i * 5 + 2] = 0;
nmsedDets[i * 5 + 3] = 0;
nmsedDets[i * 5 + 4] = 0;
} else {
const int bboxOffset =
imgId *
(shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass));
const int bboxId =
((shareLocation ? (index % numPredsPerClass)
: index % (numClasses * numPredsPerClass)) +
bboxOffset) *
4;
nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) /
numPredsPerClass; // label
// clipped bbox xmin
nmsedDets[i * 5] =
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId];
// clipped bbox ymin
nmsedDets[i * 5 + 1] =
clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 1];
// clipped bbox xmax
nmsedDets[i * 5 + 2] =
clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 2];
// clipped bbox ymax
nmsedDets[i * 5 + 3] =
clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.))
: bboxData[bboxId + 3];
nmsedDets[i * 5 + 4] = score;
}
}
}
template <typename T_BBOX, typename T_SCORE>
pluginStatus_t gatherNMSOutputs_gpu(
cudaStream_t stream, const bool shareLocation, const int numImages,
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const void *indices, const void *scores,
const void *bboxData, void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
const int BS = 32;
const int GS = 32;
gatherNMSOutputs_kernel<T_BBOX, T_SCORE, BS><<<GS, BS, 0, stream>>>(
shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK,
(int *)indices, (T_SCORE *)scores, (T_BBOX *)bboxData,
(T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// gatherNMSOutputs LAUNCH CONFIG {{{
typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int,
const int, const int, const int, const int,
const void *, const void *, const void *,
void *, void *, bool);
struct nmsOutLaunchConfig {
DataType t_bbox;
DataType t_score;
nmsOutFunc function;
nmsOutLaunchConfig(DataType t_bbox, DataType t_score)
: t_bbox(t_bbox), t_score(t_score) {}
nmsOutLaunchConfig(DataType t_bbox, DataType t_score, nmsOutFunc function)
: t_bbox(t_bbox), t_score(t_score), function(function) {}
bool operator==(const nmsOutLaunchConfig &other) {
return t_bbox == other.t_bbox && t_score == other.t_score;
}
};
using nvinfer1::DataType;
static std::vector<nmsOutLaunchConfig> nmsOutFuncVec;
bool nmsOutputInit() {
nmsOutFuncVec.push_back(nmsOutLaunchConfig(
DataType::kFLOAT, DataType::kFLOAT, gatherNMSOutputs_gpu<float, float>));
return true;
}
static bool initialized = nmsOutputInit();
//}}}
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation,
const int numImages, const int numPredsPerClass,
const int numClasses, const int topK,
const int keepTopK, const DataType DT_BBOX,
const DataType DT_SCORE, const void *indices,
const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels,
bool clipBoxes) {
nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE);
for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) {
if (lc == nmsOutFuncVec[i]) {
DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i);
return nmsOutFuncVec[i].function(stream, shareLocation, numImages,
numPredsPerClass, numClasses, topK,
keepTopK, indices, scores, bboxData,
nmsedDets, nmsedLabels, clipBoxes);
}
}
return STATUS_BAD_PARAM;
}

View File

@ -0,0 +1,104 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <stdint.h>
#include <cub/cub.cuh>
#include "cublas_v2.h"
#include "kernel.h"
#include "trt_plugin_helper.hpp"
#define CUDA_MEM_ALIGN 256
// ALIGNPTR
int8_t *alignPtr(int8_t *ptr, uintptr_t to) {
uintptr_t addr = (uintptr_t)ptr;
if (addr % to) {
addr += to - addr % to;
}
return (int8_t *)addr;
}
// NEXTWORKSPACEPTR
int8_t *nextWorkspacePtr(int8_t *ptr, uintptr_t previousWorkspaceSize) {
uintptr_t addr = (uintptr_t)ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t *)addr, CUDA_MEM_ALIGN);
}
// CALCULATE TOTAL WORKSPACE SIZE
size_t calculateTotalWorkspaceSize(size_t *workspaces, int count) {
size_t total = 0;
for (int i = 0; i < count; i++) {
total += workspaces[i];
if (workspaces[i] % CUDA_MEM_ALIGN) {
total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN);
}
}
return total;
}
using nvinfer1::DataType;
template <unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void setUniformOffsets_kernel(const int num_segments, const int offset,
int *d_offsets) {
const int idx = blockIdx.x * nthds_per_cta + threadIdx.x;
if (idx <= num_segments) d_offsets[idx] = idx * offset;
}
void setUniformOffsets(cudaStream_t stream, const int num_segments,
const int offset, int *d_offsets) {
const int BS = 32;
const int GS = (num_segments + 1 + BS - 1) / BS;
setUniformOffsets_kernel<BS>
<<<GS, BS, 0, stream>>>(num_segments, offset, d_offsets);
}
size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) {
if (DT_BBOX == DataType::kFLOAT) {
return N * C1 * sizeof(float);
}
printf("Only FP32 type bounding boxes are supported.\n");
return (size_t)-1;
}
size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1,
DataType DT_BBOX) {
if (DT_BBOX == DataType::kFLOAT) {
return shareLocation ? 0 : N * C1 * sizeof(float);
}
printf("Only FP32 type bounding boxes are supported.\n");
return (size_t)-1;
}
size_t detectionForwardPreNMSSize(int N, int C2) {
ASSERT(sizeof(float) == sizeof(int));
return N * C2 * sizeof(float);
}
size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) {
ASSERT(sizeof(float) == sizeof(int));
return N * numClasses * topK * sizeof(float);
}
size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1,
int C2, int numClasses,
int numPredsPerClass, int topK,
DataType DT_BBOX, DataType DT_SCORE) {
size_t wss[7];
wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX);
wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX);
wss[2] = detectionForwardPreNMSSize(N, C2);
wss[3] = detectionForwardPreNMSSize(N, C2);
wss[4] = detectionForwardPostNMSSize(N, numClasses, topK);
wss[5] = detectionForwardPostNMSSize(N, numClasses, topK);
wss[6] =
std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass,
DT_SCORE),
sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE));
return calculateTotalWorkspaceSize(wss, 7);
}

View File

@ -0,0 +1,112 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#ifndef TRT_KERNEL_H
#define TRT_KERNEL_H
#include <cuda_runtime.h>
#include <cassert>
#include <cstdio>
#include "cublas_v2.h"
#include "trt_plugin_helper.hpp"
using namespace nvinfer1;
using namespace nvinfer1::plugin;
#define DEBUG_ENABLE 0
template <typename T>
struct Bbox {
T xmin, ymin, xmax, ymax;
Bbox(T xmin, T ymin, T xmax, T ymax)
: xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {}
Bbox() = default;
};
template <typename T>
struct BboxInfo {
T conf_score;
int label;
int bbox_idx;
bool kept;
BboxInfo(T conf_score, int label, int bbox_idx, bool kept)
: conf_score(conf_score), label(label), bbox_idx(bbox_idx), kept(kept) {}
BboxInfo() = default;
};
int8_t* alignPtr(int8_t* ptr, uintptr_t to);
int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize);
void setUniformOffsets(cudaStream_t stream, int num_segments, int offset,
int* d_offsets);
pluginStatus_t allClassNMS(cudaStream_t stream, int num, int num_classes,
int num_preds_per_class, int top_k,
float nms_threshold, bool share_location,
bool isNormalized, DataType DT_SCORE,
DataType DT_BBOX, void* bbox_data,
void* beforeNMS_scores, void* beforeNMS_index_array,
void* afterNMS_scores, void* afterNMS_index_array,
bool flipXY = false);
size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX);
size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1,
DataType DT_BBOX);
size_t sortScoresPerClassWorkspaceSize(int num, int num_classes,
int num_preds_per_class,
DataType DT_CONF);
size_t sortScoresPerImageWorkspaceSize(int num_images, int num_items_per_image,
DataType DT_SCORE);
pluginStatus_t sortScoresPerImage(cudaStream_t stream, int num_images,
int num_items_per_image, DataType DT_SCORE,
void* unsorted_scores,
void* unsorted_bbox_indices,
void* sorted_scores,
void* sorted_bbox_indices, void* workspace);
pluginStatus_t sortScoresPerClass(cudaStream_t stream, int num, int num_classes,
int num_preds_per_class,
int background_label_id,
float confidence_threshold, DataType DT_SCORE,
void* conf_scores_gpu, void* index_array_gpu,
void* workspace);
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count);
pluginStatus_t permuteData(cudaStream_t stream, int nthreads, int num_classes,
int num_data, int num_dim, DataType DT_DATA,
bool confSigmoid, const void* data, void* new_data);
size_t detectionForwardPreNMSSize(int N, int C2);
size_t detectionForwardPostNMSSize(int N, int numClasses, int topK);
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation,
int numImages, int numPredsPerClass,
int numClasses, int topK, int keepTopK,
DataType DT_BBOX, DataType DT_SCORE,
const void* indices, const void* scores,
const void* bboxData, void* nmsedDets,
void* nmsedLabels, bool clipBoxes = true);
size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1,
int C2, int numClasses,
int numPredsPerClass, int topK,
DataType DT_BBOX, DataType DT_SCORE);
pluginStatus_t nmsInference(cudaStream_t stream, int N, int boxesSize,
int scoresSize, bool shareLocation,
int backgroundLabelId, int numPredsPerClass,
int numClasses, int topK, int keepTopK,
float scoreThreshold, float iouThreshold,
DataType DT_BBOX, const void* locData,
DataType DT_SCORE, const void* confData,
void* nmsedDets, void* nmsedLabels, void* workspace,
bool isNormalized = true, bool confSigmoid = false,
bool clipBoxes = true);
#endif

View File

@ -0,0 +1,81 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
template <typename Dtype, unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void permuteData_kernel(const int nthreads, const int num_classes,
const int num_data, const int num_dim,
bool confSigmoid, const Dtype *data,
Dtype *new_data) {
// data format: [batch_size, num_data, num_classes, num_dim]
for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads;
index += nthds_per_cta * gridDim.x) {
const int i = index % num_dim;
const int c = (index / num_dim) % num_classes;
const int d = (index / num_dim / num_classes) % num_data;
const int n = index / num_dim / num_classes / num_data;
const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i;
float result = data[index];
if (confSigmoid) result = exp(result) / (1 + exp(result));
new_data[new_index] = result;
}
// new data format: [batch_size, num_classes, num_data, num_dim]
}
template <typename Dtype>
pluginStatus_t permuteData_gpu(cudaStream_t stream, const int nthreads,
const int num_classes, const int num_data,
const int num_dim, bool confSigmoid,
const void *data, void *new_data) {
const int BS = 512;
const int GS = (nthreads + BS - 1) / BS;
permuteData_kernel<Dtype, BS><<<GS, BS, 0, stream>>>(
nthreads, num_classes, num_data, num_dim, confSigmoid,
(const Dtype *)data, (Dtype *)new_data);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// permuteData LAUNCH CONFIG
typedef pluginStatus_t (*pdFunc)(cudaStream_t, const int, const int, const int,
const int, bool, const void *, void *);
struct pdLaunchConfig {
DataType t_data;
pdFunc function;
pdLaunchConfig(DataType t_data) : t_data(t_data) {}
pdLaunchConfig(DataType t_data, pdFunc function)
: t_data(t_data), function(function) {}
bool operator==(const pdLaunchConfig &other) {
return t_data == other.t_data;
}
};
static std::vector<pdLaunchConfig> pdFuncVec;
bool permuteDataInit() {
pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu<float>));
return true;
}
static bool initialized = permuteDataInit();
pluginStatus_t permuteData(cudaStream_t stream, const int nthreads,
const int num_classes, const int num_data,
const int num_dim, const DataType DT_DATA,
bool confSigmoid, const void *data, void *new_data) {
pdLaunchConfig lc = pdLaunchConfig(DT_DATA);
for (unsigned i = 0; i < pdFuncVec.size(); ++i) {
if (lc == pdFuncVec[i]) {
DEBUG_PRINTF("permuteData kernel %d\n", i);
return pdFuncVec[i].function(stream, nthreads, num_classes, num_data,
num_dim, confSigmoid, data, new_data);
}
}
return STATUS_BAD_PARAM;
}

View File

@ -0,0 +1,157 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "cub/cub.cuh"
#include "cub_helper.h"
#include "kernel.h"
#include "trt_plugin_helper.hpp"
template <typename T_SCORE, unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void prepareSortData(const int num, const int num_classes,
const int num_preds_per_class,
const int background_label_id,
const float confidence_threshold,
T_SCORE *conf_scores_gpu, T_SCORE *temp_scores,
int *temp_idx, int *d_offsets) {
// Prepare scores data for sort
const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x;
const int numPredsPerBatch = num_classes * num_preds_per_class;
if (cur_idx < numPredsPerBatch) {
const int class_idx = cur_idx / num_preds_per_class;
for (int i = 0; i < num; i++) {
const int targetIdx = i * numPredsPerBatch + cur_idx;
const T_SCORE score = conf_scores_gpu[targetIdx];
// "Clear" background labeled score and index
// Because we do not care about background
if (class_idx == background_label_id) {
// Set scores to 0
// Set label = -1
temp_scores[targetIdx] = 0.0f;
temp_idx[targetIdx] = -1;
conf_scores_gpu[targetIdx] = 0.0f;
}
// "Clear" scores lower than threshold
else {
if (score > confidence_threshold) {
temp_scores[targetIdx] = score;
temp_idx[targetIdx] = cur_idx + i * numPredsPerBatch;
} else {
// Set scores to 0
// Set label = -1
temp_scores[targetIdx] = 0.0f;
temp_idx[targetIdx] = -1;
conf_scores_gpu[targetIdx] = 0.0f;
// TODO: HERE writing memory too many times
}
}
if ((cur_idx % num_preds_per_class) == 0) {
const int offset_ct = i * num_classes + cur_idx / num_preds_per_class;
d_offsets[offset_ct] = offset_ct * num_preds_per_class;
// set the last element in d_offset
if (blockIdx.x == 0 && threadIdx.x == 0)
d_offsets[num * num_classes] = num * numPredsPerBatch;
}
}
}
}
template <typename T_SCORE>
pluginStatus_t sortScoresPerClass_gpu(cudaStream_t stream, const int num,
const int num_classes,
const int num_preds_per_class,
const int background_label_id,
const float confidence_threshold,
void *conf_scores_gpu,
void *index_array_gpu, void *workspace) {
const int num_segments = num * num_classes;
void *temp_scores = workspace;
const int arrayLen = num * num_classes * num_preds_per_class;
void *temp_idx =
nextWorkspacePtr((int8_t *)temp_scores, arrayLen * sizeof(T_SCORE));
void *d_offsets =
nextWorkspacePtr((int8_t *)temp_idx, arrayLen * sizeof(int));
size_t cubOffsetSize = (num_segments + 1) * sizeof(int);
void *cubWorkspace = nextWorkspacePtr((int8_t *)d_offsets, cubOffsetSize);
const int BS = 512;
const int GS = (num_classes * num_preds_per_class + BS - 1) / BS;
prepareSortData<T_SCORE, BS><<<GS, BS, 0, stream>>>(
num, num_classes, num_preds_per_class, background_label_id,
confidence_threshold, (T_SCORE *)conf_scores_gpu, (T_SCORE *)temp_scores,
(int *)temp_idx, (int *)d_offsets);
size_t temp_storage_bytes =
cubSortPairsWorkspaceSize<T_SCORE, int>(arrayLen, num_segments);
cub::DeviceSegmentedRadixSort::SortPairsDescending(
cubWorkspace, temp_storage_bytes, (const T_SCORE *)(temp_scores),
(T_SCORE *)(conf_scores_gpu), (const int *)(temp_idx),
(int *)(index_array_gpu), arrayLen, num_segments, (const int *)d_offsets,
(const int *)d_offsets + 1, 0, sizeof(T_SCORE) * 8, stream);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// sortScoresPerClass LAUNCH CONFIG
typedef pluginStatus_t (*sspcFunc)(cudaStream_t, const int, const int,
const int, const int, const float, void *,
void *, void *);
struct sspcLaunchConfig {
DataType t_score;
sspcFunc function;
sspcLaunchConfig(DataType t_score) : t_score(t_score) {}
sspcLaunchConfig(DataType t_score, sspcFunc function)
: t_score(t_score), function(function) {}
bool operator==(const sspcLaunchConfig &other) {
return t_score == other.t_score;
}
};
static std::vector<sspcLaunchConfig> sspcFuncVec;
bool sspcInit() {
sspcFuncVec.push_back(
sspcLaunchConfig(DataType::kFLOAT, sortScoresPerClass_gpu<float>));
return true;
}
static bool initialized = sspcInit();
pluginStatus_t sortScoresPerClass(
cudaStream_t stream, const int num, const int num_classes,
const int num_preds_per_class, const int background_label_id,
const float confidence_threshold, const DataType DT_SCORE,
void *conf_scores_gpu, void *index_array_gpu, void *workspace) {
sspcLaunchConfig lc = sspcLaunchConfig(DT_SCORE);
for (unsigned i = 0; i < sspcFuncVec.size(); ++i) {
if (lc == sspcFuncVec[i]) {
DEBUG_PRINTF("sortScoresPerClass kernel %d\n", i);
return sspcFuncVec[i].function(
stream, num, num_classes, num_preds_per_class, background_label_id,
confidence_threshold, conf_scores_gpu, index_array_gpu, workspace);
}
}
return STATUS_BAD_PARAM;
}
size_t sortScoresPerClassWorkspaceSize(const int num, const int num_classes,
const int num_preds_per_class,
const DataType DT_CONF) {
size_t wss[4];
const int arrayLen = num * num_classes * num_preds_per_class;
wss[0] = arrayLen * mmlab::getElementSize(DT_CONF); // temp scores
wss[1] = arrayLen * sizeof(int); // temp indices
wss[2] = (num * num_classes + 1) * sizeof(int); // offsets
if (DT_CONF == DataType::kFLOAT) {
wss[3] = cubSortPairsWorkspaceSize<float, int>(
arrayLen, num * num_classes); // cub workspace
} else {
printf("SCORE type not supported\n");
return (size_t)-1;
}
return calculateTotalWorkspaceSize(wss, 4);
}

View File

@ -0,0 +1,89 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "cub/cub.cuh"
#include "cub_helper.h"
#include "kernel.h"
template <typename T_SCORE>
pluginStatus_t sortScoresPerImage_gpu(
cudaStream_t stream, const int num_images, const int num_items_per_image,
void *unsorted_scores, void *unsorted_bbox_indices, void *sorted_scores,
void *sorted_bbox_indices, void *workspace) {
void *d_offsets = workspace;
void *cubWorkspace =
nextWorkspacePtr((int8_t *)d_offsets, (num_images + 1) * sizeof(int));
setUniformOffsets(stream, num_images, num_items_per_image, (int *)d_offsets);
const int arrayLen = num_images * num_items_per_image;
size_t temp_storage_bytes =
cubSortPairsWorkspaceSize<T_SCORE, int>(arrayLen, num_images);
cub::DeviceSegmentedRadixSort::SortPairsDescending(
cubWorkspace, temp_storage_bytes, (const T_SCORE *)(unsorted_scores),
(T_SCORE *)(sorted_scores), (const int *)(unsorted_bbox_indices),
(int *)(sorted_bbox_indices), arrayLen, num_images,
(const int *)d_offsets, (const int *)d_offsets + 1, 0,
sizeof(T_SCORE) * 8, stream);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// sortScoresPerImage LAUNCH CONFIG
typedef pluginStatus_t (*sspiFunc)(cudaStream_t, const int, const int, void *,
void *, void *, void *, void *);
struct sspiLaunchConfig {
DataType t_score;
sspiFunc function;
sspiLaunchConfig(DataType t_score) : t_score(t_score) {}
sspiLaunchConfig(DataType t_score, sspiFunc function)
: t_score(t_score), function(function) {}
bool operator==(const sspiLaunchConfig &other) {
return t_score == other.t_score;
}
};
static std::vector<sspiLaunchConfig> sspiFuncVec;
bool sspiInit() {
sspiFuncVec.push_back(
sspiLaunchConfig(DataType::kFLOAT, sortScoresPerImage_gpu<float>));
return true;
}
static bool initialized = sspiInit();
pluginStatus_t sortScoresPerImage(
cudaStream_t stream, const int num_images, const int num_items_per_image,
const DataType DT_SCORE, void *unsorted_scores, void *unsorted_bbox_indices,
void *sorted_scores, void *sorted_bbox_indices, void *workspace) {
sspiLaunchConfig lc = sspiLaunchConfig(DT_SCORE);
for (unsigned i = 0; i < sspiFuncVec.size(); ++i) {
if (lc == sspiFuncVec[i]) {
DEBUG_PRINTF("sortScoresPerImage kernel %d\n", i);
return sspiFuncVec[i].function(
stream, num_images, num_items_per_image, unsorted_scores,
unsorted_bbox_indices, sorted_scores, sorted_bbox_indices, workspace);
}
}
return STATUS_BAD_PARAM;
}
size_t sortScoresPerImageWorkspaceSize(const int num_images,
const int num_items_per_image,
const DataType DT_SCORE) {
const int arrayLen = num_images * num_items_per_image;
size_t wss[2];
wss[0] = (num_images + 1) * sizeof(int); // offsets
if (DT_SCORE == DataType::kFLOAT) {
wss[1] =
cubSortPairsWorkspaceSize<float, int>(arrayLen,
num_images); // cub workspace
} else {
printf("SCORE type not supported.\n");
return (size_t)-1;
}
return calculateTotalWorkspaceSize(wss, 2);
}

View File

@ -0,0 +1,270 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include "trt_batched_nms.hpp"
#include <cstring>
#include "kernel.h"
#include "trt_serialize.hpp"
using namespace nvinfer1;
using nvinfer1::plugin::NMSParameters;
namespace {
static const char* NMS_PLUGIN_VERSION{"1"};
static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"};
} // namespace
PluginFieldCollection TRTBatchedNMSPluginDynamicCreator::mFC{};
std::vector<PluginField> TRTBatchedNMSPluginDynamicCreator::mPluginAttributes;
TRTBatchedNMSPluginDynamic::TRTBatchedNMSPluginDynamic(NMSParameters params)
: param(params) {}
TRTBatchedNMSPluginDynamic::TRTBatchedNMSPluginDynamic(const void* data,
size_t length) {
deserialize_value(&data, &length, &param);
deserialize_value(&data, &length, &boxesSize);
deserialize_value(&data, &length, &scoresSize);
deserialize_value(&data, &length, &numPriors);
deserialize_value(&data, &length, &mClipBoxes);
}
int TRTBatchedNMSPluginDynamic::getNbOutputs() const { return 2; }
int TRTBatchedNMSPluginDynamic::initialize() { return STATUS_SUCCESS; }
void TRTBatchedNMSPluginDynamic::terminate() {}
nvinfer1::DimsExprs TRTBatchedNMSPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) {
ASSERT(nbInputs == 2);
ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs());
ASSERT(inputs[0].nbDims == 4);
ASSERT(inputs[1].nbDims == 3);
nvinfer1::DimsExprs ret;
ret.d[0] = inputs[0].d[0];
ret.d[1] = exprBuilder.constant(param.keepTopK);
switch (outputIndex) {
case 0:
ret.nbDims = 3;
ret.d[2] = exprBuilder.constant(5);
break;
case 1:
ret.nbDims = 2;
break;
default:
break;
}
return ret;
}
size_t TRTBatchedNMSPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
size_t batch_size = inputs[0].dims.d[0];
size_t boxes_size =
inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3];
size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2];
size_t num_priors = inputs[0].dims.d[1];
bool shareLocation = (inputs[0].dims.d[2] == 1);
int topk = param.topK > 0 ? topk : inputs[1].dims.d[1];
return detectionInferenceWorkspaceSize(
shareLocation, batch_size, boxes_size, score_size, param.numClasses,
num_priors, topk, DataType::kFLOAT, DataType::kFLOAT);
}
int TRTBatchedNMSPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workSpace, cudaStream_t stream) {
const void* const locData = inputs[0];
const void* const confData = inputs[1];
void* nmsedDets = outputs[0];
void* nmsedLabels = outputs[1];
size_t batch_size = inputDesc[0].dims.d[0];
size_t boxes_size =
inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3];
size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2];
size_t num_priors = inputDesc[0].dims.d[1];
bool shareLocation = (inputDesc[0].dims.d[2] == 1);
pluginStatus_t status = nmsInference(
stream, batch_size, boxes_size, score_size, shareLocation,
param.backgroundLabelId, num_priors, param.numClasses, param.topK,
param.keepTopK, param.scoreThreshold, param.iouThreshold,
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets,
nmsedLabels, workSpace, param.isNormalized, false, mClipBoxes);
ASSERT(status == STATUS_SUCCESS);
return 0;
}
size_t TRTBatchedNMSPluginDynamic::getSerializationSize() const {
// NMSParameters, boxesSize,scoresSize,numPriors
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
}
void TRTBatchedNMSPluginDynamic::serialize(void* buffer) const {
serialize_value(&buffer, param);
serialize_value(&buffer, boxesSize);
serialize_value(&buffer, scoresSize);
serialize_value(&buffer, numPriors);
serialize_value(&buffer, mClipBoxes);
}
void TRTBatchedNMSPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) {
// Validate input arguments
}
bool TRTBatchedNMSPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) {
if (pos == 3) {
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
const char* TRTBatchedNMSPluginDynamic::getPluginType() const {
return NMS_PLUGIN_NAME;
}
const char* TRTBatchedNMSPluginDynamic::getPluginVersion() const {
return NMS_PLUGIN_VERSION;
}
void TRTBatchedNMSPluginDynamic::destroy() { delete this; }
IPluginV2DynamicExt* TRTBatchedNMSPluginDynamic::clone() const {
auto* plugin = new TRTBatchedNMSPluginDynamic(param);
plugin->boxesSize = boxesSize;
plugin->scoresSize = scoresSize;
plugin->numPriors = numPriors;
plugin->setPluginNamespace(mNamespace.c_str());
plugin->setClipParam(mClipBoxes);
return plugin;
}
void TRTBatchedNMSPluginDynamic::setPluginNamespace(
const char* pluginNamespace) {
mNamespace = pluginNamespace;
}
const char* TRTBatchedNMSPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}
nvinfer1::DataType TRTBatchedNMSPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
ASSERT(index >= 0 && index < this->getNbOutputs());
if (index == 1) {
return nvinfer1::DataType::kINT32;
}
return inputTypes[0];
}
void TRTBatchedNMSPluginDynamic::setClipParam(bool clip) { mClipBoxes = clip; }
TRTBatchedNMSPluginDynamicCreator::TRTBatchedNMSPluginDynamicCreator()
: params{} {
mPluginAttributes.emplace_back(
PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
PluginField("topk", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(
PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(
PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* TRTBatchedNMSPluginDynamicCreator::getPluginName() const {
return NMS_PLUGIN_NAME;
}
const char* TRTBatchedNMSPluginDynamicCreator::getPluginVersion() const {
return NMS_PLUGIN_VERSION;
}
const PluginFieldCollection*
TRTBatchedNMSPluginDynamicCreator::getFieldNames() {
return &mFC;
}
IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::createPlugin(
const char* name, const PluginFieldCollection* fc) {
const PluginField* fields = fc->fields;
bool clipBoxes = true;
for (int i = 0; i < fc->nbFields; ++i) {
const char* attrName = fields[i].name;
if (!strcmp(attrName, "background_label_id")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.backgroundLabelId = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "num_classes")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.numClasses = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "topk")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.topK = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "keep_topk")) {
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.keepTopK = *(static_cast<const int*>(fields[i].data));
} else if (!strcmp(attrName, "score_threshold")) {
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
params.scoreThreshold = *(static_cast<const float*>(fields[i].data));
} else if (!strcmp(attrName, "iou_threshold")) {
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
params.iouThreshold = *(static_cast<const float*>(fields[i].data));
} else if (!strcmp(attrName, "is_normalized")) {
params.isNormalized = *(static_cast<const bool*>(fields[i].data));
} else if (!strcmp(attrName, "clip_boxes")) {
clipBoxes = *(static_cast<const bool*>(fields[i].data));
}
}
TRTBatchedNMSPluginDynamic* plugin = new TRTBatchedNMSPluginDynamic(params);
plugin->setClipParam(clipBoxes);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) {
// This object will be deleted when the network is destroyed, which will
// call NMS::destroy()
TRTBatchedNMSPluginDynamic* plugin =
new TRTBatchedNMSPluginDynamic(serialData, serialLength);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
void TRTBatchedNMSPluginDynamicCreator::setPluginNamespace(
const char* libNamespace) {
mNamespace = libNamespace;
}
const char* TRTBatchedNMSPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}

View File

@ -0,0 +1,118 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#ifndef TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
#define TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
class TRTBatchedNMSPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
TRTBatchedNMSPluginDynamic(nvinfer1::plugin::NMSParameters param);
TRTBatchedNMSPluginDynamic(const void* data, size_t length);
~TRTBatchedNMSPluginDynamic() override = default;
int getNbOutputs() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
int initialize() override;
void terminate() override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workSpace,
cudaStream_t stream) override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
const char* getPluginType() const override;
const char* getPluginVersion() const override;
void destroy() override;
nvinfer1::IPluginV2DynamicExt* clone() const override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputType,
int nbInputs) const override;
void setPluginNamespace(const char* libNamespace) override;
const char* getPluginNamespace() const override;
void setClipParam(bool clip);
private:
nvinfer1::plugin::NMSParameters param{};
int boxesSize{};
int scoresSize{};
int numPriors{};
std::string mNamespace;
bool mClipBoxes{};
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};
class TRTBatchedNMSPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
TRTBatchedNMSPluginDynamicCreator();
~TRTBatchedNMSPluginDynamicCreator() override = default;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) override;
void setPluginNamespace(const char* libNamespace) override;
const char* getPluginNamespace() const override;
private:
static nvinfer1::PluginFieldCollection mFC;
nvinfer1::plugin::NMSParameters params;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H

View File

@ -4,6 +4,118 @@
#include "NvInferPlugin.h"
// Enumerator for status
typedef enum {
STATUS_SUCCESS = 0,
STATUS_FAILURE = 1,
STATUS_BAD_PARAM = 2,
STATUS_NOT_SUPPORTED = 3,
STATUS_NOT_INITIALIZED = 4
} pluginStatus_t;
#define ASSERT(assertion) \
{ \
if (!(assertion)) { \
std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \
abort(); \
} \
}
#define CUASSERT(status_) \
{ \
auto s_ = status_; \
if (s_ != cudaSuccess) { \
std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " \
<< cudaGetErrorString(s_) << std::endl; \
} \
}
#define CUBLASASSERT(status_) \
{ \
auto s_ = status_; \
if (s_ != CUBLAS_STATUS_SUCCESS) { \
std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \
} \
}
#define CUERRORMSG(status_) \
{ \
auto s_ = status_; \
if (s_ != 0) \
std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \
}
#ifndef DEBUG
#define CHECK(status) \
do { \
if (status != 0) abort(); \
} while (0)
#define ASSERT_PARAM(exp) \
do { \
if (!(exp)) return STATUS_BAD_PARAM; \
} while (0)
#define ASSERT_FAILURE(exp) \
do { \
if (!(exp)) return STATUS_FAILURE; \
} while (0)
#define CSC(call, err) \
do { \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) { \
return err; \
} \
} while (0)
#define DEBUG_PRINTF(...) \
do { \
} while (0)
#else
#define ASSERT_PARAM(exp) \
do { \
if (!(exp)) { \
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_BAD_PARAM; \
} \
} while (0)
#define ASSERT_FAILURE(exp) \
do { \
if (!(exp)) { \
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_FAILURE; \
} \
} while (0)
#define CSC(call, err) \
do { \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) { \
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, \
cudaGetErrorString(cudaStatus)); \
return err; \
} \
} while (0)
#define CHECK(status) \
{ \
if (status != 0) { \
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, \
cudaGetErrorString(status)); \
abort(); \
} \
}
#define DEBUG_PRINTF(...) \
do { \
printf(__VA_ARGS__); \
} while (0)
#endif
namespace mmlab {
const int MAXTENSORDIMS = 10;

View File

@ -1,7 +1,9 @@
#include "batched_nms/trt_batched_nms.hpp"
#include "nms/trt_nms.hpp"
#include "roi_align/trt_roi_align.hpp"
#include "scatternd/trt_scatternd.hpp"
REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);

View File

@ -1,7 +1,7 @@
import ctypes
import glob
import os
import logging
import os
def get_tensorrt_op_path():

View File

@ -1,15 +1,30 @@
import logging
import os.path as osp
from typing import Optional, Union
import tensorrt as trt
import mmcv
import onnx
import tensorrt as trt
import torch.multiprocessing as mp
from .tensorrt_utils import onnx2trt, save_trt_engine
def get_trt_loglevel():
logger = logging.getLogger()
level = logger.level
if level == logging.INFO:
return trt.Logger.INFO
elif level == logging.ERROR or level == logging.CRITICAL:
return trt.Logger.ERROR
elif level == logging.WARNING:
return trt.Logger.WARNING
else:
print('for logging level: {}, use trt.Logger.INFO'.format(level))
return trt.Logger.INFO
def onnx2tensorrt(work_dir: str,
save_file: str,
deploy_cfg: Union[str, mmcv.Config],
@ -39,7 +54,7 @@ def onnx2tensorrt(work_dir: str,
engine = onnx2trt(
onnx_model,
opt_shape_dict=tensorrt_param['opt_shape_dict'],
log_level=tensorrt_param.get('log_level', trt.Logger.WARNING),
log_level=tensorrt_param.get('log_level', get_trt_loglevel()),
fp16_mode=tensorrt_param.get('fp16_mode', False),
max_workspace_size=tensorrt_param.get('max_workspace_size', 0),
device_id=device_id)

View File

@ -74,3 +74,41 @@ def nms_tensorrt(symbolic_wrapper, g, boxes, scores,
score_threshold_f=score_threshold,
center_point_box_i=0,
offset_i=0)
class TRTBatchedNMSop(torch.autograd.Function):
@staticmethod
def forward(ctx,
boxes,
scores,
num_classes,
pre_topk,
after_topk,
iou_threshold,
score_threshold,
background_label_id=-1):
batch_size, num_boxes, num_classes = scores.shape
out_boxes = min(num_boxes, after_topk)
return torch.rand(batch_size, out_boxes,
5).to(scores.device), torch.randint(
0, num_classes,
(batch_size, out_boxes)).to(scores.device)
@staticmethod
def symbolic(g, boxes, scores, num_classes, pre_topk, after_topk,
iou_threshold, score_threshold, background_label_id):
return g.op(
'mmcv::TRTBatchedNMS',
boxes,
scores,
num_classes_i=num_classes,
background_label_id_i=background_label_id,
iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold,
topk_i=pre_topk,
keep_topk_i=after_topk,
is_normalized_i=False,
clip_boxes_i=False,
outputs=2)

View File

@ -1,6 +1,7 @@
import torch
from mmdeploy.mmcv.ops import DummyONNXNMSop
from mmdeploy.mmcv.ops import DummyONNXNMSop, TRTBatchedNMSop
from mmdeploy.utils import FUNCTION_REWRITERS
def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
@ -153,3 +154,25 @@ def add_dummy_nms_for_onnx(boxes,
scores = scores.unsqueeze(2)
dets = torch.cat([boxes, scores], dim=2)
return dets, labels
@FUNCTION_REWRITERS.register_rewriter(
func_name='mmdeploy.mmdet.core.export.add_dummy_nms_for_onnx',
backend='tensorrt')
def add_dummy_nms_for_onnx_tensorrt(rewriter,
boxes,
scores,
max_output_boxes_per_class=1000,
iou_threshold=0.5,
score_threshold=0.05,
pre_top_k=-1,
after_top_k=-1,
labels=None):
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
after_top_k = max_output_boxes_per_class if after_top_k < 0 else min(
max_output_boxes_per_class, after_top_k)
dets, labels = TRTBatchedNMSop.apply(boxes, scores, int(scores.shape[-1]),
pre_top_k, after_top_k, iou_threshold,
score_threshold, -1)
return dets, labels

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from mmdeploy.mmdet.core.export import add_dummy_nms_for_onnx
import mmdeploy
from mmdeploy.utils import MODULE_REWRITERS
@ -105,7 +105,11 @@ class AnchorHead(nn.Module):
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold, score_threshold, nms_pre,
cfg.max_per_img)
return mmdeploy.mmdet.core.export.add_dummy_nms_for_onnx(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=nms_pre,
after_top_k=cfg.max_per_img)

1
third_party/cub vendored 160000

@ -0,0 +1 @@
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304