Add python interface by pybind11 and Int8 mode

Reviewed by: @TCHeish
pull/457/head
慕湮 2021-04-12 15:05:21 +08:00 committed by GitHub
parent 1dce15efad
commit fc67350e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 664 additions and 8 deletions

3
.gitmodules vendored 100644
View File

@ -0,0 +1,3 @@
[submodule "projects/FastRT/pybind_interface/pybind11"]
path = projects/FastRT/pybind_interface/pybind11
url = https://github.com/pybind/pybind11.git

0
projects/FastRT/.gitignore vendored 100644 → 100755
View File

15
projects/FastRT/CMakeLists.txt 100644 → 100755
View File

@ -31,13 +31,17 @@ option(CUDA_USE_STATIC_CUDA_RUNTIME "Use Static CUDA" OFF)
option(BUILD_FASTRT_ENGINE "Build FastRT Engine" ON) option(BUILD_FASTRT_ENGINE "Build FastRT Engine" ON)
option(BUILD_DEMO "Build DEMO" ON) option(BUILD_DEMO "Build DEMO" ON)
option(BUILD_FP16 "Build Engine as FP16" OFF) option(BUILD_FP16 "Build Engine as FP16" OFF)
option(BUILD_INT8 "Build Engine as INT8" OFF)
option(USE_CNUMPY "Include CNPY libs" OFF) option(USE_CNUMPY "Include CNPY libs" OFF)
option(BUILD_PYTHON_INTERFACE "Build Python Interface" OFF)
if(USE_CNUMPY) if(USE_CNUMPY)
add_definitions(-DUSE_CNUMPY) add_definitions(-DUSE_CNUMPY)
endif() endif()
if(BUILD_INT8)
if(BUILD_FP16) add_definitions(-DBUILD_INT8)
message("Build Engine as INT8")
elseif(BUILD_FP16)
add_definitions(-DBUILD_FP16) add_definitions(-DBUILD_FP16)
message("Build Engine as FP16") message("Build Engine as FP16")
else() else()
@ -60,3 +64,10 @@ if(BUILD_DEMO)
else() else()
message(STATUS "BUILD_DEMO: OFF") message(STATUS "BUILD_DEMO: OFF")
endif() endif()
if(BUILD_PYTHON_INTERFACE)
add_subdirectory(pybind_interface)
message(STATUS "BUILD_PYTHON_INTERFACE: ON")
else()
message(STATUS "BUILD_PYTHON_INTERFACE: OFF")
endif()

65
projects/FastRT/README.md 100644 → 100755
View File

@ -44,7 +44,6 @@ So we don't use any parsers here.
6. Verify the output with pytorch 6. Verify the output with pytorch
7. (Optional) Once you verify the result, you can set FP16 for speed up 7. (Optional) Once you verify the result, you can set FP16 for speed up
``` ```
mkdir build mkdir build
@ -55,7 +54,23 @@ So we don't use any parsers here.
then go to [step 5](#step5) then go to [step 5](#step5)
8. (Optional) Build tensorrt model as shared libs 8. (Optional) You can use INT8 quantization for speed up
First, modify the source code to specify your calibrate dataset path. In `FastRT/fastrt/meta_arch/model.cpp`, line 91.
```
Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, w, h, PATH_TO_YOUR_DATASET, "int8calib.table", p);
```
Then build.
```
mkdir build
cd build
cmake -DBUILD_FASTRT_ENGINE=ON -DBUILD_DEMO=ON -DBUILD_INT8=ON ..
make
```
then go to [step 5](#step5)
9. (Optional) Build tensorrt model as shared libs
``` ```
mkdir build mkdir build
@ -74,6 +89,29 @@ So we don't use any parsers here.
then go to [step 5](#step5) then go to [step 5](#step5)
10. (Optional) Build tensorrt model with python interface, then you can use FastRT model in python.
First get the pybind lib, run `git submodule update --init --recursive`.
```
mkdir build
cd build
cmake -DBUILD_FASTRT_ENGINE=ON -DBUILD_DEMO=ON -DBUILD_PYTHON_INTERFACE=ON -DPYTHON_EXECUTABLE=$(which python) ..
make
```
You should get a so file `FastRT/build/pybind_interface/ReID.cpython-36m-x86_64-linux-gnu.so`.
Then go to [step 5](#step5) to create engine file. After that you can import this so file in python, and deserialize engine file to infer in python. You can find use example in `pybind_interface/test.py` and `pybind_interface/market_benchmark.py`.
```
from PATH_TO_SO_FILE import ReID
model = ReID(GPU_ID)
model.build(PATH_TO_YOUR_ENGINEFILE)
numpy_feature = np.array([model.infer(CV2_FRAME)])
```
### <a name="ConfigSection"></a>`Tensorrt Model Config` ### <a name="ConfigSection"></a>`Tensorrt Model Config`
Edit `FastRT/demo/inference.cpp`, according to your model config Edit `FastRT/demo/inference.cpp`, according to your model config
@ -160,9 +198,30 @@ static const bool WITH_NL = false;
static const int EMBEDDING_DIM = 0; static const int EMBEDDING_DIM = 0;
``` ```
+ Ex5.`kd-r18-r101_ibn`
```
static const std::string WEIGHTS_PATH = "../kd-r18-r101_ibn.wts";
static const std::string ENGINE_PATH = "./kd_r18_distill.engine";
static const int MAX_BATCH_SIZE = 16;
static const int INPUT_H = 384;
static const int INPUT_W = 128;
static const int OUTPUT_SIZE = 512;
static const int DEVICE_ID = 1;
static const FastreidBackboneType BACKBONE = FastreidBackboneType::r18_distill;
static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead;
static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP;
static const int LAST_STRIDE = 1;
static const bool WITH_IBNA = true;
static const bool WITH_NL = false;
static const int EMBEDDING_DIM = 0;
```
### Supported conversion ### Supported conversion
* Backbone: resnet50, resnet34, distill-resnet50, distill-resnet34 * Backbone: resnet50, resnet34, distill-resnet50, distill-resnet34, distill-resnet18
* Heads: embedding_head * Heads: embedding_head
* Plugin layers: ibn, non-local * Plugin layers: ibn, non-local
* Pooling layers: maxpool, avgpool, GeneralizedMeanPooling, GeneralizedMeanPoolingP * Pooling layers: maxpool, avgpool, GeneralizedMeanPooling, GeneralizedMeanPoolingP

View File

View File

@ -2,6 +2,7 @@ project(FastRTEngine)
file(GLOB_RECURSE COMMON_SRC_FILES file(GLOB_RECURSE COMMON_SRC_FILES
${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/calibrator.cpp
) )
find_package(CUDA REQUIRED) find_package(CUDA REQUIRED)

View File

@ -6,6 +6,53 @@
using namespace trtxapi; using namespace trtxapi;
namespace fastrt { namespace fastrt {
ILayer* backbone_sbsR18_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
std::string ibn{""};
if(_modelCfg.with_ibna) {
ibn = "a";
}
std::map<std::string, std::vector<std::string>> ibn_layers{
{"a", {"a","a","a","a","a","a","",""}},
{"b", {"","","b","","","","b","","","","","","","","","",}},
{"", {16,""}}};
Weights emptywts{DataType::kFLOAT, nullptr, 0};
IConvolutionLayer* conv1 = network->addConvolutionNd(input, 64, DimsHW{7, 7}, weightMap["backbone.conv1.weight"], emptywts);
TRTASSERT(conv1);
conv1->setStrideNd(DimsHW{2, 2});
conv1->setPaddingNd(DimsHW{3, 3});
IScaleLayer* bn1{nullptr};
if (ibn == "b") {
bn1 = addInstanceNorm2d(network, weightMap, *conv1->getOutput(0), "backbone.bn1", 1e-5);
} else {
bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "backbone.bn1", 1e-5);
}
IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU);
TRTASSERT(relu1);
// pytorch: nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
IPoolingLayer* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
TRTASSERT(pool1);
pool1->setStrideNd(DimsHW{2, 2});
pool1->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
ILayer* x = distill_basicBlock_ibn(network, weightMap, *pool1->getOutput(0), 64, 64, 1, "backbone.layer1.0.", ibn_layers[ibn][0]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 64, 64, 1, "backbone.layer1.1.", ibn_layers[ibn][1]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 64, 128, 2, "backbone.layer2.0.", ibn_layers[ibn][2]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 128, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][3]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 128, 256, 2, "backbone.layer3.0.", ibn_layers[ibn][4]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.1.", ibn_layers[ibn][5]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][6]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][7]);
IActivationLayer* relu2 = network->addActivation(*x->getOutput(0), ActivationType::kRELU);
TRTASSERT(relu2);
return relu2;
}
ILayer* backbone_sbsR34_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) { ILayer* backbone_sbsR34_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
std::string ibn{""}; std::string ibn{""};

View File

@ -0,0 +1,80 @@
#include <iostream>
#include <iterator>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <opencv2/dnn/dnn.hpp>
#include "fastrt/calibrator.h"
#include "fastrt/cuda_utils.h"
#include "fastrt/utils.h"
Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache)
: batchsize_(batchsize)
, input_w_(input_w)
, input_h_(input_h)
, img_idx_(0)
, img_dir_(img_dir)
, calib_table_name_(calib_table_name)
, input_blob_name_(input_blob_name)
, read_cache_(read_cache)
{
input_count_ = 3 * input_w * input_h * batchsize;
CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float)));
read_files_in_dir(img_dir, img_files_);
}
Int8EntropyCalibrator2::~Int8EntropyCalibrator2()
{
CUDA_CHECK(cudaFree(device_input_));
}
int Int8EntropyCalibrator2::getBatchSize() const
{
return batchsize_;
}
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings)
{
if (img_idx_ + batchsize_ > (int)img_files_.size()) {
return false;
}
std::vector<cv::Mat> input_imgs_;
for (int i = img_idx_; i < img_idx_ + batchsize_; i++) {
std::cout << img_dir_ + img_files_[i] << " " << i << std::endl;
cv::Mat temp = cv::imread(img_dir_ + img_files_[i]);
if (temp.empty()){
std::cerr << "Fatal error: image cannot open!" << std::endl;
return false;
}
input_imgs_.push_back(temp);
}
img_idx_ += batchsize_;
cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0, cv::Size(input_w_, input_h_), cv::Scalar(0, 0, 0), true, false);
CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr<float>(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice));
assert(!strcmp(names[0], input_blob_name_));
bindings[0] = device_input_;
return true;
}
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
{
std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
calib_cache_.clear();
std::ifstream input(calib_table_name_, std::ios::binary);
input >> std::noskipws;
if (read_cache_ && input.good())
{
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(calib_cache_));
}
length = calib_cache_.size();
return length ? calib_cache_.data() : nullptr;
}
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length)
{
std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
std::ofstream output(calib_table_name_, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
}

View File

@ -29,6 +29,11 @@ namespace fastrt {
/* cfg.MODEL.BACKBONE.DEPTH: 34x */ /* cfg.MODEL.BACKBONE.DEPTH: 34x */
std::cout << "[createBackboneModule]: backbone_sbsR34_distill" << std::endl; std::cout << "[createBackboneModule]: backbone_sbsR34_distill" << std::endl;
return make_unique<backbone_sbsR34_distill>(modelCfg); return make_unique<backbone_sbsR34_distill>(modelCfg);
case FastreidBackboneType::r18_distill:
/* cfg.MODEL.META_ARCHITECTURE: Distiller */
/* cfg.MODEL.BACKBONE.DEPTH: 18x */
std::cout << "[createBackboneModule]: backbone_sbsR18_distill" << std::endl;
return make_unique<backbone_sbsR18_distill>(modelCfg);
default: default:
std::cerr << "[Backbone is not supported.]" << std::endl; std::cerr << "[Backbone is not supported.]" << std::endl;
return nullptr; return nullptr;

View File

@ -19,7 +19,7 @@ namespace fastrt {
/* Standardization */ /* Standardization */
static const float mean[3] = {123.675, 116.28, 103.53}; static const float mean[3] = {123.675, 116.28, 103.53};
static const float std[3] = {58.395, 57.120000000000005, 57.375}; static const float std[3] = {58.395, 57.120000000000005, 57.375};
return addMeanStd(network, weightMap, input, "", mean, std, true); // true for div 255 return addMeanStd(network, weightMap, input, "", mean, std, false); // true for div 255
} }
} }

View File

@ -1,4 +1,6 @@
#include "fastrt/model.h" #include "fastrt/model.h"
#include "fastrt/calibrator.h"
namespace fastrt { namespace fastrt {
@ -68,9 +70,27 @@ namespace fastrt {
/* Build engine */ /* Build engine */
builder->setMaxBatchSize(_engineCfg.max_batch_size); builder->setMaxBatchSize(_engineCfg.max_batch_size);
config->setMaxWorkspaceSize(1 << 20); config->setMaxWorkspaceSize(1 << 20);
#if defined(BUILD_FP16) && defined(BUILD_INT8)
std::cout << "Flag confilct! BUILD_FP16 and BUILD_INT8 can't be both True!" << std::endl;
return null;
#endif
#ifdef BUILD_FP16 #ifdef BUILD_FP16
std::cout << "[Build fp16]" << std::endl; std::cout << "[Build fp16]" << std::endl;
config->setFlag(BuilderFlag::kFP16); config->setFlag(BuilderFlag::kFP16);
#endif
#ifdef BUILD_INT8
std::cout << "[Build int8]" << std::endl;
std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
assert(builder->platformHasFastInt8());
config->setFlag(BuilderFlag::kINT8);
int w = _engineCfg.input_w;
int h = _engineCfg.input_h;
char*p = (char*)_engineCfg.input_name.data();
//path must end with /
Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, w, h,
"/data/person_reid/data/Market-1501-v15.09.15/bounding_box_test/", "int8calib.table", p);
config->setInt8Calibrator(calibrator);
#endif #endif
auto engine = make_holder(builder->buildEngineWithConfig(*network, *config)); auto engine = make_holder(builder->buildEngineWithConfig(*network, *config));
std::cout << "[TRT engine build out]" << std::endl; std::cout << "[TRT engine build out]" << std::endl;

View File

@ -0,0 +1,39 @@
#ifndef ENTROPY_CALIBRATOR_H
#define ENTROPY_CALIBRATOR_H
#include "NvInfer.h"
#include <string>
#include <vector>
//! \class Int8EntropyCalibrator2
//!
//! \brief Implements Entropy calibrator 2.
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
//!
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
{
public:
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true);
virtual ~Int8EntropyCalibrator2();
int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
const void* readCalibrationCache(size_t& length) override;
void writeCalibrationCache(const void* cache, size_t length) override;
private:
int batchsize_;
int input_w_;
int input_h_;
int img_idx_;
std::string img_dir_;
std::vector<std::string> img_files_;
size_t input_count_;
std::string calib_table_name_;
const char* input_blob_name_;
bool read_cache_;
void* device_input_;
std::vector<char> calib_cache_;
};
#endif // ENTROPY_CALIBRATOR_H

View File

@ -0,0 +1,18 @@
#ifndef TRTX_CUDA_UTILS_H_
#define TRTX_CUDA_UTILS_H_
#include <cuda_runtime_api.h>
#ifndef CUDA_CHECK
#define CUDA_CHECK(callstr)\
{\
cudaError_t error_code = callstr;\
if (error_code != cudaSuccess) {\
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
assert(0);\
}\
}
#endif // CUDA_CHECK
#endif // TRTX_CUDA_UTILS_H_

View File

View File

@ -7,6 +7,16 @@
using namespace nvinfer1; using namespace nvinfer1;
namespace fastrt { namespace fastrt {
class backbone_sbsR18_distill : public Module {
private:
FastreidConfig& _modelCfg;
public:
backbone_sbsR18_distill(FastreidConfig& modelCfg) : _modelCfg(modelCfg){}
~backbone_sbsR18_distill() = default;
ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap,
ITensor& input) override;
};
class backbone_sbsR34_distill : public Module { class backbone_sbsR34_distill : public Module {
private: private:

View File

@ -28,7 +28,8 @@ namespace fastrt {
X(r50, "r50") \ X(r50, "r50") \
X(r50_distill, "r50_distill") \ X(r50_distill, "r50_distill") \
X(r34, "r34") \ X(r34, "r34") \
X(r34_distill, "r34_distill") X(r34_distill, "r34_distill") \
X(r18_distill, "r18_distill")
#define X(a, b) a, #define X(a, b) a,
enum FastreidBackboneType { FASTBACKBONE_TABLE }; enum FastreidBackboneType { FASTBACKBONE_TABLE };

View File

@ -7,7 +7,9 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <cassert> #include <cassert>
#include <string.h>
#include <dirent.h>
#include "NvInfer.h" #include "NvInfer.h"
#include "cuda_runtime_api.h" #include "cuda_runtime_api.h"
#include "fastrt/struct.h" #include "fastrt/struct.h"
@ -37,6 +39,26 @@ namespace io {
std::vector<std::string> fileGlob(const std::string& pattern); std::vector<std::string> fileGlob(const std::string& pattern);
} }
static inline int read_files_in_dir(const char *p_dir_name, std::vector<std::string> &file_names) {
DIR *p_dir = opendir(p_dir_name);
if (p_dir == nullptr) {
return -1;
}
struct dirent* p_file = nullptr;
while ((p_file = readdir(p_dir)) != nullptr) {
if (strcmp(p_file->d_name, ".") != 0 &&
strcmp(p_file->d_name, "..") != 0) {
std::string cur_file_name(p_file->d_name);
file_names.push_back(cur_file_name);
}
}
closedir(p_dir);
return 0;
}
namespace trt { namespace trt {
/* /*
* Load weights from files shared with TensorRT samples. * Load weights from files shared with TensorRT samples.

View File

@ -0,0 +1,40 @@
SET(APP_PROJECT_NAME ReID)
# pybind
add_subdirectory(pybind11)
find_package(CUDA REQUIRED)
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
# cuda
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)
# tensorrt
include_directories(/usr/include/x86_64-linux-gnu/)
link_directories(/usr/lib/x86_64-linux-gnu/)
include_directories(${SOLUTION_DIR}/include)
pybind11_add_module(${APP_PROJECT_NAME} ${PROJECT_SOURCE_DIR}/pybind_interface/reid.cpp)
# OpenCV
find_package(OpenCV)
target_include_directories(${APP_PROJECT_NAME}
PUBLIC
${OpenCV_INCLUDE_DIRS}
)
target_link_libraries(${APP_PROJECT_NAME}
PUBLIC
${OpenCV_LIBS}
)
if(BUILD_FASTRT_ENGINE AND BUILD_PYTHON_INTERFACE)
SET(FASTRTENGINE_LIB FastRTEngine)
else()
SET(FASTRTENGINE_LIB ${SOLUTION_DIR}/libs/FastRTEngine/libFastRTEngine.so)
endif()
target_link_libraries(${APP_PROJECT_NAME}
PRIVATE
${FASTRTENGINE_LIB}
nvinfer
)

View File

@ -0,0 +1,17 @@
# cuda10.0
FROM fineyu/tensorrt7:0.0.1
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y software-properties-common
RUN add-apt-repository -y ppa:timsc/opencv-3.4 && \
apt-get update && \
apt-get install -y cmake \
libopencv-dev \
libopencv-dnn-dev \
libopencv-shape3.4-dbg && \
rm -rf /var/lib/apt/lists/*
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && rm get-pip.py && pip3 install torch==1.6.0 torchvision tensorboard matplotlib scipy Pillow numpy prettytable easydict opencv-python \
scikit-learn pyyaml yacs termcolor tabulate tensorboard opencv-python pyyaml yacs termcolor tabulate gdown faiss-cpu

View File

@ -0,0 +1,17 @@
# cuda10.2
FROM nvcr.io/nvidia/tensorrt:20.03-py3
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y software-properties-common
RUN add-apt-repository -y ppa:timsc/opencv-3.4 && \
apt-get update && \
apt-get install -y cmake \
libopencv-dev \
libopencv-dnn-dev \
libopencv-shape3.4-dbg && \
rm -rf /var/lib/apt/lists/*
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && rm get-pip.py && pip3 install torch==1.6.0 torchvision tensorboard matplotlib scipy Pillow numpy prettytable easydict opencv-python \
scikit-learn pyyaml yacs termcolor tabulate tensorboard opencv-python pyyaml yacs termcolor tabulate gdown faiss-cpu

View File

@ -0,0 +1,65 @@
import random
import numpy as np
import cv2
import fs
import argparse
import io
import sys
import torch
import time
import os
import torchvision.transforms as T
sys.path.append('../../..')
sys.path.append('../')
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
from fastreid.data import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation.rank import eval_market1501
from build.pybind_interface.ReID import ReID
FEATURE_DIM = 512
GPU_ID = 0
def map(wrapper):
model = wrapper
cfg = get_cfg()
test_loader, num_query = build_reid_test_loader(cfg, "Market1501", T.Compose([]))
feats = []
pids = []
camids = []
for batch in test_loader:
for image_path in batch["img_paths"]:
t = torch.Tensor(np.array([model.infer(cv2.imread(image_path))]))
t.to(torch.device(GPU_ID))
feats.append(t)
pids.extend(batch["targets"].numpy())
camids.extend(batch["camids"].numpy())
feats = torch.cat(feats, dim=0)
q_feat = feats[:num_query]
g_feat = feats[num_query:]
q_pids = np.asarray(pids[:num_query])
g_pids = np.asarray(pids[num_query:])
q_camids = np.asarray(camids[:num_query])
g_camids = np.asarray(camids[num_query:])
distmat = 1 - torch.mm(q_feat, g_feat.t())
distmat = distmat.numpy()
all_cmc, all_AP, all_INP = eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, 5)
mAP = np.mean(all_AP)
print("mAP {}, rank-1 {}".format(mAP, all_cmc[0]))
if __name__ == '__main__':
infer = ReID(GPU_ID)
infer.build("../build/kd_r18_distill.engine")
map(infer)

@ -0,0 +1 @@
Subproject commit 0e01c243c7ffae3a2e52f998bacfe82f56aa96d9

View File

@ -0,0 +1,177 @@
#include <iostream>
#include <opencv2/opencv.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "fastrt/utils.h"
#include "fastrt/baseline.h"
#include "fastrt/factory.h"
using namespace fastrt;
using namespace nvinfer1;
namespace py = pybind11;
/* Ex1. sbs_R50-ibn */
static const std::string WEIGHTS_PATH = "../sbs_R50-ibn.wts";
static const std::string ENGINE_PATH = "./sbs_R50-ibn.engine";
static const int MAX_BATCH_SIZE = 4;
static const int INPUT_H = 256;
static const int INPUT_W = 128;
static const int OUTPUT_SIZE = 2048;
static const int DEVICE_ID = 0;
static const FastreidBackboneType BACKBONE = FastreidBackboneType::r50;
static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead;
static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP;
static const int LAST_STRIDE = 1;
static const bool WITH_IBNA = true;
static const bool WITH_NL = true;
static const int EMBEDDING_DIM = 0;
FastreidConfig reidCfg {
BACKBONE,
HEAD,
HEAD_POOLING,
LAST_STRIDE,
WITH_IBNA,
WITH_NL,
EMBEDDING_DIM};
class ReID
{
private:
int device; // GPU id
fastrt::Baseline baseline;
public:
ReID(int a);
int build(const std::string &engine_file);
// std::list<float> infer_test(const std::string &image_file);
std::list<float> infer(py::array_t<uint8_t>&);
std::list<std::list<float>> batch_infer(std::list<py::array_t<uint8_t>>&);
~ReID();
};
ReID::ReID(int device): baseline(trt::ModelConfig {
WEIGHTS_PATH,
MAX_BATCH_SIZE,
INPUT_H,
INPUT_W,
OUTPUT_SIZE,
device})
{
std::cout << "Init on device " << device << std::endl;
}
int ReID::build(const std::string &engine_file)
{
if(!baseline.deserializeEngine(engine_file)) {
std::cout << "DeserializeEngine Failed." << std::endl;
return -1;
}
return 0;
}
ReID::~ReID()
{
std::cout << "Destroy engine succeed" << std::endl;
}
std::list<float> ReID::infer(py::array_t<uint8_t>& img)
{
auto rows = img.shape(0);
auto cols = img.shape(1);
auto type = CV_8UC3;
cv::Mat img2(rows, cols, type, (unsigned char*)img.data());
cv::Mat re(INPUT_H, INPUT_W, CV_8UC3);
// std::cout << (int)img2.data[0] << std::endl;
cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */
std::vector<cv::Mat> input;
input.emplace_back(re);
if(!baseline.inference(input)) {
std::cout << "Inference Failed." << std::endl;
}
std::list<float> feature;
float* feat_embedding = baseline.getOutput();
TRTASSERT(feat_embedding);
for (int dim = 0; dim < baseline.getOutputSize(); ++dim) {
feature.push_back(feat_embedding[dim]);
}
return feature;
}
std::list<std::list<float>> ReID::batch_infer(std::list<py::array_t<uint8_t>>& imgs)
{
// auto t1 = Time::now();
std::vector<cv::Mat> input;
int count = 0;
while(!imgs.empty()){
py::array_t<uint8_t>& img = imgs.front();
imgs.pop_front();
// parse to cvmat
auto rows = img.shape(0);
auto cols = img.shape(1);
auto type = CV_8UC3;
cv::Mat img2(rows, cols, type, (unsigned char*)img.data());
cv::Mat re(INPUT_H, INPUT_W, CV_8UC3);
// std::cout << (int)img2.data[0] << std::endl;
cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */
input.emplace_back(re);
count += 1;
}
// auto t2 = Time::now();
if(!baseline.inference(input)) {
std::cout << "Inference Failed." << std::endl;
}
std::list<std::list<float>> result;
float* feat_embedding = baseline.getOutput();
TRTASSERT(feat_embedding);
// auto t3 = Time::now();
for (int index = 0; index < count; index++)
{
std::list<float> feature;
for (int dim = 0; dim < baseline.getOutputSize(); ++dim) {
feature.push_back(feat_embedding[index * baseline.getOutputSize() + dim]);
}
result.push_back(feature);
}
// std::cout << "[Preprocessing]: " << std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count() << "ms"
// << "[Infer]: " << std::chrono::duration_cast<std::chrono::milliseconds>(t3 - t2).count() << "ms"
// << "[Cast]: " << std::chrono::duration_cast<std::chrono::milliseconds>(Time::now() - t3).count() << "ms"
// << std::endl;
return result;
}
PYBIND11_MODULE(ReID, m) {
m.doc() = R"pbdoc(
Pybind11 example plugin
)pbdoc";
py::class_<ReID>(m, "ReID")
.def(py::init<int>())
.def("build", &ReID::build)
.def("infer", &ReID::infer, py::return_value_policy::automatic)
.def("batch_infer", &ReID::batch_infer, py::return_value_policy::automatic)
;
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}

View File

@ -0,0 +1,23 @@
import sys
sys.path.append("../")
from build.pybind_interface.ReID import ReID
import cv2
import time
if __name__ == '__main__':
iter_ = 20000
m = ReID(0)
m.build("../build/kd_r18_distill.engine")
print("build done")
frame = cv2.imread("/data/sunp/algorithm/2020_1015_time/pytorchtotensorrt_reid/test/query/0001/0001_c1s1_001051_00.jpg")
m.infer(frame)
t0 = time.time()
for i in range(iter_):
m.infer(frame)
total = time.time() - t0
print("CPP API fps is {:.1f}, avg infer time is {:.2f}ms".format(iter_ / total, total / iter_ * 1000))