mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* check in cmake * move backend_ops to csrc/backend_ops * check in preprocess, model, some codebase and their c-apis * check in CMakeLists.txt * check in parts of test_csrc * commit everything else * add readme * update core's BUILD_INTERFACE directory * skip codespell on third_party * update trt_net and ort_net's CMakeLists * ignore clion's build directory * check in pybind11 * add onnx.proto. Remove MMDeploy's dependency on ncnn's source code * export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON * remove useless message * target include directory is wrong * change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net * skip install directory * update project's cmake * remove useless code * set CMAKE_BUILD_TYPE to Release by force if it isn't set by user * update custom ops CMakeLists * pass object target's source lists * fix lint end-of-file * fix lint: trailing whitespace * fix codespell hook * remove bicubic_interpolate to csrc/backend_ops/ * set MMDEPLOY_BUILD_SDK OFF * change custom ops build command * add spdlog installation command * update docs on how to checkout pybind11 * move bicubic_interpolate to backend_ops/tensorrt directory * remove useless code * correct cmake * fix typo * fix typo * fix install directory * correct sdk's readme * set cub dir when cuda version < 11.0 * change directory where clang-format will apply to * fix build command * add .clang-format * change clang-format style from google to file * reformat csrc/backend_ops * format sdk's code * turn off clang-format for some files * add -Xcompiler=-fno-gnu-unique * fix trt topk initialize * check in config for sdk demo * update cmake script and csrc's readme * correct config's path * add cuda include directory, otherwise compile failed in case of tensorrt8.2 * clang-format onnx2ncnn.cpp Co-authored-by: zhangli <lzhang329@gmail.com> Co-authored-by: grimoire <yaoqian@sensetime.com>
119 lines
5.2 KiB
Plaintext
119 lines
5.2 KiB
Plaintext
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
// modify from
|
|
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
|
#include <vector>
|
|
|
|
#include "kernel.h"
|
|
#include "trt_plugin_helper.hpp"
|
|
|
|
template <typename T_BBOX, typename T_SCORE, unsigned nthds_per_cta>
|
|
__launch_bounds__(nthds_per_cta) __global__
|
|
void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages,
|
|
const int numPredsPerClass, const int numClasses, const int topK,
|
|
const int keepTopK, const int *indices, const T_SCORE *scores,
|
|
const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels,
|
|
bool clipBoxes) {
|
|
if (keepTopK > topK) return;
|
|
for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK;
|
|
i += gridDim.x * nthds_per_cta) {
|
|
const int imgId = i / keepTopK;
|
|
const int detId = i % keepTopK;
|
|
const int offset = imgId * numClasses * topK;
|
|
const int index = indices[offset + detId];
|
|
const T_SCORE score = scores[offset + detId];
|
|
if (index == -1) {
|
|
nmsedLabels[i] = -1;
|
|
nmsedDets[i * 5] = 0;
|
|
nmsedDets[i * 5 + 1] = 0;
|
|
nmsedDets[i * 5 + 2] = 0;
|
|
nmsedDets[i * 5 + 3] = 0;
|
|
nmsedDets[i * 5 + 4] = 0;
|
|
} else {
|
|
const int bboxOffset =
|
|
imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass));
|
|
const int bboxId =
|
|
((shareLocation ? (index % numPredsPerClass) : index % (numClasses * numPredsPerClass)) +
|
|
bboxOffset) *
|
|
4;
|
|
nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label
|
|
// clipped bbox xmin
|
|
nmsedDets[i * 5] =
|
|
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
|
|
// clipped bbox ymin
|
|
nmsedDets[i * 5 + 1] =
|
|
clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1];
|
|
// clipped bbox xmax
|
|
nmsedDets[i * 5 + 2] =
|
|
clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2];
|
|
// clipped bbox ymax
|
|
nmsedDets[i * 5 + 3] =
|
|
clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3];
|
|
nmsedDets[i * 5 + 4] = score;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T_BBOX, typename T_SCORE>
|
|
pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocation,
|
|
const int numImages, const int numPredsPerClass,
|
|
const int numClasses, const int topK, const int keepTopK,
|
|
const void *indices, const void *scores, const void *bboxData,
|
|
void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
|
|
const int BS = 32;
|
|
const int GS = 32;
|
|
gatherNMSOutputs_kernel<T_BBOX, T_SCORE, BS><<<GS, BS, 0, stream>>>(
|
|
shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices,
|
|
(T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes);
|
|
|
|
CSC(cudaGetLastError(), STATUS_FAILURE);
|
|
return STATUS_SUCCESS;
|
|
}
|
|
|
|
// gatherNMSOutputs LAUNCH CONFIG {{{
|
|
typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int,
|
|
const int, const int, const void *, const void *, const void *,
|
|
void *, void *, bool);
|
|
struct nmsOutLaunchConfig {
|
|
DataType t_bbox;
|
|
DataType t_score;
|
|
nmsOutFunc function;
|
|
|
|
nmsOutLaunchConfig(DataType t_bbox, DataType t_score) : t_bbox(t_bbox), t_score(t_score) {}
|
|
nmsOutLaunchConfig(DataType t_bbox, DataType t_score, nmsOutFunc function)
|
|
: t_bbox(t_bbox), t_score(t_score), function(function) {}
|
|
bool operator==(const nmsOutLaunchConfig &other) {
|
|
return t_bbox == other.t_bbox && t_score == other.t_score;
|
|
}
|
|
};
|
|
|
|
using nvinfer1::DataType;
|
|
|
|
static std::vector<nmsOutLaunchConfig> nmsOutFuncVec;
|
|
|
|
bool nmsOutputInit() {
|
|
nmsOutFuncVec.push_back(
|
|
nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, gatherNMSOutputs_gpu<float, float>));
|
|
return true;
|
|
}
|
|
|
|
static bool initialized = nmsOutputInit();
|
|
|
|
//}}}
|
|
|
|
pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, const int numImages,
|
|
const int numPredsPerClass, const int numClasses, const int topK,
|
|
const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE,
|
|
const void *indices, const void *scores, const void *bboxData,
|
|
void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
|
|
nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE);
|
|
for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) {
|
|
if (lc == nmsOutFuncVec[i]) {
|
|
DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i);
|
|
return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass,
|
|
numClasses, topK, keepTopK, indices, scores, bboxData,
|
|
nmsedDets, nmsedLabels, clipBoxes);
|
|
}
|
|
}
|
|
return STATUS_BAD_PARAM;
|
|
}
|