From fc67350e992bc20903cb2e2321ce6e1db92d06e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E6=B9=AE?= Date: Mon, 12 Apr 2021 15:05:21 +0800 Subject: [PATCH] Add python interface by pybind11 and Int8 mode Reviewed by: @TCHeish --- .gitmodules | 3 + projects/FastRT/.gitignore | 2 +- projects/FastRT/CMakeLists.txt | 15 +- projects/FastRT/README.md | 65 ++++++- projects/FastRT/demo/inference.cpp | 0 projects/FastRT/fastrt/CMakeLists.txt | 1 + .../FastRT/fastrt/backbones/sbs_resnet.cpp | 47 +++++ projects/FastRT/fastrt/common/calibrator.cpp | 80 ++++++++ projects/FastRT/fastrt/factory/factory.cpp | 5 + projects/FastRT/fastrt/meta_arch/baseline.cpp | 2 +- projects/FastRT/fastrt/meta_arch/model.cpp | 20 ++ projects/FastRT/include/fastrt/calibrator.h | 39 ++++ projects/FastRT/include/fastrt/cuda_utils.h | 18 ++ projects/FastRT/include/fastrt/model.h | 0 projects/FastRT/include/fastrt/sbs_resnet.h | 10 + projects/FastRT/include/fastrt/struct.h | 3 +- projects/FastRT/include/fastrt/utils.h | 22 +++ .../FastRT/pybind_interface/CMakeLists.txt | 40 ++++ .../docker/trt7cu100/Dockerfile | 17 ++ .../docker/trt7cu102/Dockerfile | 17 ++ .../pybind_interface/market_benchmark.py | 65 +++++++ projects/FastRT/pybind_interface/pybind11 | 1 + projects/FastRT/pybind_interface/reid.cpp | 177 ++++++++++++++++++ projects/FastRT/pybind_interface/test.py | 23 +++ 24 files changed, 664 insertions(+), 8 deletions(-) create mode 100644 .gitmodules mode change 100644 => 100755 projects/FastRT/.gitignore mode change 100644 => 100755 projects/FastRT/CMakeLists.txt mode change 100644 => 100755 projects/FastRT/README.md mode change 100644 => 100755 projects/FastRT/demo/inference.cpp mode change 100644 => 100755 projects/FastRT/fastrt/CMakeLists.txt mode change 100644 => 100755 projects/FastRT/fastrt/backbones/sbs_resnet.cpp create mode 100755 projects/FastRT/fastrt/common/calibrator.cpp mode change 100644 => 100755 projects/FastRT/fastrt/factory/factory.cpp mode change 100644 => 100755 projects/FastRT/fastrt/meta_arch/baseline.cpp mode change 100644 => 100755 projects/FastRT/fastrt/meta_arch/model.cpp create mode 100755 projects/FastRT/include/fastrt/calibrator.h create mode 100755 projects/FastRT/include/fastrt/cuda_utils.h mode change 100644 => 100755 projects/FastRT/include/fastrt/model.h mode change 100644 => 100755 projects/FastRT/include/fastrt/sbs_resnet.h mode change 100644 => 100755 projects/FastRT/include/fastrt/struct.h mode change 100644 => 100755 projects/FastRT/include/fastrt/utils.h create mode 100755 projects/FastRT/pybind_interface/CMakeLists.txt create mode 100755 projects/FastRT/pybind_interface/docker/trt7cu100/Dockerfile create mode 100755 projects/FastRT/pybind_interface/docker/trt7cu102/Dockerfile create mode 100755 projects/FastRT/pybind_interface/market_benchmark.py create mode 160000 projects/FastRT/pybind_interface/pybind11 create mode 100755 projects/FastRT/pybind_interface/reid.cpp create mode 100755 projects/FastRT/pybind_interface/test.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..1eb4b19 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "projects/FastRT/pybind_interface/pybind11"] + path = projects/FastRT/pybind_interface/pybind11 + url = https://github.com/pybind/pybind11.git diff --git a/projects/FastRT/.gitignore b/projects/FastRT/.gitignore old mode 100644 new mode 100755 index e6d51cc..54681cd --- a/projects/FastRT/.gitignore +++ b/projects/FastRT/.gitignore @@ -3,4 +3,4 @@ .vscode/ libs/ build/ -data/ +data/ \ No newline at end of file diff --git a/projects/FastRT/CMakeLists.txt b/projects/FastRT/CMakeLists.txt old mode 100644 new mode 100755 index c73d82e..7f5074e --- a/projects/FastRT/CMakeLists.txt +++ b/projects/FastRT/CMakeLists.txt @@ -31,13 +31,17 @@ option(CUDA_USE_STATIC_CUDA_RUNTIME "Use Static CUDA" OFF) option(BUILD_FASTRT_ENGINE "Build FastRT Engine" ON) option(BUILD_DEMO "Build DEMO" ON) option(BUILD_FP16 "Build Engine as FP16" OFF) +option(BUILD_INT8 "Build Engine as INT8" OFF) option(USE_CNUMPY "Include CNPY libs" OFF) +option(BUILD_PYTHON_INTERFACE "Build Python Interface" OFF) if(USE_CNUMPY) add_definitions(-DUSE_CNUMPY) endif() - -if(BUILD_FP16) +if(BUILD_INT8) + add_definitions(-DBUILD_INT8) + message("Build Engine as INT8") +elseif(BUILD_FP16) add_definitions(-DBUILD_FP16) message("Build Engine as FP16") else() @@ -60,3 +64,10 @@ if(BUILD_DEMO) else() message(STATUS "BUILD_DEMO: OFF") endif() + +if(BUILD_PYTHON_INTERFACE) + add_subdirectory(pybind_interface) + message(STATUS "BUILD_PYTHON_INTERFACE: ON") +else() + message(STATUS "BUILD_PYTHON_INTERFACE: OFF") +endif() diff --git a/projects/FastRT/README.md b/projects/FastRT/README.md old mode 100644 new mode 100755 index af9a545..6188d66 --- a/projects/FastRT/README.md +++ b/projects/FastRT/README.md @@ -44,7 +44,6 @@ So we don't use any parsers here. 6. Verify the output with pytorch - 7. (Optional) Once you verify the result, you can set FP16 for speed up ``` mkdir build @@ -55,7 +54,23 @@ So we don't use any parsers here. then go to [step 5](#step5) -8. (Optional) Build tensorrt model as shared libs +8. (Optional) You can use INT8 quantization for speed up + + First, modify the source code to specify your calibrate dataset path. In `FastRT/fastrt/meta_arch/model.cpp`, line 91. + ``` + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, w, h, PATH_TO_YOUR_DATASET, "int8calib.table", p); + ``` + Then build. + ``` + mkdir build + cd build + cmake -DBUILD_FASTRT_ENGINE=ON -DBUILD_DEMO=ON -DBUILD_INT8=ON .. + make + ``` + + then go to [step 5](#step5) + +9. (Optional) Build tensorrt model as shared libs ``` mkdir build @@ -73,6 +88,29 @@ So we don't use any parsers here. ``` then go to [step 5](#step5) + +10. (Optional) Build tensorrt model with python interface, then you can use FastRT model in python. + First get the pybind lib, run `git submodule update --init --recursive`. + + ``` + mkdir build + cd build + cmake -DBUILD_FASTRT_ENGINE=ON -DBUILD_DEMO=ON -DBUILD_PYTHON_INTERFACE=ON -DPYTHON_EXECUTABLE=$(which python) .. + make + ``` + You should get a so file `FastRT/build/pybind_interface/ReID.cpython-36m-x86_64-linux-gnu.so`. + + Then go to [step 5](#step5) to create engine file. After that you can import this so file in python, and deserialize engine file to infer in python. You can find use example in `pybind_interface/test.py` and `pybind_interface/market_benchmark.py`. + ``` + from PATH_TO_SO_FILE import ReID + model = ReID(GPU_ID) + model.build(PATH_TO_YOUR_ENGINEFILE) + numpy_feature = np.array([model.infer(CV2_FRAME)]) + ``` + + + + ### `Tensorrt Model Config` @@ -160,9 +198,30 @@ static const bool WITH_NL = false; static const int EMBEDDING_DIM = 0; ``` + ++ Ex5.`kd-r18-r101_ibn` +``` +static const std::string WEIGHTS_PATH = "../kd-r18-r101_ibn.wts"; +static const std::string ENGINE_PATH = "./kd_r18_distill.engine"; + +static const int MAX_BATCH_SIZE = 16; +static const int INPUT_H = 384; +static const int INPUT_W = 128; +static const int OUTPUT_SIZE = 512; +static const int DEVICE_ID = 1; + +static const FastreidBackboneType BACKBONE = FastreidBackboneType::r18_distill; +static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead; +static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP; +static const int LAST_STRIDE = 1; +static const bool WITH_IBNA = true; +static const bool WITH_NL = false; +static const int EMBEDDING_DIM = 0; +``` + ### Supported conversion -* Backbone: resnet50, resnet34, distill-resnet50, distill-resnet34 +* Backbone: resnet50, resnet34, distill-resnet50, distill-resnet34, distill-resnet18 * Heads: embedding_head * Plugin layers: ibn, non-local * Pooling layers: maxpool, avgpool, GeneralizedMeanPooling, GeneralizedMeanPoolingP diff --git a/projects/FastRT/demo/inference.cpp b/projects/FastRT/demo/inference.cpp old mode 100644 new mode 100755 diff --git a/projects/FastRT/fastrt/CMakeLists.txt b/projects/FastRT/fastrt/CMakeLists.txt old mode 100644 new mode 100755 index e84d86c..b21e0d2 --- a/projects/FastRT/fastrt/CMakeLists.txt +++ b/projects/FastRT/fastrt/CMakeLists.txt @@ -2,6 +2,7 @@ project(FastRTEngine) file(GLOB_RECURSE COMMON_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/calibrator.cpp ) find_package(CUDA REQUIRED) diff --git a/projects/FastRT/fastrt/backbones/sbs_resnet.cpp b/projects/FastRT/fastrt/backbones/sbs_resnet.cpp old mode 100644 new mode 100755 index 35e2506..0f1775f --- a/projects/FastRT/fastrt/backbones/sbs_resnet.cpp +++ b/projects/FastRT/fastrt/backbones/sbs_resnet.cpp @@ -6,6 +6,53 @@ using namespace trtxapi; namespace fastrt { + ILayer* backbone_sbsR18_distill::topology(INetworkDefinition *network, std::map& weightMap, ITensor& input) { + std::string ibn{""}; + if(_modelCfg.with_ibna) { + ibn = "a"; + } + std::map> ibn_layers{ + {"a", {"a","a","a","a","a","a","",""}}, + {"b", {"","","b","","","","b","","","","","","","","","",}}, + {"", {16,""}}}; + + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + IConvolutionLayer* conv1 = network->addConvolutionNd(input, 64, DimsHW{7, 7}, weightMap["backbone.conv1.weight"], emptywts); + TRTASSERT(conv1); + conv1->setStrideNd(DimsHW{2, 2}); + conv1->setPaddingNd(DimsHW{3, 3}); + + IScaleLayer* bn1{nullptr}; + if (ibn == "b") { + bn1 = addInstanceNorm2d(network, weightMap, *conv1->getOutput(0), "backbone.bn1", 1e-5); + } else { + bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "backbone.bn1", 1e-5); + } + IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); + TRTASSERT(relu1); + + // pytorch: nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) + IPoolingLayer* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3}); + TRTASSERT(pool1); + pool1->setStrideNd(DimsHW{2, 2}); + pool1->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP); + + ILayer* x = distill_basicBlock_ibn(network, weightMap, *pool1->getOutput(0), 64, 64, 1, "backbone.layer1.0.", ibn_layers[ibn][0]); + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 64, 64, 1, "backbone.layer1.1.", ibn_layers[ibn][1]); + + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 64, 128, 2, "backbone.layer2.0.", ibn_layers[ibn][2]); + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 128, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][3]); + + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 128, 256, 2, "backbone.layer3.0.", ibn_layers[ibn][4]); + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.1.", ibn_layers[ibn][5]); + + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][6]); + x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][7]); + + IActivationLayer* relu2 = network->addActivation(*x->getOutput(0), ActivationType::kRELU); + TRTASSERT(relu2); + return relu2; + } ILayer* backbone_sbsR34_distill::topology(INetworkDefinition *network, std::map& weightMap, ITensor& input) { std::string ibn{""}; diff --git a/projects/FastRT/fastrt/common/calibrator.cpp b/projects/FastRT/fastrt/common/calibrator.cpp new file mode 100755 index 0000000..ac6fa3e --- /dev/null +++ b/projects/FastRT/fastrt/common/calibrator.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include +#include "fastrt/calibrator.h" +#include "fastrt/cuda_utils.h" +#include "fastrt/utils.h" + +Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache) + : batchsize_(batchsize) + , input_w_(input_w) + , input_h_(input_h) + , img_idx_(0) + , img_dir_(img_dir) + , calib_table_name_(calib_table_name) + , input_blob_name_(input_blob_name) + , read_cache_(read_cache) +{ + input_count_ = 3 * input_w * input_h * batchsize; + CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float))); + read_files_in_dir(img_dir, img_files_); +} + +Int8EntropyCalibrator2::~Int8EntropyCalibrator2() +{ + CUDA_CHECK(cudaFree(device_input_)); +} + +int Int8EntropyCalibrator2::getBatchSize() const +{ + return batchsize_; +} + +bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) +{ + if (img_idx_ + batchsize_ > (int)img_files_.size()) { + return false; + } + + std::vector input_imgs_; + for (int i = img_idx_; i < img_idx_ + batchsize_; i++) { + std::cout << img_dir_ + img_files_[i] << " " << i << std::endl; + cv::Mat temp = cv::imread(img_dir_ + img_files_[i]); + if (temp.empty()){ + std::cerr << "Fatal error: image cannot open!" << std::endl; + return false; + } + input_imgs_.push_back(temp); + } + img_idx_ += batchsize_; + cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0, cv::Size(input_w_, input_h_), cv::Scalar(0, 0, 0), true, false); + + CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice)); + assert(!strcmp(names[0], input_blob_name_)); + bindings[0] = device_input_; + return true; +} + +const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) +{ + std::cout << "reading calib cache: " << calib_table_name_ << std::endl; + calib_cache_.clear(); + std::ifstream input(calib_table_name_, std::ios::binary); + input >> std::noskipws; + if (read_cache_ && input.good()) + { + std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(calib_cache_)); + } + length = calib_cache_.size(); + return length ? calib_cache_.data() : nullptr; +} + +void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) +{ + std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl; + std::ofstream output(calib_table_name_, std::ios::binary); + output.write(reinterpret_cast(cache), length); +} + diff --git a/projects/FastRT/fastrt/factory/factory.cpp b/projects/FastRT/fastrt/factory/factory.cpp old mode 100644 new mode 100755 index b2f04c7..80d034a --- a/projects/FastRT/fastrt/factory/factory.cpp +++ b/projects/FastRT/fastrt/factory/factory.cpp @@ -29,6 +29,11 @@ namespace fastrt { /* cfg.MODEL.BACKBONE.DEPTH: 34x */ std::cout << "[createBackboneModule]: backbone_sbsR34_distill" << std::endl; return make_unique(modelCfg); + case FastreidBackboneType::r18_distill: + /* cfg.MODEL.META_ARCHITECTURE: Distiller */ + /* cfg.MODEL.BACKBONE.DEPTH: 18x */ + std::cout << "[createBackboneModule]: backbone_sbsR18_distill" << std::endl; + return make_unique(modelCfg); default: std::cerr << "[Backbone is not supported.]" << std::endl; return nullptr; diff --git a/projects/FastRT/fastrt/meta_arch/baseline.cpp b/projects/FastRT/fastrt/meta_arch/baseline.cpp old mode 100644 new mode 100755 index 90c2a81..3511842 --- a/projects/FastRT/fastrt/meta_arch/baseline.cpp +++ b/projects/FastRT/fastrt/meta_arch/baseline.cpp @@ -19,7 +19,7 @@ namespace fastrt { /* Standardization */ static const float mean[3] = {123.675, 116.28, 103.53}; static const float std[3] = {58.395, 57.120000000000005, 57.375}; - return addMeanStd(network, weightMap, input, "", mean, std, true); // true for div 255 + return addMeanStd(network, weightMap, input, "", mean, std, false); // true for div 255 } } \ No newline at end of file diff --git a/projects/FastRT/fastrt/meta_arch/model.cpp b/projects/FastRT/fastrt/meta_arch/model.cpp old mode 100644 new mode 100755 index 369a8ce..73476ba --- a/projects/FastRT/fastrt/meta_arch/model.cpp +++ b/projects/FastRT/fastrt/meta_arch/model.cpp @@ -1,4 +1,6 @@ #include "fastrt/model.h" +#include "fastrt/calibrator.h" + namespace fastrt { @@ -68,9 +70,27 @@ namespace fastrt { /* Build engine */ builder->setMaxBatchSize(_engineCfg.max_batch_size); config->setMaxWorkspaceSize(1 << 20); +#if defined(BUILD_FP16) && defined(BUILD_INT8) + std::cout << "Flag confilct! BUILD_FP16 and BUILD_INT8 can't be both True!" << std::endl; + return null; +#endif #ifdef BUILD_FP16 std::cout << "[Build fp16]" << std::endl; config->setFlag(BuilderFlag::kFP16); +#endif +#ifdef BUILD_INT8 + std::cout << "[Build int8]" << std::endl; + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + int w = _engineCfg.input_w; + int h = _engineCfg.input_h; + char*p = (char*)_engineCfg.input_name.data(); + + //path must end with / + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, w, h, + "/data/person_reid/data/Market-1501-v15.09.15/bounding_box_test/", "int8calib.table", p); + config->setInt8Calibrator(calibrator); #endif auto engine = make_holder(builder->buildEngineWithConfig(*network, *config)); std::cout << "[TRT engine build out]" << std::endl; diff --git a/projects/FastRT/include/fastrt/calibrator.h b/projects/FastRT/include/fastrt/calibrator.h new file mode 100755 index 0000000..1cc9dbb --- /dev/null +++ b/projects/FastRT/include/fastrt/calibrator.h @@ -0,0 +1,39 @@ +#ifndef ENTROPY_CALIBRATOR_H +#define ENTROPY_CALIBRATOR_H + +#include "NvInfer.h" +#include +#include + +//! \class Int8EntropyCalibrator2 +//! +//! \brief Implements Entropy calibrator 2. +//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. +//! +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 +{ +public: + Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true); + + virtual ~Int8EntropyCalibrator2(); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], int nbBindings) override; + const void* readCalibrationCache(size_t& length) override; + void writeCalibrationCache(const void* cache, size_t length) override; + +private: + int batchsize_; + int input_w_; + int input_h_; + int img_idx_; + std::string img_dir_; + std::vector img_files_; + size_t input_count_; + std::string calib_table_name_; + const char* input_blob_name_; + bool read_cache_; + void* device_input_; + std::vector calib_cache_; +}; + +#endif // ENTROPY_CALIBRATOR_H diff --git a/projects/FastRT/include/fastrt/cuda_utils.h b/projects/FastRT/include/fastrt/cuda_utils.h new file mode 100755 index 0000000..8fbd319 --- /dev/null +++ b/projects/FastRT/include/fastrt/cuda_utils.h @@ -0,0 +1,18 @@ +#ifndef TRTX_CUDA_UTILS_H_ +#define TRTX_CUDA_UTILS_H_ + +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr)\ + {\ + cudaError_t error_code = callstr;\ + if (error_code != cudaSuccess) {\ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ + assert(0);\ + }\ + } +#endif // CUDA_CHECK + +#endif // TRTX_CUDA_UTILS_H_ + diff --git a/projects/FastRT/include/fastrt/model.h b/projects/FastRT/include/fastrt/model.h old mode 100644 new mode 100755 diff --git a/projects/FastRT/include/fastrt/sbs_resnet.h b/projects/FastRT/include/fastrt/sbs_resnet.h old mode 100644 new mode 100755 index 5d883c0..96e34ef --- a/projects/FastRT/include/fastrt/sbs_resnet.h +++ b/projects/FastRT/include/fastrt/sbs_resnet.h @@ -7,6 +7,16 @@ using namespace nvinfer1; namespace fastrt { + class backbone_sbsR18_distill : public Module { + private: + FastreidConfig& _modelCfg; + public: + backbone_sbsR18_distill(FastreidConfig& modelCfg) : _modelCfg(modelCfg){} + ~backbone_sbsR18_distill() = default; + ILayer* topology(INetworkDefinition *network, + std::map& weightMap, + ITensor& input) override; + }; class backbone_sbsR34_distill : public Module { private: diff --git a/projects/FastRT/include/fastrt/struct.h b/projects/FastRT/include/fastrt/struct.h old mode 100644 new mode 100755 index 7a896f9..dc67024 --- a/projects/FastRT/include/fastrt/struct.h +++ b/projects/FastRT/include/fastrt/struct.h @@ -28,7 +28,8 @@ namespace fastrt { X(r50, "r50") \ X(r50_distill, "r50_distill") \ X(r34, "r34") \ - X(r34_distill, "r34_distill") + X(r34_distill, "r34_distill") \ + X(r18_distill, "r18_distill") #define X(a, b) a, enum FastreidBackboneType { FASTBACKBONE_TABLE }; diff --git a/projects/FastRT/include/fastrt/utils.h b/projects/FastRT/include/fastrt/utils.h old mode 100644 new mode 100755 index 5a43e14..2835de2 --- a/projects/FastRT/include/fastrt/utils.h +++ b/projects/FastRT/include/fastrt/utils.h @@ -7,7 +7,9 @@ #include #include #include +#include +#include #include "NvInfer.h" #include "cuda_runtime_api.h" #include "fastrt/struct.h" @@ -37,6 +39,26 @@ namespace io { std::vector fileGlob(const std::string& pattern); } +static inline int read_files_in_dir(const char *p_dir_name, std::vector &file_names) { + DIR *p_dir = opendir(p_dir_name); + if (p_dir == nullptr) { + return -1; + } + + struct dirent* p_file = nullptr; + while ((p_file = readdir(p_dir)) != nullptr) { + if (strcmp(p_file->d_name, ".") != 0 && + strcmp(p_file->d_name, "..") != 0) { + + std::string cur_file_name(p_file->d_name); + file_names.push_back(cur_file_name); + } + } + + closedir(p_dir); + return 0; +} + namespace trt { /* * Load weights from files shared with TensorRT samples. diff --git a/projects/FastRT/pybind_interface/CMakeLists.txt b/projects/FastRT/pybind_interface/CMakeLists.txt new file mode 100755 index 0000000..eb92995 --- /dev/null +++ b/projects/FastRT/pybind_interface/CMakeLists.txt @@ -0,0 +1,40 @@ +SET(APP_PROJECT_NAME ReID) + +# pybind +add_subdirectory(pybind11) + +find_package(CUDA REQUIRED) +# include and link dirs of cuda and tensorrt, you need adapt them if yours are different +# cuda +include_directories(/usr/local/cuda/include) +link_directories(/usr/local/cuda/lib64) +# tensorrt +include_directories(/usr/include/x86_64-linux-gnu/) +link_directories(/usr/lib/x86_64-linux-gnu/) + +include_directories(${SOLUTION_DIR}/include) + +pybind11_add_module(${APP_PROJECT_NAME} ${PROJECT_SOURCE_DIR}/pybind_interface/reid.cpp) + +# OpenCV +find_package(OpenCV) +target_include_directories(${APP_PROJECT_NAME} +PUBLIC + ${OpenCV_INCLUDE_DIRS} +) +target_link_libraries(${APP_PROJECT_NAME} +PUBLIC + ${OpenCV_LIBS} +) + +if(BUILD_FASTRT_ENGINE AND BUILD_PYTHON_INTERFACE) + SET(FASTRTENGINE_LIB FastRTEngine) +else() + SET(FASTRTENGINE_LIB ${SOLUTION_DIR}/libs/FastRTEngine/libFastRTEngine.so) +endif() + +target_link_libraries(${APP_PROJECT_NAME} +PRIVATE + ${FASTRTENGINE_LIB} + nvinfer +) \ No newline at end of file diff --git a/projects/FastRT/pybind_interface/docker/trt7cu100/Dockerfile b/projects/FastRT/pybind_interface/docker/trt7cu100/Dockerfile new file mode 100755 index 0000000..7672397 --- /dev/null +++ b/projects/FastRT/pybind_interface/docker/trt7cu100/Dockerfile @@ -0,0 +1,17 @@ +# cuda10.0 +FROM fineyu/tensorrt7:0.0.1 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y software-properties-common + +RUN add-apt-repository -y ppa:timsc/opencv-3.4 && \ +apt-get update && \ +apt-get install -y cmake \ +libopencv-dev \ +libopencv-dnn-dev \ +libopencv-shape3.4-dbg && \ +rm -rf /var/lib/apt/lists/* + +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && rm get-pip.py && pip3 install torch==1.6.0 torchvision tensorboard matplotlib scipy Pillow numpy prettytable easydict opencv-python \ +scikit-learn pyyaml yacs termcolor tabulate tensorboard opencv-python pyyaml yacs termcolor tabulate gdown faiss-cpu diff --git a/projects/FastRT/pybind_interface/docker/trt7cu102/Dockerfile b/projects/FastRT/pybind_interface/docker/trt7cu102/Dockerfile new file mode 100755 index 0000000..edec2ae --- /dev/null +++ b/projects/FastRT/pybind_interface/docker/trt7cu102/Dockerfile @@ -0,0 +1,17 @@ +# cuda10.2 +FROM nvcr.io/nvidia/tensorrt:20.03-py3 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y software-properties-common + +RUN add-apt-repository -y ppa:timsc/opencv-3.4 && \ +apt-get update && \ +apt-get install -y cmake \ +libopencv-dev \ +libopencv-dnn-dev \ +libopencv-shape3.4-dbg && \ +rm -rf /var/lib/apt/lists/* + +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && rm get-pip.py && pip3 install torch==1.6.0 torchvision tensorboard matplotlib scipy Pillow numpy prettytable easydict opencv-python \ +scikit-learn pyyaml yacs termcolor tabulate tensorboard opencv-python pyyaml yacs termcolor tabulate gdown faiss-cpu diff --git a/projects/FastRT/pybind_interface/market_benchmark.py b/projects/FastRT/pybind_interface/market_benchmark.py new file mode 100755 index 0000000..52f4ace --- /dev/null +++ b/projects/FastRT/pybind_interface/market_benchmark.py @@ -0,0 +1,65 @@ +import random +import numpy as np +import cv2 +import fs +import argparse +import io +import sys +import torch +import time +import os +import torchvision.transforms as T + +sys.path.append('../../..') +sys.path.append('../') +from fastreid.config import get_cfg +from fastreid.modeling.meta_arch import build_model +from fastreid.utils.file_io import PathManager +from fastreid.utils.checkpoint import Checkpointer +from fastreid.utils.logger import setup_logger +from fastreid.data import build_reid_train_loader, build_reid_test_loader +from fastreid.evaluation.rank import eval_market1501 + +from build.pybind_interface.ReID import ReID + + +FEATURE_DIM = 512 +GPU_ID = 0 + +def map(wrapper): + model = wrapper + cfg = get_cfg() + test_loader, num_query = build_reid_test_loader(cfg, "Market1501", T.Compose([])) + + feats = [] + pids = [] + camids = [] + + for batch in test_loader: + for image_path in batch["img_paths"]: + t = torch.Tensor(np.array([model.infer(cv2.imread(image_path))])) + t.to(torch.device(GPU_ID)) + feats.append(t) + pids.extend(batch["targets"].numpy()) + camids.extend(batch["camids"].numpy()) + + feats = torch.cat(feats, dim=0) + q_feat = feats[:num_query] + g_feat = feats[num_query:] + q_pids = np.asarray(pids[:num_query]) + g_pids = np.asarray(pids[num_query:]) + q_camids = np.asarray(camids[:num_query]) + g_camids = np.asarray(camids[num_query:]) + + + distmat = 1 - torch.mm(q_feat, g_feat.t()) + distmat = distmat.numpy() + all_cmc, all_AP, all_INP = eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, 5) + mAP = np.mean(all_AP) + print("mAP {}, rank-1 {}".format(mAP, all_cmc[0])) + + +if __name__ == '__main__': + infer = ReID(GPU_ID) + infer.build("../build/kd_r18_distill.engine") + map(infer) diff --git a/projects/FastRT/pybind_interface/pybind11 b/projects/FastRT/pybind_interface/pybind11 new file mode 160000 index 0000000..0e01c24 --- /dev/null +++ b/projects/FastRT/pybind_interface/pybind11 @@ -0,0 +1 @@ +Subproject commit 0e01c243c7ffae3a2e52f998bacfe82f56aa96d9 diff --git a/projects/FastRT/pybind_interface/reid.cpp b/projects/FastRT/pybind_interface/reid.cpp new file mode 100755 index 0000000..1156943 --- /dev/null +++ b/projects/FastRT/pybind_interface/reid.cpp @@ -0,0 +1,177 @@ +#include +#include +#include +#include +#include + +#include "fastrt/utils.h" +#include "fastrt/baseline.h" +#include "fastrt/factory.h" +using namespace fastrt; +using namespace nvinfer1; + +namespace py = pybind11; + + +/* Ex1. sbs_R50-ibn */ +static const std::string WEIGHTS_PATH = "../sbs_R50-ibn.wts"; +static const std::string ENGINE_PATH = "./sbs_R50-ibn.engine"; + +static const int MAX_BATCH_SIZE = 4; +static const int INPUT_H = 256; +static const int INPUT_W = 128; +static const int OUTPUT_SIZE = 2048; +static const int DEVICE_ID = 0; + +static const FastreidBackboneType BACKBONE = FastreidBackboneType::r50; +static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead; +static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP; +static const int LAST_STRIDE = 1; +static const bool WITH_IBNA = true; +static const bool WITH_NL = true; +static const int EMBEDDING_DIM = 0; + +FastreidConfig reidCfg { + BACKBONE, + HEAD, + HEAD_POOLING, + LAST_STRIDE, + WITH_IBNA, + WITH_NL, + EMBEDDING_DIM}; + +class ReID +{ + +private: + int device; // GPU id + fastrt::Baseline baseline; + +public: + ReID(int a); + int build(const std::string &engine_file); + // std::list infer_test(const std::string &image_file); + std::list infer(py::array_t&); + std::list> batch_infer(std::list>&); + ~ReID(); +}; + +ReID::ReID(int device): baseline(trt::ModelConfig { + WEIGHTS_PATH, + MAX_BATCH_SIZE, + INPUT_H, + INPUT_W, + OUTPUT_SIZE, + device}) +{ + std::cout << "Init on device " << device << std::endl; +} + +int ReID::build(const std::string &engine_file) +{ + if(!baseline.deserializeEngine(engine_file)) { + std::cout << "DeserializeEngine Failed." << std::endl; + return -1; + } + return 0; +} + +ReID::~ReID() +{ + + std::cout << "Destroy engine succeed" << std::endl; +} + +std::list ReID::infer(py::array_t& img) +{ + auto rows = img.shape(0); + auto cols = img.shape(1); + auto type = CV_8UC3; + + cv::Mat img2(rows, cols, type, (unsigned char*)img.data()); + cv::Mat re(INPUT_H, INPUT_W, CV_8UC3); + // std::cout << (int)img2.data[0] << std::endl; + cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */ + std::vector input; + input.emplace_back(re); + + if(!baseline.inference(input)) { + std::cout << "Inference Failed." << std::endl; + } + std::list feature; + + float* feat_embedding = baseline.getOutput(); + TRTASSERT(feat_embedding); + for (int dim = 0; dim < baseline.getOutputSize(); ++dim) { + feature.push_back(feat_embedding[dim]); + } + + return feature; +} + + +std::list> ReID::batch_infer(std::list>& imgs) +{ + // auto t1 = Time::now(); + std::vector input; + int count = 0; + while(!imgs.empty()){ + py::array_t& img = imgs.front(); + imgs.pop_front(); + // parse to cvmat + auto rows = img.shape(0); + auto cols = img.shape(1); + auto type = CV_8UC3; + + cv::Mat img2(rows, cols, type, (unsigned char*)img.data()); + cv::Mat re(INPUT_H, INPUT_W, CV_8UC3); + // std::cout << (int)img2.data[0] << std::endl; + cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */ + input.emplace_back(re); + + count += 1; + } + // auto t2 = Time::now(); + + if(!baseline.inference(input)) { + std::cout << "Inference Failed." << std::endl; + } + std::list> result; + + float* feat_embedding = baseline.getOutput(); + TRTASSERT(feat_embedding); + + // auto t3 = Time::now(); + for (int index = 0; index < count; index++) + { + std::list feature; + for (int dim = 0; dim < baseline.getOutputSize(); ++dim) { + feature.push_back(feat_embedding[index * baseline.getOutputSize() + dim]); + } + result.push_back(feature); + } + // std::cout << "[Preprocessing]: " << std::chrono::duration_cast(t2 - t1).count() << "ms" + // << "[Infer]: " << std::chrono::duration_cast(t3 - t2).count() << "ms" + // << "[Cast]: " << std::chrono::duration_cast(Time::now() - t3).count() << "ms" + // << std::endl; + return result; +} + + +PYBIND11_MODULE(ReID, m) { + m.doc() = R"pbdoc( + Pybind11 example plugin + )pbdoc"; + py::class_(m, "ReID") + .def(py::init()) + .def("build", &ReID::build) + .def("infer", &ReID::infer, py::return_value_policy::automatic) + .def("batch_infer", &ReID::batch_infer, py::return_value_policy::automatic) + ; + +#ifdef VERSION_INFO + m.attr("__version__") = VERSION_INFO; +#else + m.attr("__version__") = "dev"; +#endif +} diff --git a/projects/FastRT/pybind_interface/test.py b/projects/FastRT/pybind_interface/test.py new file mode 100755 index 0000000..a1a1db3 --- /dev/null +++ b/projects/FastRT/pybind_interface/test.py @@ -0,0 +1,23 @@ +import sys + +sys.path.append("../") +from build.pybind_interface.ReID import ReID +import cv2 +import time + + +if __name__ == '__main__': + iter_ = 20000 + m = ReID(0) + m.build("../build/kd_r18_distill.engine") + print("build done") + + frame = cv2.imread("/data/sunp/algorithm/2020_1015_time/pytorchtotensorrt_reid/test/query/0001/0001_c1s1_001051_00.jpg") + m.infer(frame) + t0 = time.time() + + for i in range(iter_): + m.infer(frame) + + total = time.time() - t0 + print("CPP API fps is {:.1f}, avg infer time is {:.2f}ms".format(iter_ / total, total / iter_ * 1000)) \ No newline at end of file