fast-reid/projects/FastRT/pybind_interface/reid.cpp

178 lines
4.9 KiB
C++

#include <iostream>
#include <opencv2/opencv.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "fastrt/utils.h"
#include "fastrt/baseline.h"
#include "fastrt/factory.h"
using namespace fastrt;
using namespace nvinfer1;
namespace py = pybind11;
/* Ex1. sbs_R50-ibn */
static const std::string WEIGHTS_PATH = "../sbs_R50-ibn.wts";
static const std::string ENGINE_PATH = "./sbs_R50-ibn.engine";
static const int MAX_BATCH_SIZE = 4;
static const int INPUT_H = 384;
static const int INPUT_W = 128;
static const int OUTPUT_SIZE = 2048;
static const int DEVICE_ID = 0;
static const FastreidBackboneType BACKBONE = FastreidBackboneType::r50;
static const FastreidHeadType HEAD = FastreidHeadType::EmbeddingHead;
static const FastreidPoolingType HEAD_POOLING = FastreidPoolingType::gempoolP;
static const int LAST_STRIDE = 1;
static const bool WITH_IBNA = true;
static const bool WITH_NL = true;
static const int EMBEDDING_DIM = 0;
FastreidConfig reidCfg {
BACKBONE,
HEAD,
HEAD_POOLING,
LAST_STRIDE,
WITH_IBNA,
WITH_NL,
EMBEDDING_DIM};
class ReID
{
private:
int device; // GPU id
fastrt::Baseline baseline;
public:
ReID(int a);
int build(const std::string &engine_file);
// std::list<float> infer_test(const std::string &image_file);
std::list<float> infer(py::array_t<uint8_t>&);
std::list<std::list<float>> batch_infer(std::list<py::array_t<uint8_t>>&);
~ReID();
};
ReID::ReID(int device): baseline(trt::ModelConfig {
WEIGHTS_PATH,
MAX_BATCH_SIZE,
INPUT_H,
INPUT_W,
OUTPUT_SIZE,
device})
{
std::cout << "Init on device " << device << std::endl;
}
int ReID::build(const std::string &engine_file)
{
if(!baseline.deserializeEngine(engine_file)) {
std::cout << "DeserializeEngine Failed." << std::endl;
return -1;
}
return 0;
}
ReID::~ReID()
{
std::cout << "Destroy engine succeed" << std::endl;
}
std::list<float> ReID::infer(py::array_t<uint8_t>& img)
{
auto rows = img.shape(0);
auto cols = img.shape(1);
auto type = CV_8UC3;
cv::Mat img2(rows, cols, type, (unsigned char*)img.data());
cv::Mat re(INPUT_H, INPUT_W, CV_8UC3);
// std::cout << (int)img2.data[0] << std::endl;
cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */
std::vector<cv::Mat> input;
input.emplace_back(re);
if(!baseline.inference(input)) {
std::cout << "Inference Failed." << std::endl;
}
std::list<float> feature;
float* feat_embedding = baseline.getOutput();
TRTASSERT(feat_embedding);
for (int dim = 0; dim < baseline.getOutputSize(); ++dim) {
feature.push_back(feat_embedding[dim]);
}
return feature;
}
std::list<std::list<float>> ReID::batch_infer(std::list<py::array_t<uint8_t>>& imgs)
{
// auto t1 = Time::now();
std::vector<cv::Mat> input;
int count = 0;
while(!imgs.empty()){
py::array_t<uint8_t>& img = imgs.front();
imgs.pop_front();
// parse to cvmat
auto rows = img.shape(0);
auto cols = img.shape(1);
auto type = CV_8UC3;
cv::Mat img2(rows, cols, type, (unsigned char*)img.data());
cv::Mat re(INPUT_H, INPUT_W, CV_8UC3);
// std::cout << (int)img2.data[0] << std::endl;
cv::resize(img2, re, re.size(), 0, 0, cv::INTER_CUBIC); /* cv::INTER_LINEAR */
input.emplace_back(re);
count += 1;
}
// auto t2 = Time::now();
if(!baseline.inference(input)) {
std::cout << "Inference Failed." << std::endl;
}
std::list<std::list<float>> result;
float* feat_embedding = baseline.getOutput();
TRTASSERT(feat_embedding);
// auto t3 = Time::now();
for (int index = 0; index < count; index++)
{
std::list<float> feature;
for (int dim = 0; dim < baseline.getOutputSize(); ++dim) {
feature.push_back(feat_embedding[index * baseline.getOutputSize() + dim]);
}
result.push_back(feature);
}
// std::cout << "[Preprocessing]: " << std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count() << "ms"
// << "[Infer]: " << std::chrono::duration_cast<std::chrono::milliseconds>(t3 - t2).count() << "ms"
// << "[Cast]: " << std::chrono::duration_cast<std::chrono::milliseconds>(Time::now() - t3).count() << "ms"
// << std::endl;
return result;
}
PYBIND11_MODULE(ReID, m) {
m.doc() = R"pbdoc(
Pybind11 example plugin
)pbdoc";
py::class_<ReID>(m, "ReID")
.def(py::init<int>())
.def("build", &ReID::build)
.def("infer", &ReID::infer, py::return_value_policy::automatic)
.def("batch_infer", &ReID::batch_infer, py::return_value_policy::automatic)
;
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}