mmdeploy/backend_ops/tensorrt/batched_nms/allClassNMS.cu

290 lines
11 KiB
Plaintext

// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <vector>
#include "kernel.h"
template <typename T_BBOX>
__device__ T_BBOX bboxSize(const Bbox<T_BBOX> &bbox, const bool normalized,
T_BBOX offset) {
if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) {
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0.
return 0;
} else {
T_BBOX width = bbox.xmax - bbox.xmin;
T_BBOX height = bbox.ymax - bbox.ymin;
if (normalized) {
return width * height;
} else {
// If bbox is not within range [0, 1].
return (width + offset) * (height + offset);
}
}
}
template <typename T_BBOX>
__device__ void intersectBbox(const Bbox<T_BBOX> &bbox1,
const Bbox<T_BBOX> &bbox2,
Bbox<T_BBOX> *intersect_bbox) {
if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin ||
bbox2.ymin > bbox1.ymax || bbox2.ymax < bbox1.ymin) {
// Return [0, 0, 0, 0] if there is no intersection.
intersect_bbox->xmin = T_BBOX(0);
intersect_bbox->ymin = T_BBOX(0);
intersect_bbox->xmax = T_BBOX(0);
intersect_bbox->ymax = T_BBOX(0);
} else {
intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin);
intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin);
intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax);
intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax);
}
}
template <typename T_BBOX>
__device__ float jaccardOverlap(const Bbox<T_BBOX> &bbox1,
const Bbox<T_BBOX> &bbox2,
const bool normalized, T_BBOX offset) {
Bbox<T_BBOX> intersect_bbox;
intersectBbox(bbox1, bbox2, &intersect_bbox);
float intersect_width, intersect_height;
if (normalized) {
intersect_width = intersect_bbox.xmax - intersect_bbox.xmin;
intersect_height = intersect_bbox.ymax - intersect_bbox.ymin;
} else {
intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset;
intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset;
}
if (intersect_width > 0 && intersect_height > 0) {
float intersect_size = intersect_width * intersect_height;
float bbox1_size = bboxSize(bbox1, normalized, offset);
float bbox2_size = bboxSize(bbox2, normalized, offset);
return intersect_size / (bbox1_size + bbox2_size - intersect_size);
} else {
return 0.;
}
}
template <typename T_BBOX>
__device__ void emptyBboxInfo(BboxInfo<T_BBOX> *bbox_info) {
bbox_info->conf_score = T_BBOX(0);
bbox_info->label =
-2; // -1 is used for all labels when shared_location is ture
bbox_info->bbox_idx = -1;
bbox_info->kept = false;
}
/********** new NMS for only score and index array **********/
template <typename T_SCORE, typename T_BBOX, int TSIZE>
__global__ void allClassNMS_kernel(
const int num, const int num_classes, const int num_preds_per_class,
const int top_k, const float nms_threshold, const bool share_location,
const bool isNormalized,
T_BBOX *bbox_data, // bbox_data should be float to preserve location
// information
T_SCORE *beforeNMS_scores, int *beforeNMS_index_array,
T_SCORE *afterNMS_scores, int *afterNMS_index_array, bool flipXY = false) {
//__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE];
extern __shared__ bool kept_bboxinfo_flag[];
for (int i = 0; i < num; i++) {
const int offset = i * num_classes * num_preds_per_class +
blockIdx.x * num_preds_per_class;
const int max_idx =
offset + top_k; // put top_k bboxes into NMS calculation
const int bbox_idx_offset = share_location
? (i * num_preds_per_class)
: (i * num_classes * num_preds_per_class);
// local thread data
int loc_bboxIndex[TSIZE];
Bbox<T_BBOX> loc_bbox[TSIZE];
// initialize Bbox, Bboxinfo, kept_bboxinfo_flag
// Eliminate shared memory RAW hazard
__syncthreads();
#pragma unroll
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int item_idx = offset + cur_idx;
if (item_idx < max_idx) {
loc_bboxIndex[t] = beforeNMS_index_array[item_idx];
if (loc_bboxIndex[t] >= 0)
// if (loc_bboxIndex[t] != -1)
{
const int bbox_data_idx =
share_location
? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset)
: loc_bboxIndex[t];
loc_bbox[t].xmin = flipXY ? bbox_data[bbox_data_idx * 4 + 1]
: bbox_data[bbox_data_idx * 4 + 0];
loc_bbox[t].ymin = flipXY ? bbox_data[bbox_data_idx * 4 + 0]
: bbox_data[bbox_data_idx * 4 + 1];
loc_bbox[t].xmax = flipXY ? bbox_data[bbox_data_idx * 4 + 3]
: bbox_data[bbox_data_idx * 4 + 2];
loc_bbox[t].ymax = flipXY ? bbox_data[bbox_data_idx * 4 + 2]
: bbox_data[bbox_data_idx * 4 + 3];
kept_bboxinfo_flag[cur_idx] = true;
} else {
kept_bboxinfo_flag[cur_idx] = false;
}
} else {
kept_bboxinfo_flag[cur_idx] = false;
}
}
// filter out overlapped boxes with lower scores
int ref_item_idx = offset;
int ref_bbox_idx =
share_location
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class +
bbox_idx_offset)
: beforeNMS_index_array[ref_item_idx];
while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) {
Bbox<T_BBOX> ref_bbox;
ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1]
: bbox_data[ref_bbox_idx * 4 + 0];
ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0]
: bbox_data[ref_bbox_idx * 4 + 1];
ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3]
: bbox_data[ref_bbox_idx * 4 + 2];
ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2]
: bbox_data[ref_bbox_idx * 4 + 3];
// Eliminate shared memory RAW hazard
__syncthreads();
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int item_idx = offset + cur_idx;
if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) {
// TODO: may need to add bool normalized as argument, HERE true means
// normalized
if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) >
nms_threshold) {
kept_bboxinfo_flag[cur_idx] = false;
}
}
}
__syncthreads();
do {
ref_item_idx++;
} while (ref_item_idx < max_idx &&
!kept_bboxinfo_flag[ref_item_idx - offset]);
ref_bbox_idx =
share_location
? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class +
bbox_idx_offset)
: beforeNMS_index_array[ref_item_idx];
}
// store data
for (int t = 0; t < TSIZE; t++) {
const int cur_idx = threadIdx.x + blockDim.x * t;
const int read_item_idx = offset + cur_idx;
const int write_item_idx =
(i * num_classes * top_k + blockIdx.x * top_k) + cur_idx;
/*
* If not not keeping the bbox
* Set the score to 0
* Set the bounding box index to -1
*/
if (read_item_idx < max_idx) {
afterNMS_scores[write_item_idx] = kept_bboxinfo_flag[cur_idx]
? beforeNMS_scores[read_item_idx]
: 0.0f;
afterNMS_index_array[write_item_idx] =
kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1;
}
}
}
}
template <typename T_SCORE, typename T_BBOX>
pluginStatus_t allClassNMS_gpu(
cudaStream_t stream, const int num, const int num_classes,
const int num_preds_per_class, const int top_k, const float nms_threshold,
const bool share_location, const bool isNormalized, void *bbox_data,
void *beforeNMS_scores, void *beforeNMS_index_array, void *afterNMS_scores,
void *afterNMS_index_array, bool flipXY = false) {
#define P(tsize) allClassNMS_kernel<T_SCORE, T_BBOX, (tsize)>
void (*kernel[10])(const int, const int, const int, const int, const float,
const bool, const bool, float *, T_SCORE *, int *,
T_SCORE *, int *, bool) = {
P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10),
};
const int BS = 512;
const int GS = num_classes;
const int t_size = (top_k + BS - 1) / BS;
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
num, num_classes, num_preds_per_class, top_k, nms_threshold,
share_location, isNormalized, (T_BBOX *)bbox_data,
(T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
(T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
// allClassNMS LAUNCH CONFIG
typedef pluginStatus_t (*nmsFunc)(cudaStream_t, const int, const int, const int,
const int, const float, const bool,
const bool, void *, void *, void *, void *,
void *, bool);
struct nmsLaunchConfigSSD {
DataType t_score;
DataType t_bbox;
nmsFunc function;
nmsLaunchConfigSSD(DataType t_score, DataType t_bbox)
: t_score(t_score), t_bbox(t_bbox) {}
nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function)
: t_score(t_score), t_bbox(t_bbox), function(function) {}
bool operator==(const nmsLaunchConfigSSD &other) {
return t_score == other.t_score && t_bbox == other.t_bbox;
}
};
static std::vector<nmsLaunchConfigSSD> nmsFuncVec;
bool nmsInit() {
nmsFuncVec.push_back(nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT,
allClassNMS_gpu<float, float>));
return true;
}
static bool initialized = nmsInit();
pluginStatus_t allClassNMS(cudaStream_t stream, const int num,
const int num_classes, const int num_preds_per_class,
const int top_k, const float nms_threshold,
const bool share_location, const bool isNormalized,
const DataType DT_SCORE, const DataType DT_BBOX,
void *bbox_data, void *beforeNMS_scores,
void *beforeNMS_index_array, void *afterNMS_scores,
void *afterNMS_index_array, bool flipXY) {
nmsLaunchConfigSSD lc =
nmsLaunchConfigSSD(DT_SCORE, DT_BBOX, allClassNMS_gpu<float, float>);
for (unsigned i = 0; i < nmsFuncVec.size(); ++i) {
if (lc == nmsFuncVec[i]) {
DEBUG_PRINTF("all class nms kernel %d\n", i);
return nmsFuncVec[i].function(
stream, num, num_classes, num_preds_per_class, top_k, nms_threshold,
share_location, isNormalized, bbox_data, beforeNMS_scores,
beforeNMS_index_array, afterNMS_scores, afterNMS_index_array, flipXY);
}
}
return STATUS_BAD_PARAM;
}