mmdeploy/csrc/net/trt/trt_net.h
lvhan028 36124f6205
Merge sdk (#251)
* check in cmake

* move backend_ops to csrc/backend_ops

* check in preprocess, model, some codebase and their c-apis

* check in CMakeLists.txt

* check in parts of test_csrc

* commit everything else

* add readme

* update core's BUILD_INTERFACE directory

* skip codespell on third_party

* update trt_net and ort_net's CMakeLists

* ignore clion's build directory

* check in pybind11

* add onnx.proto. Remove MMDeploy's dependency on ncnn's source code

* export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON

* remove useless message

* target include directory is wrong

* change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net

* skip install directory

* update project's cmake

* remove useless code

* set CMAKE_BUILD_TYPE to Release by force if it isn't set by user

* update custom ops CMakeLists

* pass object target's source lists

* fix lint end-of-file

* fix lint: trailing whitespace

* fix codespell hook

* remove bicubic_interpolate to csrc/backend_ops/

* set MMDEPLOY_BUILD_SDK OFF

* change custom ops build command

* add spdlog installation command

* update docs on how to checkout pybind11

* move bicubic_interpolate to backend_ops/tensorrt directory

* remove useless code

* correct cmake

* fix typo

* fix typo

* fix install directory

* correct sdk's readme

* set cub dir when cuda version < 11.0

* change directory where clang-format will apply to

* fix build command

* add .clang-format

* change clang-format style from google to file

* reformat csrc/backend_ops

* format sdk's code

* turn off clang-format for some files

* add -Xcompiler=-fno-gnu-unique

* fix trt topk initialize

* check in config for sdk demo

* update cmake script and csrc's readme

* correct config's path

* add cuda include directory, otherwise compile failed in case of tensorrt8.2

* clang-format onnx2ncnn.cpp

Co-authored-by: zhangli <lzhang329@gmail.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00

81 lines
2.0 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_NET_TRT_TRT_NET_H_
#define MMDEPLOY_SRC_NET_TRT_TRT_NET_H_
#include <core/mpl/span.h>
#include "NvInferRuntime.h"
#include "core/net.h"
namespace mmdeploy {
namespace trt_detail {
template <typename T>
class TRTWrapper {
public:
TRTWrapper() : ptr_(nullptr) {}
TRTWrapper(T* ptr) : ptr_(ptr) {} // NOLINT
~TRTWrapper() { reset(); }
TRTWrapper(const TRTWrapper&) = delete;
TRTWrapper& operator=(const TRTWrapper&) = delete;
TRTWrapper(TRTWrapper&& other) noexcept { *this = std::move(other); }
TRTWrapper& operator=(TRTWrapper&& other) noexcept {
reset(std::exchange(other.ptr_, nullptr));
return *this;
}
T& operator*() { return *ptr_; }
T* operator->() { return ptr_; }
void reset(T* p = nullptr) {
if (auto old = std::exchange(ptr_, p)) { // NOLINT
#if NV_TENSORRT_MAJOR < 8
old->destroy();
#else
delete old;
#endif
}
}
explicit operator bool() const noexcept { return ptr_ != nullptr; }
private:
T* ptr_;
};
// clang-format off
template <typename T>
explicit TRTWrapper(T*) -> TRTWrapper<T>;
// clang-format on
} // namespace trt_detail
class TRTNet : public Net {
public:
~TRTNet() override;
Result<void> Init(const Value& cfg) override;
Result<void> Deinit() override;
Result<Span<Tensor>> GetInputTensors() override;
Result<Span<Tensor>> GetOutputTensors() override;
Result<void> Reshape(Span<TensorShape> input_shapes) override;
Result<void> Forward() override;
Result<void> ForwardAsync(Event* event) override;
private:
private:
trt_detail::TRTWrapper<nvinfer1::ICudaEngine> engine_;
trt_detail::TRTWrapper<nvinfer1::IExecutionContext> context_;
std::vector<int> input_ids_;
std::vector<int> output_ids_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
std::vector<Tensor> input_tensors_;
std::vector<Tensor> output_tensors_;
Device device_;
Stream stream_;
Event event_;
};
} // namespace mmdeploy
#endif // MMDEPLOY_SRC_NET_TRT_TRT_NET_H_