[Fix] add bounds to avoid large resource usage of nms operator on jetson (#1686)
* fix trt nms jetson * update-for-comment * clang formatpull/1704/head
parent
99d6fb3190
commit
5fdf00324b
csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include "nms/kernel.h"
|
||||
|
||||
const static int BS = 512;
|
||||
|
||||
template <typename T_BBOX>
|
||||
__device__ T_BBOX bboxSize(const Bbox<T_BBOX> &bbox, const bool normalized, T_BBOX offset) {
|
||||
if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) {
|
||||
|
@ -65,18 +67,25 @@ __device__ float jaccardOverlap(const Bbox<T_BBOX> &bbox1, const Bbox<T_BBOX> &b
|
|||
|
||||
/********** new NMS for only score and index array **********/
|
||||
|
||||
// clang-format off
|
||||
template <typename T_SCORE, typename T_BBOX, int TSIZE>
|
||||
__global__ void allClassNMS_kernel(const int num, const int num_classes,
|
||||
const int num_preds_per_class, const int top_k,
|
||||
const float nms_threshold, const bool share_location,
|
||||
const bool isNormalized,
|
||||
T_BBOX *bbox_data, // bbox_data should be float to preserve
|
||||
// location information
|
||||
T_SCORE *beforeNMS_scores, int *beforeNMS_index_array,
|
||||
T_SCORE *afterNMS_scores, int *afterNMS_index_array,
|
||||
bool flipXY = false) {
|
||||
__global__ void
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ == 620 || __CUDA_ARCH__ == 530
|
||||
__launch_bounds__(512)
|
||||
#endif
|
||||
#endif
|
||||
allClassNMS_kernel(const int num, const int num_classes, const int num_preds_per_class,
|
||||
const int top_k, const float nms_threshold, const bool share_location,
|
||||
const bool isNormalized,
|
||||
T_BBOX *bbox_data, // bbox_data should be float to preserve
|
||||
// location information
|
||||
T_SCORE *beforeNMS_scores, int *beforeNMS_index_array,
|
||||
T_SCORE *afterNMS_scores, int *afterNMS_index_array, bool flipXY = false) {
|
||||
// clang-format on
|
||||
|
||||
//__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE];
|
||||
extern __shared__ bool kept_bboxinfo_flag[];
|
||||
__shared__ bool kept_bboxinfo_flag[TSIZE * BS];
|
||||
for (int i = 0; i < num; i++) {
|
||||
const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class;
|
||||
const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation
|
||||
|
@ -196,29 +205,18 @@ pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num
|
|||
P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10),
|
||||
};
|
||||
|
||||
const int BS = 512;
|
||||
const int GS = num_classes;
|
||||
const int t_size = (top_k + BS - 1) / BS;
|
||||
|
||||
ASSERT(t_size <= 10);
|
||||
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
|
||||
kernel[t_size - 1]<<<GS, BS, 0, stream>>>(
|
||||
num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized,
|
||||
(T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
|
||||
(T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY);
|
||||
|
||||
cudaError_t code = cudaGetLastError();
|
||||
if (code != cudaSuccess) {
|
||||
// Verify if cuda dev0 requires top_k to be reduced;
|
||||
// sm_53 (Jetson Nano) and sm_62 (Jetson TX2) requires reduced top_k < 1000
|
||||
auto __cuda_arch__ = get_cuda_arch(0);
|
||||
if ((__cuda_arch__ == 530 || __cuda_arch__ == 620) && top_k >= 1000) {
|
||||
printf(
|
||||
"Warning: pre_top_k need to be reduced for devices with arch 5.3, 6.2, got "
|
||||
"pre_top_k=%d\n",
|
||||
top_k);
|
||||
}
|
||||
}
|
||||
CSC(cudaGetLastError(), STATUS_FAILURE);
|
||||
CUASSERT(code);
|
||||
CSC(code, STATUS_FAILURE);
|
||||
return STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue