mmdeploy/csrc/preprocess/cuda/normalize.cu
lvhan028 36124f6205
Merge sdk (#251)
* check in cmake

* move backend_ops to csrc/backend_ops

* check in preprocess, model, some codebase and their c-apis

* check in CMakeLists.txt

* check in parts of test_csrc

* commit everything else

* add readme

* update core's BUILD_INTERFACE directory

* skip codespell on third_party

* update trt_net and ort_net's CMakeLists

* ignore clion's build directory

* check in pybind11

* add onnx.proto. Remove MMDeploy's dependency on ncnn's source code

* export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON

* remove useless message

* target include directory is wrong

* change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net

* skip install directory

* update project's cmake

* remove useless code

* set CMAKE_BUILD_TYPE to Release by force if it isn't set by user

* update custom ops CMakeLists

* pass object target's source lists

* fix lint end-of-file

* fix lint: trailing whitespace

* fix codespell hook

* remove bicubic_interpolate to csrc/backend_ops/

* set MMDEPLOY_BUILD_SDK OFF

* change custom ops build command

* add spdlog installation command

* update docs on how to checkout pybind11

* move bicubic_interpolate to backend_ops/tensorrt directory

* remove useless code

* correct cmake

* fix typo

* fix typo

* fix install directory

* correct sdk's readme

* set cub dir when cuda version < 11.0

* change directory where clang-format will apply to

* fix build command

* add .clang-format

* change clang-format style from google to file

* reformat csrc/backend_ops

* format sdk's code

* turn off clang-format for some files

* add -Xcompiler=-fno-gnu-unique

* fix trt topk initialize

* check in config for sdk demo

* update cmake script and csrc's readme

* correct config's path

* add cuda include directory, otherwise compile failed in case of tensorrt8.2

* clang-format onnx2ncnn.cpp

Co-authored-by: zhangli <lzhang329@gmail.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00

60 lines
2.5 KiB
Plaintext

// Copyright (c) OpenMMLab. All rights reserved.
#include <cstdint>
#include <cstdio>
namespace mmdeploy {
namespace cuda {
template <typename T, int channels>
__global__ void normalize(const T* src, int height, int width, int stride, float* output,
const float3 mean, const float3 std, bool to_rgb) {
int x = (int)(blockIdx.x * blockDim.x + threadIdx.x);
int y = (int)(blockIdx.y * blockDim.y + threadIdx.y);
if (x >= width or y >= height) {
return;
}
int loc = y * stride + x * channels;
auto mean_ptr = &mean.x;
auto std_ptr = &std.x;
if (to_rgb) {
for (int c = 0; c < channels; ++c) {
output[loc + c] = ((float)src[loc + channels - 1 - c] - mean_ptr[c]) / std_ptr[c];
}
} else {
for (int c = 0; c < channels; ++c) {
output[loc + c] = ((float)src[loc + c] - mean_ptr[c]) / std_ptr[c];
}
}
}
template <typename T, int channels>
void Normalize(const T* src, int height, int width, int stride, float* output, const float* mean,
const float* std, bool to_rgb, cudaStream_t stream) {
const dim3 thread_block(16, 16);
const dim3 num_blocks((width + thread_block.x - 1) / thread_block.x,
(height + thread_block.y - 1) / thread_block.y);
const float3 _mean{mean[0], mean[1], mean[2]};
const float3 _std{std[0], std[1], std[2]};
normalize<T, channels><<<num_blocks, thread_block, 0, stream>>>(src, height, width, stride,
output, _mean, _std, to_rgb);
}
template void Normalize<uint8_t, 3>(const uint8_t* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
template void Normalize<uint8_t, 1>(const uint8_t* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
template void Normalize<float, 3>(const float* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
template void Normalize<float, 1>(const float* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
} // namespace cuda
} // namespace mmdeploy