mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* check in cmake * move backend_ops to csrc/backend_ops * check in preprocess, model, some codebase and their c-apis * check in CMakeLists.txt * check in parts of test_csrc * commit everything else * add readme * update core's BUILD_INTERFACE directory * skip codespell on third_party * update trt_net and ort_net's CMakeLists * ignore clion's build directory * check in pybind11 * add onnx.proto. Remove MMDeploy's dependency on ncnn's source code * export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON * remove useless message * target include directory is wrong * change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net * skip install directory * update project's cmake * remove useless code * set CMAKE_BUILD_TYPE to Release by force if it isn't set by user * update custom ops CMakeLists * pass object target's source lists * fix lint end-of-file * fix lint: trailing whitespace * fix codespell hook * remove bicubic_interpolate to csrc/backend_ops/ * set MMDEPLOY_BUILD_SDK OFF * change custom ops build command * add spdlog installation command * update docs on how to checkout pybind11 * move bicubic_interpolate to backend_ops/tensorrt directory * remove useless code * correct cmake * fix typo * fix typo * fix install directory * correct sdk's readme * set cub dir when cuda version < 11.0 * change directory where clang-format will apply to * fix build command * add .clang-format * change clang-format style from google to file * reformat csrc/backend_ops * format sdk's code * turn off clang-format for some files * add -Xcompiler=-fno-gnu-unique * fix trt topk initialize * check in config for sdk demo * update cmake script and csrc's readme * correct config's path * add cuda include directory, otherwise compile failed in case of tensorrt8.2 * clang-format onnx2ncnn.cpp Co-authored-by: zhangli <lzhang329@gmail.com> Co-authored-by: grimoire <yaoqian@sensetime.com>
77 lines
3.1 KiB
Plaintext
77 lines
3.1 KiB
Plaintext
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
// modify from
|
|
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
|
|
#include <vector>
|
|
|
|
#include "kernel.h"
|
|
|
|
template <typename Dtype, unsigned nthds_per_cta>
|
|
__launch_bounds__(nthds_per_cta) __global__
|
|
void permuteData_kernel(const int nthreads, const int num_classes, const int num_data,
|
|
const int num_dim, bool confSigmoid, const Dtype *data,
|
|
Dtype *new_data) {
|
|
// data format: [batch_size, num_data, num_classes, num_dim]
|
|
for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads;
|
|
index += nthds_per_cta * gridDim.x) {
|
|
const int i = index % num_dim;
|
|
const int c = (index / num_dim) % num_classes;
|
|
const int d = (index / num_dim / num_classes) % num_data;
|
|
const int n = index / num_dim / num_classes / num_data;
|
|
const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i;
|
|
float result = data[index];
|
|
if (confSigmoid) result = exp(result) / (1 + exp(result));
|
|
|
|
new_data[new_index] = result;
|
|
}
|
|
// new data format: [batch_size, num_classes, num_data, num_dim]
|
|
}
|
|
|
|
template <typename Dtype>
|
|
pluginStatus_t permuteData_gpu(cudaStream_t stream, const int nthreads, const int num_classes,
|
|
const int num_data, const int num_dim, bool confSigmoid,
|
|
const void *data, void *new_data) {
|
|
const int BS = 512;
|
|
const int GS = (nthreads + BS - 1) / BS;
|
|
permuteData_kernel<Dtype, BS><<<GS, BS, 0, stream>>>(nthreads, num_classes, num_data, num_dim,
|
|
confSigmoid, (const Dtype *)data,
|
|
(Dtype *)new_data);
|
|
CSC(cudaGetLastError(), STATUS_FAILURE);
|
|
return STATUS_SUCCESS;
|
|
}
|
|
|
|
// permuteData LAUNCH CONFIG
|
|
typedef pluginStatus_t (*pdFunc)(cudaStream_t, const int, const int, const int, const int, bool,
|
|
const void *, void *);
|
|
|
|
struct pdLaunchConfig {
|
|
DataType t_data;
|
|
pdFunc function;
|
|
|
|
pdLaunchConfig(DataType t_data) : t_data(t_data) {}
|
|
pdLaunchConfig(DataType t_data, pdFunc function) : t_data(t_data), function(function) {}
|
|
bool operator==(const pdLaunchConfig &other) { return t_data == other.t_data; }
|
|
};
|
|
|
|
static std::vector<pdLaunchConfig> pdFuncVec;
|
|
|
|
bool permuteDataInit() {
|
|
pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu<float>));
|
|
return true;
|
|
}
|
|
|
|
static bool initialized = permuteDataInit();
|
|
|
|
pluginStatus_t permuteData(cudaStream_t stream, const int nthreads, const int num_classes,
|
|
const int num_data, const int num_dim, const DataType DT_DATA,
|
|
bool confSigmoid, const void *data, void *new_data) {
|
|
pdLaunchConfig lc = pdLaunchConfig(DT_DATA);
|
|
for (unsigned i = 0; i < pdFuncVec.size(); ++i) {
|
|
if (lc == pdFuncVec[i]) {
|
|
DEBUG_PRINTF("permuteData kernel %d\n", i);
|
|
return pdFuncVec[i].function(stream, nthreads, num_classes, num_data, num_dim, confSigmoid,
|
|
data, new_data);
|
|
}
|
|
}
|
|
return STATUS_BAD_PARAM;
|
|
}
|