mmdeploy/csrc/preprocess/cuda/normalize.cu
lzhangzz 640aa03538
Support Windows (#106)
* minor changes

* support windows

* fix GCC build

* fix lint

* reformat

* fix Windows build

* fix GCC build

* search backend ops for onnxruntime

* fix lint

* fix lint

* code clean-up

* code clean-up

* fix clang build

* fix trt support

* fix cmake for ncnn

* fix cmake for openvino

* fix SDK Python API

* handle ops for other backends (ncnn, trt)

* handle SDK Python API library location

* robustify linkage

* fix cuda

* minor fix for openvino & ncnn

* use CMAKE_CUDA_ARCHITECTURES if set

* fix cuda preprocessor

* fix misc

* fix pplnn & pplcv, drop support for pplcv<0.6.0

* robustify cmake

* update build.md (#2)

* build dynamic modules as module library & fix demo (partially)

* fix candidate path for mmdeploy_python

* move "enable CUDA" to cmake config for demo

* refine demo cmake

* add comment

* fix ubuntu build

* revert docs/en/build.md

* fix C API

* fix lint

* Windows build doc (#3)

* check in docs related to mmdeploy build on windows

* update build guide on windows platform

* update build guide on windows platform

* make path of thirdparty libraries consistent

* make path consistency

* correct build command for custom ops

* correct build command for sdk

* update sdk build instructions

* update doc

* correct build command

* fix lint

* correct build command and fix lint

Co-authored-by: lvhan <lvhan@pjlab.org>

* trailing whitespace (#4)

* minor fix

* fix sr sdk model

* fix type deduction

* fix cudaFree after driver shutting down

* update ppl.cv installation warning (#5)

* fix device allocator threshold & fix lint

* update doc (#6)

* update ppl.cv installation warning

* missing 'git clone'

Co-authored-by: chenxin <chenxin2@sensetime.com>
Co-authored-by: zhangli <zhangli@sensetime.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: lvhan <lvhan@pjlab.org>
2022-02-24 20:08:44 +08:00

62 lines
2.6 KiB
Plaintext

// Copyright (c) OpenMMLab. All rights reserved.
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
namespace mmdeploy {
namespace cuda {
template <typename T, int channels>
__global__ void normalize(const T* src, int height, int width, int stride, float* output,
const float3 mean, const float3 std, bool to_rgb) {
int x = (int)(blockIdx.x * blockDim.x + threadIdx.x);
int y = (int)(blockIdx.y * blockDim.y + threadIdx.y);
if (x >= width || y >= height) {
return;
}
int loc = y * stride + x * channels;
auto mean_ptr = &mean.x;
auto std_ptr = &std.x;
if (to_rgb) {
for (int c = 0; c < channels; ++c) {
output[loc + c] = ((float)src[loc + channels - 1 - c] - mean_ptr[c]) * std_ptr[c];
}
} else {
for (int c = 0; c < channels; ++c) {
output[loc + c] = ((float)src[loc + c] - mean_ptr[c]) * std_ptr[c];
}
}
}
template <typename T, int channels>
void Normalize(const T* src, int height, int width, int stride, float* output, const float* mean,
const float* std, bool to_rgb, cudaStream_t stream) {
const dim3 thread_block(16, 16);
const dim3 num_blocks((width + thread_block.x - 1) / thread_block.x,
(height + thread_block.y - 1) / thread_block.y);
const float3 _mean{mean[0], mean[1], mean[2]};
const float3 _std{float(1. / std[0]), float(1. / std[1]), float(1. / std[2])};
normalize<T, channels><<<num_blocks, thread_block, 0, stream>>>(src, height, width, stride,
output, _mean, _std, to_rgb);
}
template void Normalize<uint8_t, 3>(const uint8_t* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
template void Normalize<uint8_t, 1>(const uint8_t* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
template void Normalize<float, 3>(const float* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
template void Normalize<float, 1>(const float* src, int height, int width, int stride,
float* output, const float* mean, const float* std, bool to_rgb,
cudaStream_t stream);
} // namespace cuda
} // namespace mmdeploy