mmdeploy/csrc/preprocess/cuda/normalize_impl.cpp
lvhan028 36124f6205
Merge sdk (#251)
* check in cmake

* move backend_ops to csrc/backend_ops

* check in preprocess, model, some codebase and their c-apis

* check in CMakeLists.txt

* check in parts of test_csrc

* commit everything else

* add readme

* update core's BUILD_INTERFACE directory

* skip codespell on third_party

* update trt_net and ort_net's CMakeLists

* ignore clion's build directory

* check in pybind11

* add onnx.proto. Remove MMDeploy's dependency on ncnn's source code

* export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON

* remove useless message

* target include directory is wrong

* change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net

* skip install directory

* update project's cmake

* remove useless code

* set CMAKE_BUILD_TYPE to Release by force if it isn't set by user

* update custom ops CMakeLists

* pass object target's source lists

* fix lint end-of-file

* fix lint: trailing whitespace

* fix codespell hook

* remove bicubic_interpolate to csrc/backend_ops/

* set MMDEPLOY_BUILD_SDK OFF

* change custom ops build command

* add spdlog installation command

* update docs on how to checkout pybind11

* move bicubic_interpolate to backend_ops/tensorrt directory

* remove useless code

* correct cmake

* fix typo

* fix typo

* fix install directory

* correct sdk's readme

* set cub dir when cuda version < 11.0

* change directory where clang-format will apply to

* fix build command

* add .clang-format

* change clang-format style from google to file

* reformat csrc/backend_ops

* format sdk's code

* turn off clang-format for some files

* add -Xcompiler=-fno-gnu-unique

* fix trt topk initialize

* check in config for sdk demo

* update cmake script and csrc's readme

* correct config's path

* add cuda include directory, otherwise compile failed in case of tensorrt8.2

* clang-format onnx2ncnn.cpp

Co-authored-by: zhangli <lzhang329@gmail.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00

82 lines
2.8 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include <cuda_runtime.h>
#include "core/utils/formatter.h"
#include "preprocess/transform/normalize.h"
#include "preprocess/transform/transform_utils.h"
using namespace std;
namespace mmdeploy::cuda {
template <typename T, int channels>
void Normalize(const T* src, int height, int width, int stride, float* output, const float* mean,
const float* std, bool to_rgb, cudaStream_t stream);
class NormalizeImpl : public ::mmdeploy::NormalizeImpl {
public:
explicit NormalizeImpl(const Value& args) : ::mmdeploy::NormalizeImpl(args) {}
protected:
Result<Tensor> NormalizeImage(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
auto src_desc = src_tensor.desc();
int h = (int)src_desc.shape[1];
int w = (int)src_desc.shape[2];
int c = (int)src_desc.shape[3];
int stride = w * c;
TensorDesc dst_desc{device_, DataType::kFLOAT, src_desc.shape, src_desc.name};
Tensor dst_tensor{dst_desc};
auto output = dst_tensor.data<float>();
auto stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
if (DataType::kINT8 == src_desc.data_type) {
auto input = src_tensor.data<uint8_t>();
if (3 == c) {
Normalize<uint8_t, 3>(input, h, w, stride, output, arg_.mean.data(), arg_.std.data(),
arg_.to_rgb, stream);
} else if (1 == c) {
Normalize<uint8_t, 1>(input, h, w, stride, output, arg_.mean.data(), arg_.std.data(),
arg_.to_rgb, stream);
} else {
ERROR("unsupported channels {}", c);
return Status(eNotSupported);
}
} else if (DataType::kFLOAT == src_desc.data_type) {
auto input = src_tensor.data<float>();
if (3 == c) {
Normalize<float, 3>(input, h, w, stride, output, arg_.mean.data(), arg_.std.data(),
arg_.to_rgb, stream);
} else if (1 == c) {
Normalize<float, 1>(input, h, w, stride, output, arg_.mean.data(), arg_.std.data(),
arg_.to_rgb, stream);
} else {
ERROR("unsupported channels {}", c);
return Status(eNotSupported);
}
} else {
ERROR("unsupported data type {}", src_desc.data_type);
assert(0);
return Status(eNotSupported);
}
return dst_tensor;
}
};
class NormalizeImplCreator : public Creator<::mmdeploy::NormalizeImpl> {
public:
const char* GetName() const override { return "cuda"; }
int GetVersion() const override { return 1; }
std::unique_ptr<::mmdeploy::NormalizeImpl> Create(const Value& args) override {
return make_unique<NormalizeImpl>(args);
}
};
} // namespace mmdeploy::cuda
using mmdeploy::NormalizeImpl;
using mmdeploy::cuda::NormalizeImplCreator;
REGISTER_MODULE(NormalizeImpl, NormalizeImplCreator);