refactor: model.h

pull/399/head
darrenhsieh 2021-01-31 18:12:35 +08:00
parent ebc375e51e
commit 5f7d3d586e
1 changed files with 7 additions and 4 deletions

View File

@ -31,9 +31,9 @@ namespace fastrt {
/* Create builder */
auto builder = make_holder(createInferBuilder(gLogger));
auto config = make_holder(builder->createBuilderConfig());
/* Create model to populate the network, then set the outputs and create an engine */
auto engine = createEngine<B, H>(builder.get(), config.get(), backbone, head);
auto engine = createEngine<B, H>(builder.get(), backbone, head);
TRTASSERT(engine.get());
/* Serialize the engine */
@ -74,15 +74,16 @@ namespace fastrt {
private:
template <typename B, typename H>
TensorRTHolder<ICudaEngine> createEngine(IBuilder* builder, IBuilderConfig* config,
TensorRTHolder<ICudaEngine> createEngine(IBuilder* builder,
std::function<B*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> backbone,
std::function<H*(INetworkDefinition*, std::map<std::string, Weights>&, ITensor&, const FastreidConfig&)> head) {
auto network = make_holder(builder->createNetworkV2(0U));
auto config = make_holder(builder->createBuilderConfig());
auto data = network->addInput(_engineCfg.input_name.c_str(), _dt, Dims3{3, _engineCfg.input_h, _engineCfg.input_w});
TRTASSERT(data);
std::map<std::string, Weights> weightMap = loadWeights(_engineCfg.weights_path);
auto weightMap = loadWeights(_engineCfg.weights_path);
/* Preprocessing */
auto pre_input = preprocessing_gpu(network.get(), weightMap, data);
@ -90,7 +91,9 @@ namespace fastrt {
/* Modeling */
auto feat_map = backbone(network.get(), weightMap, *pre_input, _reidcfg);
TRTASSERT(feat_map);
auto embedding = head(network.get(), weightMap, *feat_map->getOutput(0), _reidcfg);
TRTASSERT(embedding);
/* Set output */
embedding->getOutput(0)->setName(_engineCfg.output_name.c_str());