[Feature] Add TensorRT batched NMS support (#3)
* add trt batched_nms plugin * update trt batched nmspull/12/head
parent
6c47ee3d2a
commit
5998d24766
|
@ -0,0 +1,3 @@
|
|||
[submodule "third_party/cub"]
|
||||
path = third_party/cub
|
||||
url = https://github.com/NVIDIA/cub.git
|
|
@ -1,2 +1,2 @@
|
|||
[settings]
|
||||
known_third_party = mmcv,mmdet,numpy,setuptools,tensorrt,torch
|
||||
known_third_party = mmcv,mmdet,numpy,onnx,setuptools,tensorrt,torch
|
||||
|
|
|
@ -41,10 +41,17 @@ if(NOT TENSORRT_FOUND)
|
|||
endif()
|
||||
INCLUDE_DIRECTORIES(${TENSORRT_INCLUDE_DIR})
|
||||
|
||||
# cub
|
||||
if (NOT DEFINED CUB_ROOT_DIR)
|
||||
set(CUB_ROOT_DIR "${PROJECT_SOURCE_DIR}/third_party/cub")
|
||||
endif()
|
||||
INCLUDE_DIRECTORIES(${CUB_ROOT_DIR})
|
||||
|
||||
# add plugin source
|
||||
set(PLUGIN_LISTS scatternd
|
||||
nms
|
||||
roi_align)
|
||||
roi_align
|
||||
batched_nms)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
file(GLOB PLUGIN_OPS_SRCS ${PLUGIN_ITER}/*.cpp ${PLUGIN_ITER}/*.cu)
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include <vector>
|
||||
|
||||
#include "kernel.h"
|
||||
|
||||
template <typename T_BBOX>
|
||||
__device__ T_BBOX bboxSize(const Bbox<T_BBOX> &bbox, const bool normalized,
|
||||
T_BBOX offset) {
|
||||
if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) {
|
||||
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0.
|
||||
return 0;
|
||||
} else {
|
||||
T_BBOX width = bbox.xmax - bbox.xmin;
|
||||
T_BBOX height = bbox.ymax - bbox.ymin;
|
||||
if (normalized) {
|
||||
return width * height;
|
||||
} else {
|
||||
// If bbox is not within range [0, 1].
|
||||
return (width + offset) * (height + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_BBOX>
|
||||
__device__ void intersectBbox(const Bbox<T_BBOX> &bbox1,
|
||||
const Bbox<T_BBOX> &bbox2,
|
||||
Bbox<T_BBOX> *intersect_bbox) {
|
||||
if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin ||
|
||||
bbox2.ymin > bbox1.ymax || bbox2.ymax < bbox1.ymin) {
|
||||
// Return [0, 0, 0, 0] if there is no intersection.
|
||||
intersect_bbox->xmin = T_BBOX(0);
|
||||
intersect_bbox->ymin = T_BBOX(0);
|
||||
intersect_bbox->xmax = T_BBOX(0);
|
||||
intersect_bbox->ymax = T_BBOX(0);
|
||||
} else {
|
||||
intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin);
|
||||
intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin);
|
||||
intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax);
|
||||
intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_BBOX>
|
||||
__device__ float jaccardOverlap(const Bbox<T_BBOX> &bbox1,
|
||||
const Bbox<T_BBOX> &bbox2,
|
||||
const bool normalized, T_BBOX offset) {
|
||||
Bbox<T_BBOX> intersect_bbox;
|
||||
intersectBbox(bbox1, bbox2, &intersect_bbox);
|
||||
float intersect_width, intersect_height;
|
||||
if (normalized) {
|
||||
intersect_width = intersect_bbox.xmax - intersect_bbox.xmin;
|
||||
intersect_height = intersect_bbox.ymax - intersect_bbox.ymin;
|
||||
} else {
|
||||
intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset;
|
||||
intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset;
|
||||
}
|
||||
if (intersect_width > 0 && intersect_height > 0) {
|
||||
float intersect_size = intersect_width * intersect_height;
|
||||
float bbox1_size = bboxSize(bbox1, normalized, offset);
|
||||
float bbox2_size = bboxSize(bbox2, normalized, offset);
|
||||
return intersect_size / (bbox1_size + bbox2_size - intersect_size);
|
||||
} else {
|
||||
return 0.;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_BBOX>
|
||||
__device__ void emptyBboxInfo(BboxInfo<T_BBOX> *bbox_info) {
|
||||
bbox_info->conf_score = T_BBOX(0);
|
||||
bbox_info->label =
|
||||
-2; // -1 is used for all labels when shared_location is ture
|
||||
bbox_info->bbox_idx = -1;
|
||||
bbox_info->kept = false;
|
||||
}
|
||||
/********** new NMS for only score and index array **********/
|
||||
|
||||
template <typename T_SCORE, typename T_BBOX, int TSIZE>
|
||||
__global__ void allClassNMS_kernel(
|
||||
const int num, const int num_classes, const int num_preds_per_class,
|
||||
const int top_k, const float nms_threshold, const bool share_location,
|
||||
const bool isNormalized,
|
||||
T_BBOX *bbox_data, // bbox_data should be float to preserve location
|
||||
// information
|
||||
T_SCORE *beforeNMS_scores, int *beforeNMS_index_array,
|
||||
T_SCORE *afterNMS_scores, int *afterNMS_index_array, bool flipXY = false) {
|
||||
//__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE];
|
||||
extern __shared__ bool kept_bboxinfo_flag[];
|
||||
|
||||
for (int i = 0; i < num; i++) {
|
||||
const int offset = i * num_classes * num_preds_per_class +
|
||||
blockIdx.x * num_preds_per_class;
|
||||
const int max_idx =
|
||||
offset + top_k; // put top_k bboxes into NMS calculation
|
||||
const int bbox_idx_offset = share_location
|
||||
? (i * num_preds_per_class)
|
||||
: (i * num_classes * num_preds_per_class);
|
||||
|
||||
// local thread data
|
||||
int loc_bboxIndex[TSIZE];
|
||||
Bbox<T_BBOX> loc_bbox[TSIZE];
|
||||
|
||||
// initialize Bbox, Bboxinfo, kept_bboxinfo_flag
|
||||
// Eliminate shared memory RAW hazard
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int t = 0; t < TSIZE; t++) {
|
||||
const int cur_idx = threadIdx.x + blockDim.x * t;
|
||||
const int item_idx = offset + cur_idx;
|
||||
|
||||
if (item_idx < max_idx) {
|
||||
loc_bboxIndex[t] = beforeNMS_index_array[item_idx];
|
||||
|
||||
if (loc_bboxIndex[t] >= 0)
|
||||
// if (loc_bboxIndex[t] != -1)
|
||||
{
|
||||
const int bbox_data_idx =
|
||||
share_location
|
||||
? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset)
|
||||
: loc_bboxIndex[t];
|
||||
|
||||
loc_bbox[t].xmin = flipXY ? bbox_data[bbox_data_idx * 4 + 1]
|
||||
: bbox_data[bbox_data_idx * 4 + 0];
|
||||
loc_bbox[t].ymin = flipXY ? bbox_data[bbox_data_idx * 4 + 0]
|
||||
: bbox_data[bbox_data_idx * 4 + 1];
|
||||
loc_bbox[t].xmax = flipXY ? bbox_data[bbox_data_idx * 4 + 3]
|
||||
: bbox_data[bbox_data_idx * 4 + 2];
|
||||
loc_bbox[t].ymax = flipXY ? bbox_data[bbox_data_idx * 4 + 2]
|
||||
: bbox_data[bbox_data_idx * 4 + 3];
|
||||
kept_bboxinfo_flag[cur_idx] = true;
|
||||
} else {
|
||||
kept_bboxinfo_flag[cur_idx] = false;
|
||||
}
|
||||
} else {
|
||||
kept_bboxinfo_flag[cur_idx] = false;
|
||||
}
|
||||
}
|
||||
|
||||
// filter out overlapped boxes with lower scores
|
||||
int ref_item_idx = offset;
|
||||
int ref_bbox_idx =
|
||||
share_location
|
||||
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class +
|
||||
bbox_idx_offset)
|
||||
: beforeNMS_index_array[ref_item_idx];
|
||||
|
||||
while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) {
|
||||
Bbox<T_BBOX> ref_bbox;
|
||||
ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1]
|
||||
: bbox_data[ref_bbox_idx * 4 + 0];
|
||||
ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0]
|
||||
: bbox_data[ref_bbox_idx * 4 + 1];
|
||||
ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3]
|
||||
: bbox_data[ref_bbox_idx * 4 + 2];
|
||||
ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2]
|
||||
: bbox_data[ref_bbox_idx * 4 + 3];
|
||||
|
||||
// Eliminate shared memory RAW hazard
|
||||
__syncthreads();
|
||||
|
||||
for (int t = 0; t < TSIZE; t++) {
|
||||
const int cur_idx = threadIdx.x + blockDim.x * t;
|
||||
const int item_idx = offset + cur_idx;
|
||||
|
||||
if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) {
|
||||
// TODO: may need to add bool normalized as argument, HERE true means
|
||||
// normalized
|
||||
if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) >
|
||||
nms_threshold) {
|
||||
kept_bboxinfo_flag[cur_idx] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
do {
|
||||
ref_item_idx++;
|
||||
} while (ref_item_idx < max_idx &&
|
||||
!kept_bboxinfo_flag[ref_item_idx - offset]);
|
||||
|
||||
ref_bbox_idx =
|
||||
share_location
|
||||
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class +
|
||||
bbox_idx_offset)
|
||||
: beforeNMS_index_array[ref_item_idx];
|
||||
}
|
||||
|
||||
// store data
|
||||
for (int t = 0; t < TSIZE; t++) {
|
||||
const int cur_idx = threadIdx.x + blockDim.x * t;
|
||||
const int read_item_idx = offset + cur_idx;
|
||||
const int write_item_idx =
|
||||
(i * num_classes * top_k + blockIdx.x * top_k) + cur_idx;
|
||||
/*
|
||||
* If not not keeping the bbox
|
||||
* Set the score to 0
|
||||
* Set the bounding box index to -1
|
||||
*/
|
||||
if (read_item_idx < max_idx) {
|
||||
afterNMS_scores[write_item_idx] = kept_bboxinfo_flag[cur_idx]
|
||||
? beforeNMS_scores[read_item_idx]
|
||||
: 0.0f;
|
||||
afterNMS_index_array[write_item_idx] =
|
||||
kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_SCORE, typename T_BBOX>
|
||||
pluginStatus_t allClassNMS_gpu(
|
||||
cudaStream_t stream, const int num, const int num_classes,
|
||||
const int num_preds_per_class, const int top_k, const float nms_threshold,
|
||||
const bool share_location, const bool isNormalized, void *bbox_data,
|
||||
void *beforeNMS_scores, void *beforeNMS_index_array, void *afterNMS_scores,
|
||||
void *afterNMS_index_array, bool flipXY = false) {
|
||||
#define P(tsize) allClassNMS_kernel<T_SCORE, T_BBOX, (tsize)>
|
||||
|
||||
void (*kernel[10])(const int, const int, const int, const int, const float,
|
||||
const bool, const bool, float *, T_SCORE *, int *,
|
||||
T_SCORE *, int *, bool) = {
|
||||
P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10),
|
||||
};
|
||||
|
||||
const int BS = 512;
|
||||
const int GS = num_classes;
|
||||
const int t_size = (top_k + BS - 1) / BS;
|
||||
|
||||
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
|
||||
num, num_classes, num_preds_per_class, top_k, nms_threshold,
|
||||
share_location, isNormalized, (T_BBOX *)bbox_data,
|
||||
(T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
|
||||
(T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY);
|
||||
|
||||
CSC(cudaGetLastError(), STATUS_FAILURE);
|
||||
return STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
// allClassNMS LAUNCH CONFIG
|
||||
typedef pluginStatus_t (*nmsFunc)(cudaStream_t, const int, const int, const int,
|
||||
const int, const float, const bool,
|
||||
const bool, void *, void *, void *, void *,
|
||||
void *, bool);
|
||||
|
||||
struct nmsLaunchConfigSSD {
|
||||
DataType t_score;
|
||||
DataType t_bbox;
|
||||
nmsFunc function;
|
||||
|
||||
nmsLaunchConfigSSD(DataType t_score, DataType t_bbox)
|
||||
: t_score(t_score), t_bbox(t_bbox) {}
|
||||
nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function)
|
||||
: t_score(t_score), t_bbox(t_bbox), function(function) {}
|
||||
bool operator==(const nmsLaunchConfigSSD &other) {
|
||||
return t_score == other.t_score && t_bbox == other.t_bbox;
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<nmsLaunchConfigSSD> nmsFuncVec;
|
||||
|
||||
bool nmsInit() {
|
||||
nmsFuncVec.push_back(nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT,
|
||||
allClassNMS_gpu<float, float>));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool initialized = nmsInit();
|
||||
|
||||
pluginStatus_t allClassNMS(cudaStream_t stream, const int num,
|
||||
const int num_classes, const int num_preds_per_class,
|
||||
const int top_k, const float nms_threshold,
|
||||
const bool share_location, const bool isNormalized,
|
||||
const DataType DT_SCORE, const DataType DT_BBOX,
|
||||
void *bbox_data, void *beforeNMS_scores,
|
||||
void *beforeNMS_index_array, void *afterNMS_scores,
|
||||
void *afterNMS_index_array, bool flipXY) {
|
||||
nmsLaunchConfigSSD lc =
|
||||
nmsLaunchConfigSSD(DT_SCORE, DT_BBOX, allClassNMS_gpu<float, float>);
|
||||
for (unsigned i = 0; i < nmsFuncVec.size(); ++i) {
|
||||
if (lc == nmsFuncVec[i]) {
|
||||
DEBUG_PRINTF("all class nms kernel %d\n", i);
|
||||
return nmsFuncVec[i].function(
|
||||
stream, num, num_classes, num_preds_per_class, top_k, nms_threshold,
|
||||
share_location, isNormalized, bbox_data, beforeNMS_scores,
|
||||
beforeNMS_index_array, afterNMS_scores, afterNMS_index_array, flipXY);
|
||||
}
|
||||
}
|
||||
return STATUS_BAD_PARAM;
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "kernel.h"
|
||||
|
||||
pluginStatus_t nmsInference(
|
||||
cudaStream_t stream, const int N, const int perBatchBoxesSize,
|
||||
const int perBatchScoresSize, const bool shareLocation,
|
||||
const int backgroundLabelId, const int numPredsPerClass,
|
||||
const int numClasses, const int topK, const int keepTopK,
|
||||
const float scoreThreshold, const float iouThreshold,
|
||||
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
|
||||
const void* confData, void* nmsedDets, void* nmsedLabels, void* workspace,
|
||||
bool isNormalized, bool confSigmoid, bool clipBoxes) {
|
||||
const int topKVal = topK < 0 ? numPredsPerClass : topK;
|
||||
const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK;
|
||||
// locCount = batch_size * number_boxes_per_sample * 4
|
||||
const int locCount = N * perBatchBoxesSize;
|
||||
/*
|
||||
* shareLocation
|
||||
* Bounding box are shared among all classes, i.e., a bounding box could be
|
||||
* classified as any candidate class. Otherwise Bounding box are designed for
|
||||
* specific classes, i.e., a bounding box could be classified as one certain
|
||||
* class or not (binary classification).
|
||||
*/
|
||||
const int numLocClasses = shareLocation ? 1 : numClasses;
|
||||
|
||||
size_t bboxDataSize =
|
||||
detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT);
|
||||
void* bboxDataRaw = workspace;
|
||||
cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice,
|
||||
stream);
|
||||
pluginStatus_t status;
|
||||
|
||||
/*
|
||||
* bboxDataRaw format:
|
||||
* [batch size, numPriors (per sample), numLocClasses, 4]
|
||||
*/
|
||||
// float for now
|
||||
void* bboxData;
|
||||
size_t bboxPermuteSize = detectionForwardBBoxPermuteSize(
|
||||
shareLocation, N, perBatchBoxesSize, DataType::kFLOAT);
|
||||
void* bboxPermute = nextWorkspacePtr((int8_t*)bboxDataRaw, bboxDataSize);
|
||||
|
||||
/*
|
||||
* After permutation, bboxData format:
|
||||
* [batch_size, numLocClasses, numPriors (per sample) (numPredsPerClass), 4]
|
||||
* This is equivalent to swapping axis
|
||||
*/
|
||||
if (!shareLocation) {
|
||||
status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, 4,
|
||||
DataType::kFLOAT, false, bboxDataRaw, bboxPermute);
|
||||
ASSERT_FAILURE(status == STATUS_SUCCESS);
|
||||
bboxData = bboxPermute;
|
||||
}
|
||||
/*
|
||||
* If shareLocation, numLocClasses = 1
|
||||
* No need to permute data on linear memory
|
||||
*/
|
||||
else {
|
||||
bboxData = bboxDataRaw;
|
||||
}
|
||||
|
||||
/*
|
||||
* Conf data format
|
||||
* [batch size, numPriors * param.numClasses, 1, 1]
|
||||
*/
|
||||
const int numScores = N * perBatchScoresSize;
|
||||
size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
|
||||
void* scores = nextWorkspacePtr((int8_t*)bboxPermute, bboxPermuteSize);
|
||||
|
||||
// need a conf_scores
|
||||
/*
|
||||
* After permutation, bboxData format:
|
||||
* [batch_size, numClasses, numPredsPerClass, 1]
|
||||
*/
|
||||
status = permuteData(stream, numScores, numClasses, numPredsPerClass, 1,
|
||||
DataType::kFLOAT, confSigmoid, confData, scores);
|
||||
ASSERT_FAILURE(status == STATUS_SUCCESS);
|
||||
|
||||
size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
|
||||
void* indices = nextWorkspacePtr((int8_t*)scores, totalScoresSize);
|
||||
|
||||
size_t postNMSScoresSize =
|
||||
detectionForwardPostNMSSize(N, numClasses, topKVal);
|
||||
size_t postNMSIndicesSize =
|
||||
detectionForwardPostNMSSize(N, numClasses, topKVal);
|
||||
void* postNMSScores = nextWorkspacePtr((int8_t*)indices, indicesSize);
|
||||
void* postNMSIndices =
|
||||
nextWorkspacePtr((int8_t*)postNMSScores, postNMSScoresSize);
|
||||
|
||||
void* sortingWorkspace =
|
||||
nextWorkspacePtr((int8_t*)postNMSIndices, postNMSIndicesSize);
|
||||
// Sort the scores so that the following NMS could be applied.
|
||||
|
||||
status = sortScoresPerClass(
|
||||
stream, N, numClasses, numPredsPerClass, backgroundLabelId,
|
||||
scoreThreshold, DataType::kFLOAT, scores, indices, sortingWorkspace);
|
||||
ASSERT_FAILURE(status == STATUS_SUCCESS);
|
||||
|
||||
// This is set to true as the input bounding boxes are of the format [ymin,
|
||||
// xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax,
|
||||
// ymax]
|
||||
bool flipXY = false;
|
||||
// NMS
|
||||
status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal,
|
||||
iouThreshold, shareLocation, isNormalized,
|
||||
DataType::kFLOAT, DataType::kFLOAT, bboxData, scores,
|
||||
indices, postNMSScores, postNMSIndices, flipXY);
|
||||
ASSERT_FAILURE(status == STATUS_SUCCESS);
|
||||
|
||||
// Sort the bounding boxes after NMS using scores
|
||||
status = sortScoresPerImage(stream, N, numClasses * topKVal, DataType::kFLOAT,
|
||||
postNMSScores, postNMSIndices, scores, indices,
|
||||
sortingWorkspace);
|
||||
|
||||
ASSERT_FAILURE(status == STATUS_SUCCESS);
|
||||
|
||||
// Gather data from the sorted bounding boxes after NMS
|
||||
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass,
|
||||
numClasses, topKVal, keepTopKVal, DataType::kFLOAT,
|
||||
DataType::kFLOAT, indices, scores, bboxData,
|
||||
nmsedDets, nmsedLabels, clipBoxes);
|
||||
|
||||
ASSERT_FAILURE(status == STATUS_SUCCESS);
|
||||
|
||||
return STATUS_SUCCESS;
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "kernel.h"
|
||||
template <typename KeyT, typename ValueT>
|
||||
size_t cubSortPairsWorkspaceSize(int num_items, int num_segments) {
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
(void*)NULL, temp_storage_bytes, (const KeyT*)NULL, (KeyT*)NULL,
|
||||
(const ValueT*)NULL, (ValueT*)NULL,
|
||||
num_items, // # items
|
||||
num_segments, // # segments
|
||||
(const int*)NULL, (const int*)NULL);
|
||||
return temp_storage_bytes;
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include <vector>
|
||||
|
||||
#include "kernel.h"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
template <typename T_BBOX, typename T_SCORE, unsigned nthds_per_cta>
|
||||
__launch_bounds__(nthds_per_cta) __global__
|
||||
void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages,
|
||||
const int numPredsPerClass,
|
||||
const int numClasses, const int topK,
|
||||
const int keepTopK, const int *indices,
|
||||
const T_SCORE *scores, const T_BBOX *bboxData,
|
||||
T_BBOX *nmsedDets, int *nmsedLabels,
|
||||
bool clipBoxes) {
|
||||
if (keepTopK > topK) return;
|
||||
for (int i = blockIdx.x * nthds_per_cta + threadIdx.x;
|
||||
i < numImages * keepTopK; i += gridDim.x * nthds_per_cta) {
|
||||
const int imgId = i / keepTopK;
|
||||
const int detId = i % keepTopK;
|
||||
const int offset = imgId * numClasses * topK;
|
||||
const int index = indices[offset + detId];
|
||||
const T_SCORE score = scores[offset + detId];
|
||||
if (index == -1) {
|
||||
nmsedLabels[i] = -1;
|
||||
nmsedDets[i * 5] = 0;
|
||||
nmsedDets[i * 5 + 1] = 0;
|
||||
nmsedDets[i * 5 + 2] = 0;
|
||||
nmsedDets[i * 5 + 3] = 0;
|
||||
nmsedDets[i * 5 + 4] = 0;
|
||||
} else {
|
||||
const int bboxOffset =
|
||||
imgId *
|
||||
(shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass));
|
||||
const int bboxId =
|
||||
((shareLocation ? (index % numPredsPerClass)
|
||||
: index % (numClasses * numPredsPerClass)) +
|
||||
bboxOffset) *
|
||||
4;
|
||||
nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) /
|
||||
numPredsPerClass; // label
|
||||
// clipped bbox xmin
|
||||
nmsedDets[i * 5] =
|
||||
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.))
|
||||
: bboxData[bboxId];
|
||||
// clipped bbox ymin
|
||||
nmsedDets[i * 5 + 1] =
|
||||
clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.))
|
||||
: bboxData[bboxId + 1];
|
||||
// clipped bbox xmax
|
||||
nmsedDets[i * 5 + 2] =
|
||||
clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.))
|
||||
: bboxData[bboxId + 2];
|
||||
// clipped bbox ymax
|
||||
nmsedDets[i * 5 + 3] =
|
||||
clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.))
|
||||
: bboxData[bboxId + 3];
|
||||
nmsedDets[i * 5 + 4] = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_BBOX, typename T_SCORE>
|
||||
pluginStatus_t gatherNMSOutputs_gpu(
|
||||
cudaStream_t stream, const bool shareLocation, const int numImages,
|
||||
const int numPredsPerClass, const int numClasses, const int topK,
|
||||
const int keepTopK, const void *indices, const void *scores,
|
||||
const void *bboxData, void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
|
||||
const int BS = 32;
|
||||
const int GS = 32;
|
||||
gatherNMSOutputs_kernel<T_BBOX, T_SCORE, BS><<<GS, BS, 0, stream>>>(
|
||||
shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK,
|
||||
(int *)indices, (T_SCORE *)scores, (T_BBOX *)bboxData,
|
||||
(T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes);
|
||||
|
||||
CSC(cudaGetLastError(), STATUS_FAILURE);
|
||||
return STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
// gatherNMSOutputs LAUNCH CONFIG {{{
|
||||
typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int,
|
||||
const int, const int, const int, const int,
|
||||
const void *, const void *, const void *,
|
||||
void *, void *, bool);
|
||||
struct nmsOutLaunchConfig {
|
||||
DataType t_bbox;
|
||||
DataType t_score;
|
||||
nmsOutFunc function;
|
||||
|
||||
nmsOutLaunchConfig(DataType t_bbox, DataType t_score)
|
||||
: t_bbox(t_bbox), t_score(t_score) {}
|
||||
nmsOutLaunchConfig(DataType t_bbox, DataType t_score, nmsOutFunc function)
|
||||
: t_bbox(t_bbox), t_score(t_score), function(function) {}
|
||||
bool operator==(const nmsOutLaunchConfig &other) {
|
||||
return t_bbox == other.t_bbox && t_score == other.t_score;
|
||||
}
|
||||
};
|
||||
|
||||
using nvinfer1::DataType;
|
||||
|
||||
static std::vector<nmsOutLaunchConfig> nmsOutFuncVec;
|
||||
|
||||
bool nmsOutputInit() {
|
||||
nmsOutFuncVec.push_back(nmsOutLaunchConfig(
|
||||
DataType::kFLOAT, DataType::kFLOAT, gatherNMSOutputs_gpu<float, float>));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool initialized = nmsOutputInit();
|
||||
|
||||
//}}}
|
||||
|
||||
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation,
|
||||
const int numImages, const int numPredsPerClass,
|
||||
const int numClasses, const int topK,
|
||||
const int keepTopK, const DataType DT_BBOX,
|
||||
const DataType DT_SCORE, const void *indices,
|
||||
const void *scores, const void *bboxData,
|
||||
void *nmsedDets, void *nmsedLabels,
|
||||
bool clipBoxes) {
|
||||
nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE);
|
||||
for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) {
|
||||
if (lc == nmsOutFuncVec[i]) {
|
||||
DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i);
|
||||
return nmsOutFuncVec[i].function(stream, shareLocation, numImages,
|
||||
numPredsPerClass, numClasses, topK,
|
||||
keepTopK, indices, scores, bboxData,
|
||||
nmsedDets, nmsedLabels, clipBoxes);
|
||||
}
|
||||
}
|
||||
return STATUS_BAD_PARAM;
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "cublas_v2.h"
|
||||
#include "kernel.h"
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
#define CUDA_MEM_ALIGN 256
|
||||
|
||||
// ALIGNPTR
|
||||
int8_t *alignPtr(int8_t *ptr, uintptr_t to) {
|
||||
uintptr_t addr = (uintptr_t)ptr;
|
||||
if (addr % to) {
|
||||
addr += to - addr % to;
|
||||
}
|
||||
return (int8_t *)addr;
|
||||
}
|
||||
|
||||
// NEXTWORKSPACEPTR
|
||||
int8_t *nextWorkspacePtr(int8_t *ptr, uintptr_t previousWorkspaceSize) {
|
||||
uintptr_t addr = (uintptr_t)ptr;
|
||||
addr += previousWorkspaceSize;
|
||||
return alignPtr((int8_t *)addr, CUDA_MEM_ALIGN);
|
||||
}
|
||||
|
||||
// CALCULATE TOTAL WORKSPACE SIZE
|
||||
size_t calculateTotalWorkspaceSize(size_t *workspaces, int count) {
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; i++) {
|
||||
total += workspaces[i];
|
||||
if (workspaces[i] % CUDA_MEM_ALIGN) {
|
||||
total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
using nvinfer1::DataType;
|
||||
|
||||
template <unsigned nthds_per_cta>
|
||||
__launch_bounds__(nthds_per_cta) __global__
|
||||
void setUniformOffsets_kernel(const int num_segments, const int offset,
|
||||
int *d_offsets) {
|
||||
const int idx = blockIdx.x * nthds_per_cta + threadIdx.x;
|
||||
if (idx <= num_segments) d_offsets[idx] = idx * offset;
|
||||
}
|
||||
|
||||
void setUniformOffsets(cudaStream_t stream, const int num_segments,
|
||||
const int offset, int *d_offsets) {
|
||||
const int BS = 32;
|
||||
const int GS = (num_segments + 1 + BS - 1) / BS;
|
||||
setUniformOffsets_kernel<BS>
|
||||
<<<GS, BS, 0, stream>>>(num_segments, offset, d_offsets);
|
||||
}
|
||||
|
||||
size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) {
|
||||
if (DT_BBOX == DataType::kFLOAT) {
|
||||
return N * C1 * sizeof(float);
|
||||
}
|
||||
|
||||
printf("Only FP32 type bounding boxes are supported.\n");
|
||||
return (size_t)-1;
|
||||
}
|
||||
|
||||
size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1,
|
||||
DataType DT_BBOX) {
|
||||
if (DT_BBOX == DataType::kFLOAT) {
|
||||
return shareLocation ? 0 : N * C1 * sizeof(float);
|
||||
}
|
||||
printf("Only FP32 type bounding boxes are supported.\n");
|
||||
return (size_t)-1;
|
||||
}
|
||||
|
||||
size_t detectionForwardPreNMSSize(int N, int C2) {
|
||||
ASSERT(sizeof(float) == sizeof(int));
|
||||
return N * C2 * sizeof(float);
|
||||
}
|
||||
|
||||
size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) {
|
||||
ASSERT(sizeof(float) == sizeof(int));
|
||||
return N * numClasses * topK * sizeof(float);
|
||||
}
|
||||
|
||||
size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1,
|
||||
int C2, int numClasses,
|
||||
int numPredsPerClass, int topK,
|
||||
DataType DT_BBOX, DataType DT_SCORE) {
|
||||
size_t wss[7];
|
||||
wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX);
|
||||
wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX);
|
||||
wss[2] = detectionForwardPreNMSSize(N, C2);
|
||||
wss[3] = detectionForwardPreNMSSize(N, C2);
|
||||
wss[4] = detectionForwardPostNMSSize(N, numClasses, topK);
|
||||
wss[5] = detectionForwardPostNMSSize(N, numClasses, topK);
|
||||
wss[6] =
|
||||
std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass,
|
||||
DT_SCORE),
|
||||
sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE));
|
||||
return calculateTotalWorkspaceSize(wss, 7);
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#ifndef TRT_KERNEL_H
|
||||
#define TRT_KERNEL_H
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
|
||||
#include "cublas_v2.h"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using namespace nvinfer1::plugin;
|
||||
#define DEBUG_ENABLE 0
|
||||
|
||||
template <typename T>
|
||||
struct Bbox {
|
||||
T xmin, ymin, xmax, ymax;
|
||||
Bbox(T xmin, T ymin, T xmax, T ymax)
|
||||
: xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {}
|
||||
Bbox() = default;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BboxInfo {
|
||||
T conf_score;
|
||||
int label;
|
||||
int bbox_idx;
|
||||
bool kept;
|
||||
BboxInfo(T conf_score, int label, int bbox_idx, bool kept)
|
||||
: conf_score(conf_score), label(label), bbox_idx(bbox_idx), kept(kept) {}
|
||||
BboxInfo() = default;
|
||||
};
|
||||
|
||||
int8_t* alignPtr(int8_t* ptr, uintptr_t to);
|
||||
|
||||
int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize);
|
||||
|
||||
void setUniformOffsets(cudaStream_t stream, int num_segments, int offset,
|
||||
int* d_offsets);
|
||||
|
||||
pluginStatus_t allClassNMS(cudaStream_t stream, int num, int num_classes,
|
||||
int num_preds_per_class, int top_k,
|
||||
float nms_threshold, bool share_location,
|
||||
bool isNormalized, DataType DT_SCORE,
|
||||
DataType DT_BBOX, void* bbox_data,
|
||||
void* beforeNMS_scores, void* beforeNMS_index_array,
|
||||
void* afterNMS_scores, void* afterNMS_index_array,
|
||||
bool flipXY = false);
|
||||
|
||||
size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX);
|
||||
|
||||
size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1,
|
||||
DataType DT_BBOX);
|
||||
|
||||
size_t sortScoresPerClassWorkspaceSize(int num, int num_classes,
|
||||
int num_preds_per_class,
|
||||
DataType DT_CONF);
|
||||
|
||||
size_t sortScoresPerImageWorkspaceSize(int num_images, int num_items_per_image,
|
||||
DataType DT_SCORE);
|
||||
|
||||
pluginStatus_t sortScoresPerImage(cudaStream_t stream, int num_images,
|
||||
int num_items_per_image, DataType DT_SCORE,
|
||||
void* unsorted_scores,
|
||||
void* unsorted_bbox_indices,
|
||||
void* sorted_scores,
|
||||
void* sorted_bbox_indices, void* workspace);
|
||||
|
||||
pluginStatus_t sortScoresPerClass(cudaStream_t stream, int num, int num_classes,
|
||||
int num_preds_per_class,
|
||||
int background_label_id,
|
||||
float confidence_threshold, DataType DT_SCORE,
|
||||
void* conf_scores_gpu, void* index_array_gpu,
|
||||
void* workspace);
|
||||
|
||||
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count);
|
||||
|
||||
pluginStatus_t permuteData(cudaStream_t stream, int nthreads, int num_classes,
|
||||
int num_data, int num_dim, DataType DT_DATA,
|
||||
bool confSigmoid, const void* data, void* new_data);
|
||||
|
||||
size_t detectionForwardPreNMSSize(int N, int C2);
|
||||
|
||||
size_t detectionForwardPostNMSSize(int N, int numClasses, int topK);
|
||||
|
||||
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation,
|
||||
int numImages, int numPredsPerClass,
|
||||
int numClasses, int topK, int keepTopK,
|
||||
DataType DT_BBOX, DataType DT_SCORE,
|
||||
const void* indices, const void* scores,
|
||||
const void* bboxData, void* nmsedDets,
|
||||
void* nmsedLabels, bool clipBoxes = true);
|
||||
|
||||
size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1,
|
||||
int C2, int numClasses,
|
||||
int numPredsPerClass, int topK,
|
||||
DataType DT_BBOX, DataType DT_SCORE);
|
||||
|
||||
pluginStatus_t nmsInference(cudaStream_t stream, int N, int boxesSize,
|
||||
int scoresSize, bool shareLocation,
|
||||
int backgroundLabelId, int numPredsPerClass,
|
||||
int numClasses, int topK, int keepTopK,
|
||||
float scoreThreshold, float iouThreshold,
|
||||
DataType DT_BBOX, const void* locData,
|
||||
DataType DT_SCORE, const void* confData,
|
||||
void* nmsedDets, void* nmsedLabels, void* workspace,
|
||||
bool isNormalized = true, bool confSigmoid = false,
|
||||
bool clipBoxes = true);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,81 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include <vector>
|
||||
|
||||
#include "kernel.h"
|
||||
|
||||
template <typename Dtype, unsigned nthds_per_cta>
|
||||
__launch_bounds__(nthds_per_cta) __global__
|
||||
void permuteData_kernel(const int nthreads, const int num_classes,
|
||||
const int num_data, const int num_dim,
|
||||
bool confSigmoid, const Dtype *data,
|
||||
Dtype *new_data) {
|
||||
// data format: [batch_size, num_data, num_classes, num_dim]
|
||||
for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads;
|
||||
index += nthds_per_cta * gridDim.x) {
|
||||
const int i = index % num_dim;
|
||||
const int c = (index / num_dim) % num_classes;
|
||||
const int d = (index / num_dim / num_classes) % num_data;
|
||||
const int n = index / num_dim / num_classes / num_data;
|
||||
const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i;
|
||||
float result = data[index];
|
||||
if (confSigmoid) result = exp(result) / (1 + exp(result));
|
||||
|
||||
new_data[new_index] = result;
|
||||
}
|
||||
// new data format: [batch_size, num_classes, num_data, num_dim]
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
pluginStatus_t permuteData_gpu(cudaStream_t stream, const int nthreads,
|
||||
const int num_classes, const int num_data,
|
||||
const int num_dim, bool confSigmoid,
|
||||
const void *data, void *new_data) {
|
||||
const int BS = 512;
|
||||
const int GS = (nthreads + BS - 1) / BS;
|
||||
permuteData_kernel<Dtype, BS><<<GS, BS, 0, stream>>>(
|
||||
nthreads, num_classes, num_data, num_dim, confSigmoid,
|
||||
(const Dtype *)data, (Dtype *)new_data);
|
||||
CSC(cudaGetLastError(), STATUS_FAILURE);
|
||||
return STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
// permuteData LAUNCH CONFIG
|
||||
typedef pluginStatus_t (*pdFunc)(cudaStream_t, const int, const int, const int,
|
||||
const int, bool, const void *, void *);
|
||||
|
||||
struct pdLaunchConfig {
|
||||
DataType t_data;
|
||||
pdFunc function;
|
||||
|
||||
pdLaunchConfig(DataType t_data) : t_data(t_data) {}
|
||||
pdLaunchConfig(DataType t_data, pdFunc function)
|
||||
: t_data(t_data), function(function) {}
|
||||
bool operator==(const pdLaunchConfig &other) {
|
||||
return t_data == other.t_data;
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<pdLaunchConfig> pdFuncVec;
|
||||
|
||||
bool permuteDataInit() {
|
||||
pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu<float>));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool initialized = permuteDataInit();
|
||||
|
||||
pluginStatus_t permuteData(cudaStream_t stream, const int nthreads,
|
||||
const int num_classes, const int num_data,
|
||||
const int num_dim, const DataType DT_DATA,
|
||||
bool confSigmoid, const void *data, void *new_data) {
|
||||
pdLaunchConfig lc = pdLaunchConfig(DT_DATA);
|
||||
for (unsigned i = 0; i < pdFuncVec.size(); ++i) {
|
||||
if (lc == pdFuncVec[i]) {
|
||||
DEBUG_PRINTF("permuteData kernel %d\n", i);
|
||||
return pdFuncVec[i].function(stream, nthreads, num_classes, num_data,
|
||||
num_dim, confSigmoid, data, new_data);
|
||||
}
|
||||
}
|
||||
return STATUS_BAD_PARAM;
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include <vector>
|
||||
|
||||
#include "cub/cub.cuh"
|
||||
#include "cub_helper.h"
|
||||
#include "kernel.h"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
template <typename T_SCORE, unsigned nthds_per_cta>
|
||||
__launch_bounds__(nthds_per_cta) __global__
|
||||
void prepareSortData(const int num, const int num_classes,
|
||||
const int num_preds_per_class,
|
||||
const int background_label_id,
|
||||
const float confidence_threshold,
|
||||
T_SCORE *conf_scores_gpu, T_SCORE *temp_scores,
|
||||
int *temp_idx, int *d_offsets) {
|
||||
// Prepare scores data for sort
|
||||
const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x;
|
||||
const int numPredsPerBatch = num_classes * num_preds_per_class;
|
||||
if (cur_idx < numPredsPerBatch) {
|
||||
const int class_idx = cur_idx / num_preds_per_class;
|
||||
for (int i = 0; i < num; i++) {
|
||||
const int targetIdx = i * numPredsPerBatch + cur_idx;
|
||||
const T_SCORE score = conf_scores_gpu[targetIdx];
|
||||
|
||||
// "Clear" background labeled score and index
|
||||
// Because we do not care about background
|
||||
if (class_idx == background_label_id) {
|
||||
// Set scores to 0
|
||||
// Set label = -1
|
||||
temp_scores[targetIdx] = 0.0f;
|
||||
temp_idx[targetIdx] = -1;
|
||||
conf_scores_gpu[targetIdx] = 0.0f;
|
||||
}
|
||||
// "Clear" scores lower than threshold
|
||||
else {
|
||||
if (score > confidence_threshold) {
|
||||
temp_scores[targetIdx] = score;
|
||||
temp_idx[targetIdx] = cur_idx + i * numPredsPerBatch;
|
||||
} else {
|
||||
// Set scores to 0
|
||||
// Set label = -1
|
||||
temp_scores[targetIdx] = 0.0f;
|
||||
temp_idx[targetIdx] = -1;
|
||||
conf_scores_gpu[targetIdx] = 0.0f;
|
||||
// TODO: HERE writing memory too many times
|
||||
}
|
||||
}
|
||||
|
||||
if ((cur_idx % num_preds_per_class) == 0) {
|
||||
const int offset_ct = i * num_classes + cur_idx / num_preds_per_class;
|
||||
d_offsets[offset_ct] = offset_ct * num_preds_per_class;
|
||||
// set the last element in d_offset
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0)
|
||||
d_offsets[num * num_classes] = num * numPredsPerBatch;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_SCORE>
|
||||
pluginStatus_t sortScoresPerClass_gpu(cudaStream_t stream, const int num,
|
||||
const int num_classes,
|
||||
const int num_preds_per_class,
|
||||
const int background_label_id,
|
||||
const float confidence_threshold,
|
||||
void *conf_scores_gpu,
|
||||
void *index_array_gpu, void *workspace) {
|
||||
const int num_segments = num * num_classes;
|
||||
void *temp_scores = workspace;
|
||||
const int arrayLen = num * num_classes * num_preds_per_class;
|
||||
void *temp_idx =
|
||||
nextWorkspacePtr((int8_t *)temp_scores, arrayLen * sizeof(T_SCORE));
|
||||
void *d_offsets =
|
||||
nextWorkspacePtr((int8_t *)temp_idx, arrayLen * sizeof(int));
|
||||
size_t cubOffsetSize = (num_segments + 1) * sizeof(int);
|
||||
void *cubWorkspace = nextWorkspacePtr((int8_t *)d_offsets, cubOffsetSize);
|
||||
|
||||
const int BS = 512;
|
||||
const int GS = (num_classes * num_preds_per_class + BS - 1) / BS;
|
||||
prepareSortData<T_SCORE, BS><<<GS, BS, 0, stream>>>(
|
||||
num, num_classes, num_preds_per_class, background_label_id,
|
||||
confidence_threshold, (T_SCORE *)conf_scores_gpu, (T_SCORE *)temp_scores,
|
||||
(int *)temp_idx, (int *)d_offsets);
|
||||
|
||||
size_t temp_storage_bytes =
|
||||
cubSortPairsWorkspaceSize<T_SCORE, int>(arrayLen, num_segments);
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
cubWorkspace, temp_storage_bytes, (const T_SCORE *)(temp_scores),
|
||||
(T_SCORE *)(conf_scores_gpu), (const int *)(temp_idx),
|
||||
(int *)(index_array_gpu), arrayLen, num_segments, (const int *)d_offsets,
|
||||
(const int *)d_offsets + 1, 0, sizeof(T_SCORE) * 8, stream);
|
||||
CSC(cudaGetLastError(), STATUS_FAILURE);
|
||||
return STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
// sortScoresPerClass LAUNCH CONFIG
|
||||
typedef pluginStatus_t (*sspcFunc)(cudaStream_t, const int, const int,
|
||||
const int, const int, const float, void *,
|
||||
void *, void *);
|
||||
struct sspcLaunchConfig {
|
||||
DataType t_score;
|
||||
sspcFunc function;
|
||||
|
||||
sspcLaunchConfig(DataType t_score) : t_score(t_score) {}
|
||||
sspcLaunchConfig(DataType t_score, sspcFunc function)
|
||||
: t_score(t_score), function(function) {}
|
||||
bool operator==(const sspcLaunchConfig &other) {
|
||||
return t_score == other.t_score;
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<sspcLaunchConfig> sspcFuncVec;
|
||||
bool sspcInit() {
|
||||
sspcFuncVec.push_back(
|
||||
sspcLaunchConfig(DataType::kFLOAT, sortScoresPerClass_gpu<float>));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool initialized = sspcInit();
|
||||
|
||||
pluginStatus_t sortScoresPerClass(
|
||||
cudaStream_t stream, const int num, const int num_classes,
|
||||
const int num_preds_per_class, const int background_label_id,
|
||||
const float confidence_threshold, const DataType DT_SCORE,
|
||||
void *conf_scores_gpu, void *index_array_gpu, void *workspace) {
|
||||
sspcLaunchConfig lc = sspcLaunchConfig(DT_SCORE);
|
||||
for (unsigned i = 0; i < sspcFuncVec.size(); ++i) {
|
||||
if (lc == sspcFuncVec[i]) {
|
||||
DEBUG_PRINTF("sortScoresPerClass kernel %d\n", i);
|
||||
return sspcFuncVec[i].function(
|
||||
stream, num, num_classes, num_preds_per_class, background_label_id,
|
||||
confidence_threshold, conf_scores_gpu, index_array_gpu, workspace);
|
||||
}
|
||||
}
|
||||
return STATUS_BAD_PARAM;
|
||||
}
|
||||
|
||||
size_t sortScoresPerClassWorkspaceSize(const int num, const int num_classes,
|
||||
const int num_preds_per_class,
|
||||
const DataType DT_CONF) {
|
||||
size_t wss[4];
|
||||
const int arrayLen = num * num_classes * num_preds_per_class;
|
||||
wss[0] = arrayLen * mmlab::getElementSize(DT_CONF); // temp scores
|
||||
wss[1] = arrayLen * sizeof(int); // temp indices
|
||||
wss[2] = (num * num_classes + 1) * sizeof(int); // offsets
|
||||
if (DT_CONF == DataType::kFLOAT) {
|
||||
wss[3] = cubSortPairsWorkspaceSize<float, int>(
|
||||
arrayLen, num * num_classes); // cub workspace
|
||||
} else {
|
||||
printf("SCORE type not supported\n");
|
||||
return (size_t)-1;
|
||||
}
|
||||
|
||||
return calculateTotalWorkspaceSize(wss, 4);
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#include <vector>
|
||||
|
||||
#include "cub/cub.cuh"
|
||||
#include "cub_helper.h"
|
||||
#include "kernel.h"
|
||||
|
||||
template <typename T_SCORE>
|
||||
pluginStatus_t sortScoresPerImage_gpu(
|
||||
cudaStream_t stream, const int num_images, const int num_items_per_image,
|
||||
void *unsorted_scores, void *unsorted_bbox_indices, void *sorted_scores,
|
||||
void *sorted_bbox_indices, void *workspace) {
|
||||
void *d_offsets = workspace;
|
||||
void *cubWorkspace =
|
||||
nextWorkspacePtr((int8_t *)d_offsets, (num_images + 1) * sizeof(int));
|
||||
|
||||
setUniformOffsets(stream, num_images, num_items_per_image, (int *)d_offsets);
|
||||
|
||||
const int arrayLen = num_images * num_items_per_image;
|
||||
size_t temp_storage_bytes =
|
||||
cubSortPairsWorkspaceSize<T_SCORE, int>(arrayLen, num_images);
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
cubWorkspace, temp_storage_bytes, (const T_SCORE *)(unsorted_scores),
|
||||
(T_SCORE *)(sorted_scores), (const int *)(unsorted_bbox_indices),
|
||||
(int *)(sorted_bbox_indices), arrayLen, num_images,
|
||||
(const int *)d_offsets, (const int *)d_offsets + 1, 0,
|
||||
sizeof(T_SCORE) * 8, stream);
|
||||
CSC(cudaGetLastError(), STATUS_FAILURE);
|
||||
return STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
// sortScoresPerImage LAUNCH CONFIG
|
||||
typedef pluginStatus_t (*sspiFunc)(cudaStream_t, const int, const int, void *,
|
||||
void *, void *, void *, void *);
|
||||
struct sspiLaunchConfig {
|
||||
DataType t_score;
|
||||
sspiFunc function;
|
||||
|
||||
sspiLaunchConfig(DataType t_score) : t_score(t_score) {}
|
||||
sspiLaunchConfig(DataType t_score, sspiFunc function)
|
||||
: t_score(t_score), function(function) {}
|
||||
bool operator==(const sspiLaunchConfig &other) {
|
||||
return t_score == other.t_score;
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<sspiLaunchConfig> sspiFuncVec;
|
||||
bool sspiInit() {
|
||||
sspiFuncVec.push_back(
|
||||
sspiLaunchConfig(DataType::kFLOAT, sortScoresPerImage_gpu<float>));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool initialized = sspiInit();
|
||||
|
||||
pluginStatus_t sortScoresPerImage(
|
||||
cudaStream_t stream, const int num_images, const int num_items_per_image,
|
||||
const DataType DT_SCORE, void *unsorted_scores, void *unsorted_bbox_indices,
|
||||
void *sorted_scores, void *sorted_bbox_indices, void *workspace) {
|
||||
sspiLaunchConfig lc = sspiLaunchConfig(DT_SCORE);
|
||||
for (unsigned i = 0; i < sspiFuncVec.size(); ++i) {
|
||||
if (lc == sspiFuncVec[i]) {
|
||||
DEBUG_PRINTF("sortScoresPerImage kernel %d\n", i);
|
||||
return sspiFuncVec[i].function(
|
||||
stream, num_images, num_items_per_image, unsorted_scores,
|
||||
unsorted_bbox_indices, sorted_scores, sorted_bbox_indices, workspace);
|
||||
}
|
||||
}
|
||||
return STATUS_BAD_PARAM;
|
||||
}
|
||||
|
||||
size_t sortScoresPerImageWorkspaceSize(const int num_images,
|
||||
const int num_items_per_image,
|
||||
const DataType DT_SCORE) {
|
||||
const int arrayLen = num_images * num_items_per_image;
|
||||
size_t wss[2];
|
||||
wss[0] = (num_images + 1) * sizeof(int); // offsets
|
||||
if (DT_SCORE == DataType::kFLOAT) {
|
||||
wss[1] =
|
||||
cubSortPairsWorkspaceSize<float, int>(arrayLen,
|
||||
num_images); // cub workspace
|
||||
} else {
|
||||
printf("SCORE type not supported.\n");
|
||||
return (size_t)-1;
|
||||
}
|
||||
|
||||
return calculateTotalWorkspaceSize(wss, 2);
|
||||
}
|
|
@ -0,0 +1,270 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
|
||||
#include "trt_batched_nms.hpp"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "kernel.h"
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using nvinfer1::plugin::NMSParameters;
|
||||
|
||||
namespace {
|
||||
static const char* NMS_PLUGIN_VERSION{"1"};
|
||||
static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"};
|
||||
} // namespace
|
||||
|
||||
PluginFieldCollection TRTBatchedNMSPluginDynamicCreator::mFC{};
|
||||
std::vector<PluginField> TRTBatchedNMSPluginDynamicCreator::mPluginAttributes;
|
||||
|
||||
TRTBatchedNMSPluginDynamic::TRTBatchedNMSPluginDynamic(NMSParameters params)
|
||||
: param(params) {}
|
||||
|
||||
TRTBatchedNMSPluginDynamic::TRTBatchedNMSPluginDynamic(const void* data,
|
||||
size_t length) {
|
||||
deserialize_value(&data, &length, ¶m);
|
||||
deserialize_value(&data, &length, &boxesSize);
|
||||
deserialize_value(&data, &length, &scoresSize);
|
||||
deserialize_value(&data, &length, &numPriors);
|
||||
deserialize_value(&data, &length, &mClipBoxes);
|
||||
}
|
||||
|
||||
int TRTBatchedNMSPluginDynamic::getNbOutputs() const { return 2; }
|
||||
|
||||
int TRTBatchedNMSPluginDynamic::initialize() { return STATUS_SUCCESS; }
|
||||
|
||||
void TRTBatchedNMSPluginDynamic::terminate() {}
|
||||
|
||||
nvinfer1::DimsExprs TRTBatchedNMSPluginDynamic::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) {
|
||||
ASSERT(nbInputs == 2);
|
||||
ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs());
|
||||
ASSERT(inputs[0].nbDims == 4);
|
||||
ASSERT(inputs[1].nbDims == 3);
|
||||
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.d[0] = inputs[0].d[0];
|
||||
ret.d[1] = exprBuilder.constant(param.keepTopK);
|
||||
switch (outputIndex) {
|
||||
case 0:
|
||||
ret.nbDims = 3;
|
||||
ret.d[2] = exprBuilder.constant(5);
|
||||
break;
|
||||
case 1:
|
||||
ret.nbDims = 2;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t TRTBatchedNMSPluginDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
|
||||
size_t batch_size = inputs[0].dims.d[0];
|
||||
size_t boxes_size =
|
||||
inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3];
|
||||
size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2];
|
||||
size_t num_priors = inputs[0].dims.d[1];
|
||||
bool shareLocation = (inputs[0].dims.d[2] == 1);
|
||||
int topk = param.topK > 0 ? topk : inputs[1].dims.d[1];
|
||||
return detectionInferenceWorkspaceSize(
|
||||
shareLocation, batch_size, boxes_size, score_size, param.numClasses,
|
||||
num_priors, topk, DataType::kFLOAT, DataType::kFLOAT);
|
||||
}
|
||||
|
||||
int TRTBatchedNMSPluginDynamic::enqueue(
|
||||
const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
|
||||
void* const* outputs, void* workSpace, cudaStream_t stream) {
|
||||
const void* const locData = inputs[0];
|
||||
const void* const confData = inputs[1];
|
||||
|
||||
void* nmsedDets = outputs[0];
|
||||
void* nmsedLabels = outputs[1];
|
||||
|
||||
size_t batch_size = inputDesc[0].dims.d[0];
|
||||
size_t boxes_size =
|
||||
inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3];
|
||||
size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2];
|
||||
size_t num_priors = inputDesc[0].dims.d[1];
|
||||
bool shareLocation = (inputDesc[0].dims.d[2] == 1);
|
||||
|
||||
pluginStatus_t status = nmsInference(
|
||||
stream, batch_size, boxes_size, score_size, shareLocation,
|
||||
param.backgroundLabelId, num_priors, param.numClasses, param.topK,
|
||||
param.keepTopK, param.scoreThreshold, param.iouThreshold,
|
||||
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets,
|
||||
nmsedLabels, workSpace, param.isNormalized, false, mClipBoxes);
|
||||
ASSERT(status == STATUS_SUCCESS);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t TRTBatchedNMSPluginDynamic::getSerializationSize() const {
|
||||
// NMSParameters, boxesSize,scoresSize,numPriors
|
||||
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
|
||||
}
|
||||
|
||||
void TRTBatchedNMSPluginDynamic::serialize(void* buffer) const {
|
||||
serialize_value(&buffer, param);
|
||||
serialize_value(&buffer, boxesSize);
|
||||
serialize_value(&buffer, scoresSize);
|
||||
serialize_value(&buffer, numPriors);
|
||||
serialize_value(&buffer, mClipBoxes);
|
||||
}
|
||||
|
||||
void TRTBatchedNMSPluginDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) {
|
||||
// Validate input arguments
|
||||
}
|
||||
|
||||
bool TRTBatchedNMSPluginDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
if (pos == 3) {
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSPluginDynamic::getPluginType() const {
|
||||
return NMS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSPluginDynamic::getPluginVersion() const {
|
||||
return NMS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
void TRTBatchedNMSPluginDynamic::destroy() { delete this; }
|
||||
|
||||
IPluginV2DynamicExt* TRTBatchedNMSPluginDynamic::clone() const {
|
||||
auto* plugin = new TRTBatchedNMSPluginDynamic(param);
|
||||
plugin->boxesSize = boxesSize;
|
||||
plugin->scoresSize = scoresSize;
|
||||
plugin->numPriors = numPriors;
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
plugin->setClipParam(mClipBoxes);
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void TRTBatchedNMSPluginDynamic::setPluginNamespace(
|
||||
const char* pluginNamespace) {
|
||||
mNamespace = pluginNamespace;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSPluginDynamic::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTBatchedNMSPluginDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
|
||||
ASSERT(index >= 0 && index < this->getNbOutputs());
|
||||
if (index == 1) {
|
||||
return nvinfer1::DataType::kINT32;
|
||||
}
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
void TRTBatchedNMSPluginDynamic::setClipParam(bool clip) { mClipBoxes = clip; }
|
||||
|
||||
TRTBatchedNMSPluginDynamicCreator::TRTBatchedNMSPluginDynamicCreator()
|
||||
: params{} {
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("topk", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(
|
||||
PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1));
|
||||
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSPluginDynamicCreator::getPluginName() const {
|
||||
return NMS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSPluginDynamicCreator::getPluginVersion() const {
|
||||
return NMS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const PluginFieldCollection*
|
||||
TRTBatchedNMSPluginDynamicCreator::getFieldNames() {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::createPlugin(
|
||||
const char* name, const PluginFieldCollection* fc) {
|
||||
const PluginField* fields = fc->fields;
|
||||
bool clipBoxes = true;
|
||||
|
||||
for (int i = 0; i < fc->nbFields; ++i) {
|
||||
const char* attrName = fields[i].name;
|
||||
if (!strcmp(attrName, "background_label_id")) {
|
||||
ASSERT(fields[i].type == PluginFieldType::kINT32);
|
||||
params.backgroundLabelId = *(static_cast<const int*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "num_classes")) {
|
||||
ASSERT(fields[i].type == PluginFieldType::kINT32);
|
||||
params.numClasses = *(static_cast<const int*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "topk")) {
|
||||
ASSERT(fields[i].type == PluginFieldType::kINT32);
|
||||
params.topK = *(static_cast<const int*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "keep_topk")) {
|
||||
ASSERT(fields[i].type == PluginFieldType::kINT32);
|
||||
params.keepTopK = *(static_cast<const int*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "score_threshold")) {
|
||||
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
|
||||
params.scoreThreshold = *(static_cast<const float*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "iou_threshold")) {
|
||||
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
|
||||
params.iouThreshold = *(static_cast<const float*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "is_normalized")) {
|
||||
params.isNormalized = *(static_cast<const bool*>(fields[i].data));
|
||||
} else if (!strcmp(attrName, "clip_boxes")) {
|
||||
clipBoxes = *(static_cast<const bool*>(fields[i].data));
|
||||
}
|
||||
}
|
||||
|
||||
TRTBatchedNMSPluginDynamic* plugin = new TRTBatchedNMSPluginDynamic(params);
|
||||
plugin->setClipParam(clipBoxes);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::deserializePlugin(
|
||||
const char* name, const void* serialData, size_t serialLength) {
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call NMS::destroy()
|
||||
TRTBatchedNMSPluginDynamic* plugin =
|
||||
new TRTBatchedNMSPluginDynamic(serialData, serialLength);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void TRTBatchedNMSPluginDynamicCreator::setPluginNamespace(
|
||||
const char* libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char* TRTBatchedNMSPluginDynamicCreator::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
// modify from
|
||||
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
||||
#ifndef TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
|
||||
#define TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
class TRTBatchedNMSPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
TRTBatchedNMSPluginDynamic(nvinfer1::plugin::NMSParameters param);
|
||||
|
||||
TRTBatchedNMSPluginDynamic(const void* data, size_t length);
|
||||
|
||||
~TRTBatchedNMSPluginDynamic() override = default;
|
||||
|
||||
int getNbOutputs() const override;
|
||||
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) override;
|
||||
|
||||
int initialize() override;
|
||||
|
||||
void terminate() override;
|
||||
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs,
|
||||
int nbOutputs) const override;
|
||||
|
||||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workSpace,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
size_t getSerializationSize() const override;
|
||||
|
||||
void serialize(void* buffer) const override;
|
||||
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* outputs,
|
||||
int nbOutputs) override;
|
||||
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc* inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
|
||||
const char* getPluginType() const override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
|
||||
void destroy() override;
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const override;
|
||||
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType* inputType,
|
||||
int nbInputs) const override;
|
||||
|
||||
void setPluginNamespace(const char* libNamespace) override;
|
||||
|
||||
const char* getPluginNamespace() const override;
|
||||
|
||||
void setClipParam(bool clip);
|
||||
|
||||
private:
|
||||
nvinfer1::plugin::NMSParameters param{};
|
||||
int boxesSize{};
|
||||
int scoresSize{};
|
||||
int numPriors{};
|
||||
std::string mNamespace;
|
||||
bool mClipBoxes{};
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
};
|
||||
|
||||
class TRTBatchedNMSPluginDynamicCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
TRTBatchedNMSPluginDynamicCreator();
|
||||
|
||||
~TRTBatchedNMSPluginDynamicCreator() override = default;
|
||||
|
||||
const char* getPluginName() const override;
|
||||
|
||||
const char* getPluginVersion() const override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection* getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2Ext* createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
|
||||
|
||||
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
|
||||
const void* serialData,
|
||||
size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char* libNamespace) override;
|
||||
|
||||
const char* getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
nvinfer1::plugin::NMSParameters params;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
std::string mNamespace;
|
||||
};
|
||||
|
||||
#endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
|
|
@ -4,6 +4,118 @@
|
|||
|
||||
#include "NvInferPlugin.h"
|
||||
|
||||
// Enumerator for status
|
||||
typedef enum {
|
||||
STATUS_SUCCESS = 0,
|
||||
STATUS_FAILURE = 1,
|
||||
STATUS_BAD_PARAM = 2,
|
||||
STATUS_NOT_SUPPORTED = 3,
|
||||
STATUS_NOT_INITIALIZED = 4
|
||||
} pluginStatus_t;
|
||||
|
||||
#define ASSERT(assertion) \
|
||||
{ \
|
||||
if (!(assertion)) { \
|
||||
std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \
|
||||
abort(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CUASSERT(status_) \
|
||||
{ \
|
||||
auto s_ = status_; \
|
||||
if (s_ != cudaSuccess) { \
|
||||
std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " \
|
||||
<< cudaGetErrorString(s_) << std::endl; \
|
||||
} \
|
||||
}
|
||||
#define CUBLASASSERT(status_) \
|
||||
{ \
|
||||
auto s_ = status_; \
|
||||
if (s_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \
|
||||
} \
|
||||
}
|
||||
#define CUERRORMSG(status_) \
|
||||
{ \
|
||||
auto s_ = status_; \
|
||||
if (s_ != 0) \
|
||||
std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \
|
||||
}
|
||||
|
||||
#ifndef DEBUG
|
||||
|
||||
#define CHECK(status) \
|
||||
do { \
|
||||
if (status != 0) abort(); \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_PARAM(exp) \
|
||||
do { \
|
||||
if (!(exp)) return STATUS_BAD_PARAM; \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_FAILURE(exp) \
|
||||
do { \
|
||||
if (!(exp)) return STATUS_FAILURE; \
|
||||
} while (0)
|
||||
|
||||
#define CSC(call, err) \
|
||||
do { \
|
||||
cudaError_t cudaStatus = call; \
|
||||
if (cudaStatus != cudaSuccess) { \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DEBUG_PRINTF(...) \
|
||||
do { \
|
||||
} while (0)
|
||||
|
||||
#else
|
||||
|
||||
#define ASSERT_PARAM(exp) \
|
||||
do { \
|
||||
if (!(exp)) { \
|
||||
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
|
||||
return STATUS_BAD_PARAM; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_FAILURE(exp) \
|
||||
do { \
|
||||
if (!(exp)) { \
|
||||
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
|
||||
return STATUS_FAILURE; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CSC(call, err) \
|
||||
do { \
|
||||
cudaError_t cudaStatus = call; \
|
||||
if (cudaStatus != cudaSuccess) { \
|
||||
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(cudaStatus)); \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK(status) \
|
||||
{ \
|
||||
if (status != 0) { \
|
||||
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(status)); \
|
||||
abort(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DEBUG_PRINTF(...) \
|
||||
do { \
|
||||
printf(__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
||||
namespace mmlab {
|
||||
|
||||
const int MAXTENSORDIMS = 10;
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
#include "batched_nms/trt_batched_nms.hpp"
|
||||
#include "nms/trt_nms.hpp"
|
||||
#include "roi_align/trt_roi_align.hpp"
|
||||
#include "scatternd/trt_scatternd.hpp"
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import ctypes
|
||||
import glob
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
def get_tensorrt_op_path():
|
||||
|
|
|
@ -1,15 +1,30 @@
|
|||
import logging
|
||||
import os.path as osp
|
||||
from typing import Optional, Union
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
import mmcv
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from .tensorrt_utils import onnx2trt, save_trt_engine
|
||||
|
||||
|
||||
def get_trt_loglevel():
|
||||
logger = logging.getLogger()
|
||||
level = logger.level
|
||||
|
||||
if level == logging.INFO:
|
||||
return trt.Logger.INFO
|
||||
elif level == logging.ERROR or level == logging.CRITICAL:
|
||||
return trt.Logger.ERROR
|
||||
elif level == logging.WARNING:
|
||||
return trt.Logger.WARNING
|
||||
else:
|
||||
print('for logging level: {}, use trt.Logger.INFO'.format(level))
|
||||
return trt.Logger.INFO
|
||||
|
||||
|
||||
def onnx2tensorrt(work_dir: str,
|
||||
save_file: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
|
@ -39,7 +54,7 @@ def onnx2tensorrt(work_dir: str,
|
|||
engine = onnx2trt(
|
||||
onnx_model,
|
||||
opt_shape_dict=tensorrt_param['opt_shape_dict'],
|
||||
log_level=tensorrt_param.get('log_level', trt.Logger.WARNING),
|
||||
log_level=tensorrt_param.get('log_level', get_trt_loglevel()),
|
||||
fp16_mode=tensorrt_param.get('fp16_mode', False),
|
||||
max_workspace_size=tensorrt_param.get('max_workspace_size', 0),
|
||||
device_id=device_id)
|
||||
|
|
|
@ -74,3 +74,41 @@ def nms_tensorrt(symbolic_wrapper, g, boxes, scores,
|
|||
score_threshold_f=score_threshold,
|
||||
center_point_box_i=0,
|
||||
offset_i=0)
|
||||
|
||||
|
||||
class TRTBatchedNMSop(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
boxes,
|
||||
scores,
|
||||
num_classes,
|
||||
pre_topk,
|
||||
after_topk,
|
||||
iou_threshold,
|
||||
score_threshold,
|
||||
background_label_id=-1):
|
||||
batch_size, num_boxes, num_classes = scores.shape
|
||||
|
||||
out_boxes = min(num_boxes, after_topk)
|
||||
return torch.rand(batch_size, out_boxes,
|
||||
5).to(scores.device), torch.randint(
|
||||
0, num_classes,
|
||||
(batch_size, out_boxes)).to(scores.device)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, boxes, scores, num_classes, pre_topk, after_topk,
|
||||
iou_threshold, score_threshold, background_label_id):
|
||||
return g.op(
|
||||
'mmcv::TRTBatchedNMS',
|
||||
boxes,
|
||||
scores,
|
||||
num_classes_i=num_classes,
|
||||
background_label_id_i=background_label_id,
|
||||
iou_threshold_f=iou_threshold,
|
||||
score_threshold_f=score_threshold,
|
||||
topk_i=pre_topk,
|
||||
keep_topk_i=after_topk,
|
||||
is_normalized_i=False,
|
||||
clip_boxes_i=False,
|
||||
outputs=2)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.mmcv.ops import DummyONNXNMSop
|
||||
from mmdeploy.mmcv.ops import DummyONNXNMSop, TRTBatchedNMSop
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS
|
||||
|
||||
|
||||
def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
|
||||
|
@ -153,3 +154,25 @@ def add_dummy_nms_for_onnx(boxes,
|
|||
scores = scores.unsqueeze(2)
|
||||
dets = torch.cat([boxes, scores], dim=2)
|
||||
return dets, labels
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='mmdeploy.mmdet.core.export.add_dummy_nms_for_onnx',
|
||||
backend='tensorrt')
|
||||
def add_dummy_nms_for_onnx_tensorrt(rewriter,
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=1000,
|
||||
iou_threshold=0.5,
|
||||
score_threshold=0.05,
|
||||
pre_top_k=-1,
|
||||
after_top_k=-1,
|
||||
labels=None):
|
||||
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
||||
after_top_k = max_output_boxes_per_class if after_top_k < 0 else min(
|
||||
max_output_boxes_per_class, after_top_k)
|
||||
dets, labels = TRTBatchedNMSop.apply(boxes, scores, int(scores.shape[-1]),
|
||||
pre_top_k, after_top_k, iou_threshold,
|
||||
score_threshold, -1)
|
||||
|
||||
return dets, labels
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmdeploy.mmdet.core.export import add_dummy_nms_for_onnx
|
||||
import mmdeploy
|
||||
from mmdeploy.utils import MODULE_REWRITERS
|
||||
|
||||
|
||||
|
@ -105,7 +105,11 @@ class AnchorHead(nn.Module):
|
|||
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
|
||||
score_threshold = cfg.score_thr
|
||||
nms_pre = cfg.get('deploy_nms_pre', -1)
|
||||
return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, nms_pre,
|
||||
cfg.max_per_img)
|
||||
return mmdeploy.mmdet.core.export.add_dummy_nms_for_onnx(
|
||||
batch_mlvl_bboxes,
|
||||
batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=nms_pre,
|
||||
after_top_k=cfg.max_per_img)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
|
Loading…
Reference in New Issue