mirror of https://github.com/JDAI-CV/fast-reid.git
parent
1dce15efad
commit
fc67350e99
projects/FastRT
demo
fastrt
backbones
common
factory
meta_arch
pybind_interface
docker
trt7cu100
trt7cu102
|
@ -0,0 +1,3 @@
|
|||
[submodule "projects/FastRT/pybind_interface/pybind11"]
|
||||
path = projects/FastRT/pybind_interface/pybind11
|
||||
url = https://github.com/pybind/pybind11.git
|
|
@ -3,4 +3,4 @@
|
|||
.vscode/
|
||||
libs/
|
||||
build/
|
||||
data/
|
||||
data/
|
|
@ -31,13 +31,17 @@ option(CUDA_USE_STATIC_CUDA_RUNTIME "Use Static CUDA" OFF)
|
|||
option(BUILD_FASTRT_ENGINE "Build FastRT Engine" ON)
|
||||
option(BUILD_DEMO "Build DEMO" ON)
|
||||
option(BUILD_FP16 "Build Engine as FP16" OFF)
|
||||
option(BUILD_INT8 "Build Engine as INT8" OFF)
|
||||
option(USE_CNUMPY "Include CNPY libs" OFF)
|
||||
option(BUILD_PYTHON_INTERFACE "Build Python Interface" OFF)
|
||||
|
||||
if(USE_CNUMPY)
|
||||
add_definitions(-DUSE_CNUMPY)
|
||||
endif()
|
||||
|
||||
if(BUILD_FP16)
|
||||
if(BUILD_INT8)
|
||||
add_definitions(-DBUILD_INT8)
|
||||
message("Build Engine as INT8")
|
||||
elseif(BUILD_FP16)
|
||||
add_definitions(-DBUILD_FP16)
|
||||
message("Build Engine as FP16")
|
||||
else()
|
||||
|
@ -60,3 +64,10 @@ if(BUILD_DEMO)
|
|||
else()
|
||||
message(STATUS "BUILD_DEMO: OFF")
|
||||
endif()
|
||||
|
||||
if(BUILD_PYTHON_INTERFACE)
|
||||
add_subdirectory(pybind_interface)
|
||||
message(STATUS "BUILD_PYTHON_INTERFACE: ON")
|
||||
else()
|
||||
message(STATUS "BUILD_PYTHON_INTERFACE: OFF")
|
||||
endif()
|
||||
|
|
|
@ -44,7 +44,6 @@ So we don't use any parsers here.
|
|||
|
||||
6. Verify the output with pytorch
|
||||
|
||||
|
||||
7. (Optional) Once you verify the result, you can set FP16 for speed up
|
||||
```
|
||||
mkdir build
|
||||
|
@ -55,7 +54,23 @@ So we don't use any parsers here.
|
|||
|
||||
then go to [step 5](#step5)
|
||||
|
||||
8. (Optional) Build tensorrt model as shared libs
|
||||
8. (Optional) You can use INT8 quantization for speed up
|
||||
|
||||
First, modify the source code to specify your calibrate dataset path. In `FastRT/fastrt/meta_arch/model.cpp`, line 91.
|
||||
```
|
||||
Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, w, h, PATH_TO_YOUR_DATASET, "int8calib.table", p);
|
||||
```
|
||||
Then build.
|
||||
```
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DBUILD_FASTRT_ENGINE=ON -DBUILD_DEMO=ON -DBUILD_INT8=ON ..
|
||||
make
|
||||
```
|
||||
|
||||
then go to [step 5](#step5)
|
||||
|
||||
9. (Optional) Build tensorrt model as shared libs
|
||||
|
||||
```
|
||||
mkdir build
|
||||
|
@ -73,6 +88,29 @@ So we don't use any parsers here.
|
|||
```
|
||||
|
||||
then go to [step 5](#step5)
|
||||
|
||||
10. (Optional) Build tensorrt model with python interface, then you can use FastRT model in python.
|
||||
First get the pybind lib, run `git submodule update --init --recursive`.
|
||||
|
||||
```
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DBUILD_FASTRT_ENGINE=ON -DBUILD_DEMO=ON -DBUILD_PYTHON_INTERFACE=ON -DPYTHON_EXECUTABLE=$(which python) ..
|
||||
make
|
||||
```
|
||||
You should get a so file `FastRT/build/pybind_interface/ReID.cpython-36m-x86_64-linux-gnu.so`.
|
||||
|
||||
Then go to [step 5](#step5) to create engine file. After that you can import this so file in python, and deserialize engine file to infer in python. You can find use example in `pybind_interface/test.py` and `pybind_interface/market_benchmark.py`.
|
||||
```
|
||||
from PATH_TO_SO_FILE import ReID
|
||||
model = ReID(GPU_ID)
|
||||
model.build(PATH_TO_YOUR_ENGINEFILE)
|
||||
numpy_feature = np.array([model.infer(CV2_FRAME)])
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### <a name="ConfigSection"></a>`Tensorrt Model Config`
|
||||
|
||||
|
@ -160,9 +198,30 @@ static const bool WITH_NL = false;
|
|||
static const int EMBEDDING_DIM = 0;
|
||||
```
|
||||
|
||||
|
||||
+ Ex5.`kd-r18-r101_ibn`
|
||||
```
|
||||
static const std::string WEIGHTS_PATH = "../kd-r18-r101_ibn.wts";
|
||||
static const std::string ENGINE_PATH = "./kd_r18_distill.engine";
|
||||
|
||||
static const int MAX_BATCH_SIZE = 16;
|
||||
static const int INPUT_H = 384;
|
||||
static const int INPUT_W = 128;
|
||||
static const int OUTPUT_SIZE = 512;
|
||||
static const int DEVICE_ID = 1;
|
||||
|
||||
static const FastreidBackboneType BACKBONE = FastreidBackboneType::r18_distill;
|
||||
static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead;
|
||||
static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP;
|
||||
static const int LAST_STRIDE = 1;
|
||||
static const bool WITH_IBNA = true;
|
||||
static const bool WITH_NL = false;
|
||||
static const int EMBEDDING_DIM = 0;
|
||||
```
|
||||
|
||||
### Supported conversion
|
||||
|
||||
* Backbone: resnet50, resnet34, distill-resnet50, distill-resnet34
|
||||
* Backbone: resnet50, resnet34, distill-resnet50, distill-resnet34, distill-resnet18
|
||||
* Heads: embedding_head
|
||||
* Plugin layers: ibn, non-local
|
||||
* Pooling layers: maxpool, avgpool, GeneralizedMeanPooling, GeneralizedMeanPoolingP
|
||||
|
|
|
@ -2,6 +2,7 @@ project(FastRTEngine)
|
|||
|
||||
file(GLOB_RECURSE COMMON_SRC_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/calibrator.cpp
|
||||
)
|
||||
|
||||
find_package(CUDA REQUIRED)
|
||||
|
|
|
@ -6,6 +6,53 @@
|
|||
using namespace trtxapi;
|
||||
|
||||
namespace fastrt {
|
||||
ILayer* backbone_sbsR18_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
|
||||
std::string ibn{""};
|
||||
if(_modelCfg.with_ibna) {
|
||||
ibn = "a";
|
||||
}
|
||||
std::map<std::string, std::vector<std::string>> ibn_layers{
|
||||
{"a", {"a","a","a","a","a","a","",""}},
|
||||
{"b", {"","","b","","","","b","","","","","","","","","",}},
|
||||
{"", {16,""}}};
|
||||
|
||||
Weights emptywts{DataType::kFLOAT, nullptr, 0};
|
||||
IConvolutionLayer* conv1 = network->addConvolutionNd(input, 64, DimsHW{7, 7}, weightMap["backbone.conv1.weight"], emptywts);
|
||||
TRTASSERT(conv1);
|
||||
conv1->setStrideNd(DimsHW{2, 2});
|
||||
conv1->setPaddingNd(DimsHW{3, 3});
|
||||
|
||||
IScaleLayer* bn1{nullptr};
|
||||
if (ibn == "b") {
|
||||
bn1 = addInstanceNorm2d(network, weightMap, *conv1->getOutput(0), "backbone.bn1", 1e-5);
|
||||
} else {
|
||||
bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "backbone.bn1", 1e-5);
|
||||
}
|
||||
IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU);
|
||||
TRTASSERT(relu1);
|
||||
|
||||
// pytorch: nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
|
||||
IPoolingLayer* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
|
||||
TRTASSERT(pool1);
|
||||
pool1->setStrideNd(DimsHW{2, 2});
|
||||
pool1->setPaddingMode(PaddingMode::kEXPLICIT_ROUND_UP);
|
||||
|
||||
ILayer* x = distill_basicBlock_ibn(network, weightMap, *pool1->getOutput(0), 64, 64, 1, "backbone.layer1.0.", ibn_layers[ibn][0]);
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 64, 64, 1, "backbone.layer1.1.", ibn_layers[ibn][1]);
|
||||
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 64, 128, 2, "backbone.layer2.0.", ibn_layers[ibn][2]);
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 128, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][3]);
|
||||
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 128, 256, 2, "backbone.layer3.0.", ibn_layers[ibn][4]);
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.1.", ibn_layers[ibn][5]);
|
||||
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][6]);
|
||||
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][7]);
|
||||
|
||||
IActivationLayer* relu2 = network->addActivation(*x->getOutput(0), ActivationType::kRELU);
|
||||
TRTASSERT(relu2);
|
||||
return relu2;
|
||||
}
|
||||
|
||||
ILayer* backbone_sbsR34_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
|
||||
std::string ibn{""};
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <fstream>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/dnn/dnn.hpp>
|
||||
#include "fastrt/calibrator.h"
|
||||
#include "fastrt/cuda_utils.h"
|
||||
#include "fastrt/utils.h"
|
||||
|
||||
Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache)
|
||||
: batchsize_(batchsize)
|
||||
, input_w_(input_w)
|
||||
, input_h_(input_h)
|
||||
, img_idx_(0)
|
||||
, img_dir_(img_dir)
|
||||
, calib_table_name_(calib_table_name)
|
||||
, input_blob_name_(input_blob_name)
|
||||
, read_cache_(read_cache)
|
||||
{
|
||||
input_count_ = 3 * input_w * input_h * batchsize;
|
||||
CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float)));
|
||||
read_files_in_dir(img_dir, img_files_);
|
||||
}
|
||||
|
||||
Int8EntropyCalibrator2::~Int8EntropyCalibrator2()
|
||||
{
|
||||
CUDA_CHECK(cudaFree(device_input_));
|
||||
}
|
||||
|
||||
int Int8EntropyCalibrator2::getBatchSize() const
|
||||
{
|
||||
return batchsize_;
|
||||
}
|
||||
|
||||
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings)
|
||||
{
|
||||
if (img_idx_ + batchsize_ > (int)img_files_.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<cv::Mat> input_imgs_;
|
||||
for (int i = img_idx_; i < img_idx_ + batchsize_; i++) {
|
||||
std::cout << img_dir_ + img_files_[i] << " " << i << std::endl;
|
||||
cv::Mat temp = cv::imread(img_dir_ + img_files_[i]);
|
||||
if (temp.empty()){
|
||||
std::cerr << "Fatal error: image cannot open!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
input_imgs_.push_back(temp);
|
||||
}
|
||||
img_idx_ += batchsize_;
|
||||
cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0, cv::Size(input_w_, input_h_), cv::Scalar(0, 0, 0), true, false);
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr<float>(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice));
|
||||
assert(!strcmp(names[0], input_blob_name_));
|
||||
bindings[0] = device_input_;
|
||||
return true;
|
||||
}
|
||||
|
||||
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
|
||||
{
|
||||
std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
|
||||
calib_cache_.clear();
|
||||
std::ifstream input(calib_table_name_, std::ios::binary);
|
||||
input >> std::noskipws;
|
||||
if (read_cache_ && input.good())
|
||||
{
|
||||
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(calib_cache_));
|
||||
}
|
||||
length = calib_cache_.size();
|
||||
return length ? calib_cache_.data() : nullptr;
|
||||
}
|
||||
|
||||
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length)
|
||||
{
|
||||
std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
|
||||
std::ofstream output(calib_table_name_, std::ios::binary);
|
||||
output.write(reinterpret_cast<const char*>(cache), length);
|
||||
}
|
||||
|
|
@ -29,6 +29,11 @@ namespace fastrt {
|
|||
/* cfg.MODEL.BACKBONE.DEPTH: 34x */
|
||||
std::cout << "[createBackboneModule]: backbone_sbsR34_distill" << std::endl;
|
||||
return make_unique<backbone_sbsR34_distill>(modelCfg);
|
||||
case FastreidBackboneType::r18_distill:
|
||||
/* cfg.MODEL.META_ARCHITECTURE: Distiller */
|
||||
/* cfg.MODEL.BACKBONE.DEPTH: 18x */
|
||||
std::cout << "[createBackboneModule]: backbone_sbsR18_distill" << std::endl;
|
||||
return make_unique<backbone_sbsR18_distill>(modelCfg);
|
||||
default:
|
||||
std::cerr << "[Backbone is not supported.]" << std::endl;
|
||||
return nullptr;
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace fastrt {
|
|||
/* Standardization */
|
||||
static const float mean[3] = {123.675, 116.28, 103.53};
|
||||
static const float std[3] = {58.395, 57.120000000000005, 57.375};
|
||||
return addMeanStd(network, weightMap, input, "", mean, std, true); // true for div 255
|
||||
return addMeanStd(network, weightMap, input, "", mean, std, false); // true for div 255
|
||||
}
|
||||
|
||||
}
|
|
@ -1,4 +1,6 @@
|
|||
#include "fastrt/model.h"
|
||||
#include "fastrt/calibrator.h"
|
||||
|
||||
|
||||
namespace fastrt {
|
||||
|
||||
|
@ -68,9 +70,27 @@ namespace fastrt {
|
|||
/* Build engine */
|
||||
builder->setMaxBatchSize(_engineCfg.max_batch_size);
|
||||
config->setMaxWorkspaceSize(1 << 20);
|
||||
#if defined(BUILD_FP16) && defined(BUILD_INT8)
|
||||
std::cout << "Flag confilct! BUILD_FP16 and BUILD_INT8 can't be both True!" << std::endl;
|
||||
return null;
|
||||
#endif
|
||||
#ifdef BUILD_FP16
|
||||
std::cout << "[Build fp16]" << std::endl;
|
||||
config->setFlag(BuilderFlag::kFP16);
|
||||
#endif
|
||||
#ifdef BUILD_INT8
|
||||
std::cout << "[Build int8]" << std::endl;
|
||||
std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
|
||||
assert(builder->platformHasFastInt8());
|
||||
config->setFlag(BuilderFlag::kINT8);
|
||||
int w = _engineCfg.input_w;
|
||||
int h = _engineCfg.input_h;
|
||||
char*p = (char*)_engineCfg.input_name.data();
|
||||
|
||||
//path must end with /
|
||||
Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, w, h,
|
||||
"/data/person_reid/data/Market-1501-v15.09.15/bounding_box_test/", "int8calib.table", p);
|
||||
config->setInt8Calibrator(calibrator);
|
||||
#endif
|
||||
auto engine = make_holder(builder->buildEngineWithConfig(*network, *config));
|
||||
std::cout << "[TRT engine build out]" << std::endl;
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
#ifndef ENTROPY_CALIBRATOR_H
|
||||
#define ENTROPY_CALIBRATOR_H
|
||||
|
||||
#include "NvInfer.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
//! \class Int8EntropyCalibrator2
|
||||
//!
|
||||
//! \brief Implements Entropy calibrator 2.
|
||||
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
|
||||
//!
|
||||
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
|
||||
{
|
||||
public:
|
||||
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true);
|
||||
|
||||
virtual ~Int8EntropyCalibrator2();
|
||||
int getBatchSize() const override;
|
||||
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
|
||||
const void* readCalibrationCache(size_t& length) override;
|
||||
void writeCalibrationCache(const void* cache, size_t length) override;
|
||||
|
||||
private:
|
||||
int batchsize_;
|
||||
int input_w_;
|
||||
int input_h_;
|
||||
int img_idx_;
|
||||
std::string img_dir_;
|
||||
std::vector<std::string> img_files_;
|
||||
size_t input_count_;
|
||||
std::string calib_table_name_;
|
||||
const char* input_blob_name_;
|
||||
bool read_cache_;
|
||||
void* device_input_;
|
||||
std::vector<char> calib_cache_;
|
||||
};
|
||||
|
||||
#endif // ENTROPY_CALIBRATOR_H
|
|
@ -0,0 +1,18 @@
|
|||
#ifndef TRTX_CUDA_UTILS_H_
|
||||
#define TRTX_CUDA_UTILS_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(callstr)\
|
||||
{\
|
||||
cudaError_t error_code = callstr;\
|
||||
if (error_code != cudaSuccess) {\
|
||||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
|
||||
assert(0);\
|
||||
}\
|
||||
}
|
||||
#endif // CUDA_CHECK
|
||||
|
||||
#endif // TRTX_CUDA_UTILS_H_
|
||||
|
|
@ -7,6 +7,16 @@
|
|||
using namespace nvinfer1;
|
||||
|
||||
namespace fastrt {
|
||||
class backbone_sbsR18_distill : public Module {
|
||||
private:
|
||||
FastreidConfig& _modelCfg;
|
||||
public:
|
||||
backbone_sbsR18_distill(FastreidConfig& modelCfg) : _modelCfg(modelCfg){}
|
||||
~backbone_sbsR18_distill() = default;
|
||||
ILayer* topology(INetworkDefinition *network,
|
||||
std::map<std::string, Weights>& weightMap,
|
||||
ITensor& input) override;
|
||||
};
|
||||
|
||||
class backbone_sbsR34_distill : public Module {
|
||||
private:
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace fastrt {
|
|||
X(r50, "r50") \
|
||||
X(r50_distill, "r50_distill") \
|
||||
X(r34, "r34") \
|
||||
X(r34_distill, "r34_distill")
|
||||
X(r34_distill, "r34_distill") \
|
||||
X(r18_distill, "r18_distill")
|
||||
|
||||
#define X(a, b) a,
|
||||
enum FastreidBackboneType { FASTBACKBONE_TABLE };
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
#include <string.h>
|
||||
|
||||
#include <dirent.h>
|
||||
#include "NvInfer.h"
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "fastrt/struct.h"
|
||||
|
@ -37,6 +39,26 @@ namespace io {
|
|||
std::vector<std::string> fileGlob(const std::string& pattern);
|
||||
}
|
||||
|
||||
static inline int read_files_in_dir(const char *p_dir_name, std::vector<std::string> &file_names) {
|
||||
DIR *p_dir = opendir(p_dir_name);
|
||||
if (p_dir == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct dirent* p_file = nullptr;
|
||||
while ((p_file = readdir(p_dir)) != nullptr) {
|
||||
if (strcmp(p_file->d_name, ".") != 0 &&
|
||||
strcmp(p_file->d_name, "..") != 0) {
|
||||
|
||||
std::string cur_file_name(p_file->d_name);
|
||||
file_names.push_back(cur_file_name);
|
||||
}
|
||||
}
|
||||
|
||||
closedir(p_dir);
|
||||
return 0;
|
||||
}
|
||||
|
||||
namespace trt {
|
||||
/*
|
||||
* Load weights from files shared with TensorRT samples.
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
SET(APP_PROJECT_NAME ReID)
|
||||
|
||||
# pybind
|
||||
add_subdirectory(pybind11)
|
||||
|
||||
find_package(CUDA REQUIRED)
|
||||
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
|
||||
# cuda
|
||||
include_directories(/usr/local/cuda/include)
|
||||
link_directories(/usr/local/cuda/lib64)
|
||||
# tensorrt
|
||||
include_directories(/usr/include/x86_64-linux-gnu/)
|
||||
link_directories(/usr/lib/x86_64-linux-gnu/)
|
||||
|
||||
include_directories(${SOLUTION_DIR}/include)
|
||||
|
||||
pybind11_add_module(${APP_PROJECT_NAME} ${PROJECT_SOURCE_DIR}/pybind_interface/reid.cpp)
|
||||
|
||||
# OpenCV
|
||||
find_package(OpenCV)
|
||||
target_include_directories(${APP_PROJECT_NAME}
|
||||
PUBLIC
|
||||
${OpenCV_INCLUDE_DIRS}
|
||||
)
|
||||
target_link_libraries(${APP_PROJECT_NAME}
|
||||
PUBLIC
|
||||
${OpenCV_LIBS}
|
||||
)
|
||||
|
||||
if(BUILD_FASTRT_ENGINE AND BUILD_PYTHON_INTERFACE)
|
||||
SET(FASTRTENGINE_LIB FastRTEngine)
|
||||
else()
|
||||
SET(FASTRTENGINE_LIB ${SOLUTION_DIR}/libs/FastRTEngine/libFastRTEngine.so)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${APP_PROJECT_NAME}
|
||||
PRIVATE
|
||||
${FASTRTENGINE_LIB}
|
||||
nvinfer
|
||||
)
|
|
@ -0,0 +1,17 @@
|
|||
# cuda10.0
|
||||
FROM fineyu/tensorrt7:0.0.1
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y software-properties-common
|
||||
|
||||
RUN add-apt-repository -y ppa:timsc/opencv-3.4 && \
|
||||
apt-get update && \
|
||||
apt-get install -y cmake \
|
||||
libopencv-dev \
|
||||
libopencv-dnn-dev \
|
||||
libopencv-shape3.4-dbg && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && rm get-pip.py && pip3 install torch==1.6.0 torchvision tensorboard matplotlib scipy Pillow numpy prettytable easydict opencv-python \
|
||||
scikit-learn pyyaml yacs termcolor tabulate tensorboard opencv-python pyyaml yacs termcolor tabulate gdown faiss-cpu
|
|
@ -0,0 +1,17 @@
|
|||
# cuda10.2
|
||||
FROM nvcr.io/nvidia/tensorrt:20.03-py3
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y software-properties-common
|
||||
|
||||
RUN add-apt-repository -y ppa:timsc/opencv-3.4 && \
|
||||
apt-get update && \
|
||||
apt-get install -y cmake \
|
||||
libopencv-dev \
|
||||
libopencv-dnn-dev \
|
||||
libopencv-shape3.4-dbg && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && rm get-pip.py && pip3 install torch==1.6.0 torchvision tensorboard matplotlib scipy Pillow numpy prettytable easydict opencv-python \
|
||||
scikit-learn pyyaml yacs termcolor tabulate tensorboard opencv-python pyyaml yacs termcolor tabulate gdown faiss-cpu
|
|
@ -0,0 +1,65 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
import fs
|
||||
import argparse
|
||||
import io
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import torchvision.transforms as T
|
||||
|
||||
sys.path.append('../../..')
|
||||
sys.path.append('../')
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.data import build_reid_train_loader, build_reid_test_loader
|
||||
from fastreid.evaluation.rank import eval_market1501
|
||||
|
||||
from build.pybind_interface.ReID import ReID
|
||||
|
||||
|
||||
FEATURE_DIM = 512
|
||||
GPU_ID = 0
|
||||
|
||||
def map(wrapper):
|
||||
model = wrapper
|
||||
cfg = get_cfg()
|
||||
test_loader, num_query = build_reid_test_loader(cfg, "Market1501", T.Compose([]))
|
||||
|
||||
feats = []
|
||||
pids = []
|
||||
camids = []
|
||||
|
||||
for batch in test_loader:
|
||||
for image_path in batch["img_paths"]:
|
||||
t = torch.Tensor(np.array([model.infer(cv2.imread(image_path))]))
|
||||
t.to(torch.device(GPU_ID))
|
||||
feats.append(t)
|
||||
pids.extend(batch["targets"].numpy())
|
||||
camids.extend(batch["camids"].numpy())
|
||||
|
||||
feats = torch.cat(feats, dim=0)
|
||||
q_feat = feats[:num_query]
|
||||
g_feat = feats[num_query:]
|
||||
q_pids = np.asarray(pids[:num_query])
|
||||
g_pids = np.asarray(pids[num_query:])
|
||||
q_camids = np.asarray(camids[:num_query])
|
||||
g_camids = np.asarray(camids[num_query:])
|
||||
|
||||
|
||||
distmat = 1 - torch.mm(q_feat, g_feat.t())
|
||||
distmat = distmat.numpy()
|
||||
all_cmc, all_AP, all_INP = eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, 5)
|
||||
mAP = np.mean(all_AP)
|
||||
print("mAP {}, rank-1 {}".format(mAP, all_cmc[0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
infer = ReID(GPU_ID)
|
||||
infer.build("../build/kd_r18_distill.engine")
|
||||
map(infer)
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 0e01c243c7ffae3a2e52f998bacfe82f56aa96d9
|
|
@ -0,0 +1,177 @@
|
|||
#include <iostream>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "fastrt/utils.h"
|
||||
#include "fastrt/baseline.h"
|
||||
#include "fastrt/factory.h"
|
||||
using namespace fastrt;
|
||||
using namespace nvinfer1;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
/* Ex1. sbs_R50-ibn */
|
||||
static const std::string WEIGHTS_PATH = "../sbs_R50-ibn.wts";
|
||||
static const std::string ENGINE_PATH = "./sbs_R50-ibn.engine";
|
||||
|
||||
static const int MAX_BATCH_SIZE = 4;
|
||||
static const int INPUT_H = 256;
|
||||
static const int INPUT_W = 128;
|
||||
static const int OUTPUT_SIZE = 2048;
|
||||
static const int DEVICE_ID = 0;
|
||||
|
||||
static const FastreidBackboneType BACKBONE = FastreidBackboneType::r50;
|
||||
static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead;
|
||||
static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP;
|
||||
static const int LAST_STRIDE = 1;
|
||||
static const bool WITH_IBNA = true;
|
||||
static const bool WITH_NL = true;
|
||||
static const int EMBEDDING_DIM = 0;
|
||||
|
||||
FastreidConfig reidCfg {
|
||||
BACKBONE,
|
||||
HEAD,
|
||||
HEAD_POOLING,
|
||||
LAST_STRIDE,
|
||||
WITH_IBNA,
|
||||
WITH_NL,
|
||||
EMBEDDING_DIM};
|
||||
|
||||
class ReID
|
||||
{
|
||||
|
||||
private:
|
||||
int device; // GPU id
|
||||
fastrt::Baseline baseline;
|
||||
|
||||
public:
|
||||
ReID(int a);
|
||||
int build(const std::string &engine_file);
|
||||
// std::list<float> infer_test(const std::string &image_file);
|
||||
std::list<float> infer(py::array_t<uint8_t>&);
|
||||
std::list<std::list<float>> batch_infer(std::list<py::array_t<uint8_t>>&);
|
||||
~ReID();
|
||||
};
|
||||
|
||||
ReID::ReID(int device): baseline(trt::ModelConfig {
|
||||
WEIGHTS_PATH,
|
||||
MAX_BATCH_SIZE,
|
||||
INPUT_H,
|
||||
INPUT_W,
|
||||
OUTPUT_SIZE,
|
||||
device})
|
||||
{
|
||||
std::cout << "Init on device " << device << std::endl;
|
||||
}
|
||||
|
||||
int ReID::build(const std::string &engine_file)
|
||||
{
|
||||
if(!baseline.deserializeEngine(engine_file)) {
|
||||
std::cout << "DeserializeEngine Failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
ReID::~ReID()
|
||||
{
|
||||
|
||||
std::cout << "Destroy engine succeed" << std::endl;
|
||||
}
|
||||
|
||||
std::list<float> ReID::infer(py::array_t<uint8_t>& img)
|
||||
{
|
||||
auto rows = img.shape(0);
|
||||
auto cols = img.shape(1);
|
||||
auto type = CV_8UC3;
|
||||
|
||||
cv::Mat img2(rows, cols, type, (unsigned char*)img.data());
|
||||
cv::Mat re(INPUT_H, INPUT_W, CV_8UC3);
|
||||
// std::cout << (int)img2.data[0] << std::endl;
|
||||
cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */
|
||||
std::vector<cv::Mat> input;
|
||||
input.emplace_back(re);
|
||||
|
||||
if(!baseline.inference(input)) {
|
||||
std::cout << "Inference Failed." << std::endl;
|
||||
}
|
||||
std::list<float> feature;
|
||||
|
||||
float* feat_embedding = baseline.getOutput();
|
||||
TRTASSERT(feat_embedding);
|
||||
for (int dim = 0; dim < baseline.getOutputSize(); ++dim) {
|
||||
feature.push_back(feat_embedding[dim]);
|
||||
}
|
||||
|
||||
return feature;
|
||||
}
|
||||
|
||||
|
||||
std::list<std::list<float>> ReID::batch_infer(std::list<py::array_t<uint8_t>>& imgs)
|
||||
{
|
||||
// auto t1 = Time::now();
|
||||
std::vector<cv::Mat> input;
|
||||
int count = 0;
|
||||
while(!imgs.empty()){
|
||||
py::array_t<uint8_t>& img = imgs.front();
|
||||
imgs.pop_front();
|
||||
// parse to cvmat
|
||||
auto rows = img.shape(0);
|
||||
auto cols = img.shape(1);
|
||||
auto type = CV_8UC3;
|
||||
|
||||
cv::Mat img2(rows, cols, type, (unsigned char*)img.data());
|
||||
cv::Mat re(INPUT_H, INPUT_W, CV_8UC3);
|
||||
// std::cout << (int)img2.data[0] << std::endl;
|
||||
cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */
|
||||
input.emplace_back(re);
|
||||
|
||||
count += 1;
|
||||
}
|
||||
// auto t2 = Time::now();
|
||||
|
||||
if(!baseline.inference(input)) {
|
||||
std::cout << "Inference Failed." << std::endl;
|
||||
}
|
||||
std::list<std::list<float>> result;
|
||||
|
||||
float* feat_embedding = baseline.getOutput();
|
||||
TRTASSERT(feat_embedding);
|
||||
|
||||
// auto t3 = Time::now();
|
||||
for (int index = 0; index < count; index++)
|
||||
{
|
||||
std::list<float> feature;
|
||||
for (int dim = 0; dim < baseline.getOutputSize(); ++dim) {
|
||||
feature.push_back(feat_embedding[index * baseline.getOutputSize() + dim]);
|
||||
}
|
||||
result.push_back(feature);
|
||||
}
|
||||
// std::cout << "[Preprocessing]: " << std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count() << "ms"
|
||||
// << "[Infer]: " << std::chrono::duration_cast<std::chrono::milliseconds>(t3 - t2).count() << "ms"
|
||||
// << "[Cast]: " << std::chrono::duration_cast<std::chrono::milliseconds>(Time::now() - t3).count() << "ms"
|
||||
// << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(ReID, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
Pybind11 example plugin
|
||||
)pbdoc";
|
||||
py::class_<ReID>(m, "ReID")
|
||||
.def(py::init<int>())
|
||||
.def("build", &ReID::build)
|
||||
.def("infer", &ReID::infer, py::return_value_policy::automatic)
|
||||
.def("batch_infer", &ReID::batch_infer, py::return_value_policy::automatic)
|
||||
;
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
#else
|
||||
m.attr("__version__") = "dev";
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
import sys
|
||||
|
||||
sys.path.append("../")
|
||||
from build.pybind_interface.ReID import ReID
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
iter_ = 20000
|
||||
m = ReID(0)
|
||||
m.build("../build/kd_r18_distill.engine")
|
||||
print("build done")
|
||||
|
||||
frame = cv2.imread("/data/sunp/algorithm/2020_1015_time/pytorchtotensorrt_reid/test/query/0001/0001_c1s1_001051_00.jpg")
|
||||
m.infer(frame)
|
||||
t0 = time.time()
|
||||
|
||||
for i in range(iter_):
|
||||
m.infer(frame)
|
||||
|
||||
total = time.time() - t0
|
||||
print("CPP API fps is {:.1f}, avg infer time is {:.2f}ms".format(iter_ / total, total / iter_ * 1000))
|
Loading…
Reference in New Issue