mirror of https://github.com/JDAI-CV/fast-reid.git
refactor: model.h
parent
ebc375e51e
commit
5f7d3d586e
|
@ -31,9 +31,9 @@ namespace fastrt {
|
|||
|
||||
/* Create builder */
|
||||
auto builder = make_holder(createInferBuilder(gLogger));
|
||||
auto config = make_holder(builder->createBuilderConfig());
|
||||
|
||||
/* Create model to populate the network, then set the outputs and create an engine */
|
||||
auto engine = createEngine<B, H>(builder.get(), config.get(), backbone, head);
|
||||
auto engine = createEngine<B, H>(builder.get(), backbone, head);
|
||||
TRTASSERT(engine.get());
|
||||
|
||||
/* Serialize the engine */
|
||||
|
@ -74,15 +74,16 @@ namespace fastrt {
|
|||
|
||||
private:
|
||||
template <typename B, typename H>
|
||||
TensorRTHolder<ICudaEngine> createEngine(IBuilder* builder, IBuilderConfig* config,
|
||||
TensorRTHolder<ICudaEngine> createEngine(IBuilder* builder,
|
||||
std::function<B*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> backbone,
|
||||
std::function<H*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> head) {
|
||||
|
||||
auto network = make_holder(builder->createNetworkV2(0U));
|
||||
auto config = make_holder(builder->createBuilderConfig());
|
||||
auto data = network->addInput(_engineCfg.input_name.c_str(), _dt, Dims3{3, _engineCfg.input_h, _engineCfg.input_w});
|
||||
TRTASSERT(data);
|
||||
|
||||
std::map<std::string, Weights> weightMap = loadWeights(_engineCfg.weights_path);
|
||||
auto weightMap = loadWeights(_engineCfg.weights_path);
|
||||
|
||||
/* Preprocessing */
|
||||
auto pre_input = preprocessing_gpu(network.get(), weightMap, data);
|
||||
|
@ -90,7 +91,9 @@ namespace fastrt {
|
|||
|
||||
/* Modeling */
|
||||
auto feat_map = backbone(network.get(), weightMap, *pre_input, _reidcfg);
|
||||
TRTASSERT(feat_map);
|
||||
auto embedding = head(network.get(), weightMap, *feat_map->getOutput(0), _reidcfg);
|
||||
TRTASSERT(embedding);
|
||||
|
||||
/* Set output */
|
||||
embedding->getOutput(0)->setName(_engineCfg.output_name.c_str());
|
||||
|
|
Loading…
Reference in New Issue