mirror of https://github.com/JDAI-CV/fast-reid.git
Merge branch 'master' of github.com:L1aoXingyu/fast-reid
commit
254a489eb1
|
@ -4,7 +4,7 @@ set(LIBARARY_NAME "FastRT" CACHE STRING "The Fastreid-tensorrt library name")
|
||||||
|
|
||||||
set(LIBARARY_VERSION_MAJOR "0")
|
set(LIBARARY_VERSION_MAJOR "0")
|
||||||
set(LIBARARY_VERSION_MINOR "0")
|
set(LIBARARY_VERSION_MINOR "0")
|
||||||
set(LIBARARY_VERSION_SINOR "1")
|
set(LIBARARY_VERSION_SINOR "2")
|
||||||
set(LIBARARY_SOVERSION "0")
|
set(LIBARARY_SOVERSION "0")
|
||||||
set(LIBARARY_VERSION "${LIBARARY_VERSION_MAJOR}.${LIBARARY_VERSION_MINOR}.${LIBARARY_VERSION_SINOR}")
|
set(LIBARARY_VERSION "${LIBARARY_VERSION_MAJOR}.${LIBARARY_VERSION_MINOR}.${LIBARARY_VERSION_SINOR}")
|
||||||
project(${LIBARARY_NAME}${LIBARARY_VERSION})
|
project(${LIBARARY_NAME}${LIBARARY_VERSION})
|
||||||
|
|
|
@ -16,7 +16,7 @@ add_executable(${APP_PROJECT_NAME} inference.cpp)
|
||||||
find_package(OpenCV)
|
find_package(OpenCV)
|
||||||
target_include_directories(${APP_PROJECT_NAME}
|
target_include_directories(${APP_PROJECT_NAME}
|
||||||
PUBLIC
|
PUBLIC
|
||||||
OpenCV_INCLUDE_DIRS
|
${OpenCV_INCLUDE_DIRS}
|
||||||
)
|
)
|
||||||
target_link_libraries(${APP_PROJECT_NAME}
|
target_link_libraries(${APP_PROJECT_NAME}
|
||||||
PUBLIC
|
PUBLIC
|
||||||
|
|
|
@ -24,7 +24,7 @@ PUBLIC
|
||||||
find_package(OpenCV)
|
find_package(OpenCV)
|
||||||
target_include_directories(${PROJECT_NAME}
|
target_include_directories(${PROJECT_NAME}
|
||||||
PUBLIC
|
PUBLIC
|
||||||
OpenCV_INCLUDE_DIRS
|
${OpenCV_INCLUDE_DIRS}
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(${PROJECT_NAME}
|
target_link_libraries(${PROJECT_NAME}
|
||||||
|
|
|
@ -31,9 +31,9 @@ namespace fastrt {
|
||||||
|
|
||||||
/* Create builder */
|
/* Create builder */
|
||||||
auto builder = make_holder(createInferBuilder(gLogger));
|
auto builder = make_holder(createInferBuilder(gLogger));
|
||||||
auto config = make_holder(builder->createBuilderConfig());
|
|
||||||
/* Create model to populate the network, then set the outputs and create an engine */
|
/* Create model to populate the network, then set the outputs and create an engine */
|
||||||
auto engine = createEngine<B, H>(builder.get(), config.get(), backbone, head);
|
auto engine = createEngine<B, H>(builder.get(), backbone, head);
|
||||||
TRTASSERT(engine.get());
|
TRTASSERT(engine.get());
|
||||||
|
|
||||||
/* Serialize the engine */
|
/* Serialize the engine */
|
||||||
|
@ -74,15 +74,16 @@ namespace fastrt {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename B, typename H>
|
template <typename B, typename H>
|
||||||
TensorRTHolder<ICudaEngine> createEngine(IBuilder* builder, IBuilderConfig* config,
|
TensorRTHolder<ICudaEngine> createEngine(IBuilder* builder,
|
||||||
std::function<B*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> backbone,
|
std::function<B*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> backbone,
|
||||||
std::function<H*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> head) {
|
std::function<H*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> head) {
|
||||||
|
|
||||||
auto network = make_holder(builder->createNetworkV2(0U));
|
auto network = make_holder(builder->createNetworkV2(0U));
|
||||||
|
auto config = make_holder(builder->createBuilderConfig());
|
||||||
auto data = network->addInput(_engineCfg.input_name.c_str(), _dt, Dims3{3, _engineCfg.input_h, _engineCfg.input_w});
|
auto data = network->addInput(_engineCfg.input_name.c_str(), _dt, Dims3{3, _engineCfg.input_h, _engineCfg.input_w});
|
||||||
TRTASSERT(data);
|
TRTASSERT(data);
|
||||||
|
|
||||||
std::map<std::string, Weights> weightMap = loadWeights(_engineCfg.weights_path);
|
auto weightMap = loadWeights(_engineCfg.weights_path);
|
||||||
|
|
||||||
/* Preprocessing */
|
/* Preprocessing */
|
||||||
auto pre_input = preprocessing_gpu(network.get(), weightMap, data);
|
auto pre_input = preprocessing_gpu(network.get(), weightMap, data);
|
||||||
|
@ -90,7 +91,9 @@ namespace fastrt {
|
||||||
|
|
||||||
/* Modeling */
|
/* Modeling */
|
||||||
auto feat_map = backbone(network.get(), weightMap, *pre_input, _reidcfg);
|
auto feat_map = backbone(network.get(), weightMap, *pre_input, _reidcfg);
|
||||||
|
TRTASSERT(feat_map);
|
||||||
auto embedding = head(network.get(), weightMap, *feat_map->getOutput(0), _reidcfg);
|
auto embedding = head(network.get(), weightMap, *feat_map->getOutput(0), _reidcfg);
|
||||||
|
TRTASSERT(embedding);
|
||||||
|
|
||||||
/* Set output */
|
/* Set output */
|
||||||
embedding->getOutput(0)->setName(_engineCfg.output_name.c_str());
|
embedding->getOutput(0)->setName(_engineCfg.output_name.c_str());
|
||||||
|
|
Loading…
Reference in New Issue