diff --git a/projects/FastRT/include/fastrt/model.h b/projects/FastRT/include/fastrt/model.h index a8509a7..ba0bb26 100644 --- a/projects/FastRT/include/fastrt/model.h +++ b/projects/FastRT/include/fastrt/model.h @@ -31,9 +31,9 @@ namespace fastrt { /* Create builder */ auto builder = make_holder(createInferBuilder(gLogger)); - auto config = make_holder(builder->createBuilderConfig()); + /* Create model to populate the network, then set the outputs and create an engine */ - auto engine = createEngine(builder.get(), config.get(), backbone, head); + auto engine = createEngine(builder.get(), backbone, head); TRTASSERT(engine.get()); /* Serialize the engine */ @@ -74,15 +74,16 @@ namespace fastrt { private: template - TensorRTHolder createEngine(IBuilder* builder, IBuilderConfig* config, + TensorRTHolder createEngine(IBuilder* builder, std::function&, ITensor&, const FastreidConfig&)> backbone, std::function&, ITensor&, const FastreidConfig&)> head) { auto network = make_holder(builder->createNetworkV2(0U)); + auto config = make_holder(builder->createBuilderConfig()); auto data = network->addInput(_engineCfg.input_name.c_str(), _dt, Dims3{3, _engineCfg.input_h, _engineCfg.input_w}); TRTASSERT(data); - std::map weightMap = loadWeights(_engineCfg.weights_path); + auto weightMap = loadWeights(_engineCfg.weights_path); /* Preprocessing */ auto pre_input = preprocessing_gpu(network.get(), weightMap, data); @@ -90,7 +91,9 @@ namespace fastrt { /* Modeling */ auto feat_map = backbone(network.get(), weightMap, *pre_input, _reidcfg); + TRTASSERT(feat_map); auto embedding = head(network.get(), weightMap, *feat_map->getOutput(0), _reidcfg); + TRTASSERT(embedding); /* Set output */ embedding->getOutput(0)->setName(_engineCfg.output_name.c_str());