lvhan028 36124f6205
Merge sdk (#251)
* check in cmake

* move backend_ops to csrc/backend_ops

* check in preprocess, model, some codebase and their c-apis

* check in CMakeLists.txt

* check in parts of test_csrc

* commit everything else

* add readme

* update core's BUILD_INTERFACE directory

* skip codespell on third_party

* update trt_net and ort_net's CMakeLists

* ignore clion's build directory

* check in pybind11

* add onnx.proto. Remove MMDeploy's dependency on ncnn's source code

* export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON

* remove useless message

* target include directory is wrong

* change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net

* skip install directory

* update project's cmake

* remove useless code

* set CMAKE_BUILD_TYPE to Release by force if it isn't set by user

* update custom ops CMakeLists

* pass object target's source lists

* fix lint end-of-file

* fix lint: trailing whitespace

* fix codespell hook

* remove bicubic_interpolate to csrc/backend_ops/

* set MMDEPLOY_BUILD_SDK OFF

* change custom ops build command

* add spdlog installation command

* update docs on how to checkout pybind11

* move bicubic_interpolate to backend_ops/tensorrt directory

* remove useless code

* correct cmake

* fix typo

* fix typo

* fix install directory

* correct sdk's readme

* set cub dir when cuda version < 11.0

* change directory where clang-format will apply to

* fix build command

* add .clang-format

* change clang-format style from google to file

* reformat csrc/backend_ops

* format sdk's code

* turn off clang-format for some files

* add -Xcompiler=-fno-gnu-unique

* fix trt topk initialize

* check in config for sdk demo

* update cmake script and csrc's readme

* correct config's path

* add cuda include directory, otherwise compile failed in case of tensorrt8.2

* clang-format onnx2ncnn.cpp

Co-authored-by: zhangli <lzhang329@gmail.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00

99 lines
3.2 KiB
Plaintext

// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include <stdint.h>
#include <cub/cub.cuh>
#include "cublas_v2.h"
#include "kernel.h"
#include "trt_plugin_helper.hpp"
#define CUDA_MEM_ALIGN 256
// ALIGNPTR
int8_t *alignPtr(int8_t *ptr, uintptr_t to) {
uintptr_t addr = (uintptr_t)ptr;
if (addr % to) {
addr += to - addr % to;
}
return (int8_t *)addr;
}
// NEXTWORKSPACEPTR
int8_t *nextWorkspacePtr(int8_t *ptr, uintptr_t previousWorkspaceSize) {
uintptr_t addr = (uintptr_t)ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t *)addr, CUDA_MEM_ALIGN);
}
// CALCULATE TOTAL WORKSPACE SIZE
size_t calculateTotalWorkspaceSize(size_t *workspaces, int count) {
size_t total = 0;
for (int i = 0; i < count; i++) {
total += workspaces[i];
if (workspaces[i] % CUDA_MEM_ALIGN) {
total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN);
}
}
return total;
}
using nvinfer1::DataType;
template <unsigned nthds_per_cta>
__launch_bounds__(nthds_per_cta) __global__
void setUniformOffsets_kernel(const int num_segments, const int offset, int *d_offsets) {
const int idx = blockIdx.x * nthds_per_cta + threadIdx.x;
if (idx <= num_segments) d_offsets[idx] = idx * offset;
}
void setUniformOffsets(cudaStream_t stream, const int num_segments, const int offset,
int *d_offsets) {
const int BS = 32;
const int GS = (num_segments + 1 + BS - 1) / BS;
setUniformOffsets_kernel<BS><<<GS, BS, 0, stream>>>(num_segments, offset, d_offsets);
}
size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) {
if (DT_BBOX == DataType::kFLOAT) {
return N * C1 * sizeof(float);
}
printf("Only FP32 type bounding boxes are supported.\n");
return (size_t)-1;
}
size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX) {
if (DT_BBOX == DataType::kFLOAT) {
return shareLocation ? 0 : N * C1 * sizeof(float);
}
printf("Only FP32 type bounding boxes are supported.\n");
return (size_t)-1;
}
size_t detectionForwardPreNMSSize(int N, int C2) {
ASSERT(sizeof(float) == sizeof(int));
return N * C2 * sizeof(float);
}
size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) {
ASSERT(sizeof(float) == sizeof(int));
return N * numClasses * topK * sizeof(float);
}
size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses,
int numPredsPerClass, int topK, DataType DT_BBOX,
DataType DT_SCORE) {
size_t wss[7];
wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX);
wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX);
wss[2] = detectionForwardPreNMSSize(N, C2);
wss[3] = detectionForwardPreNMSSize(N, C2);
wss[4] = detectionForwardPostNMSSize(N, numClasses, topK);
wss[5] = detectionForwardPostNMSSize(N, numClasses, topK);
wss[6] = std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass, DT_SCORE),
sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE));
return calculateTotalWorkspaceSize(wss, 7);
}