Merge pull request #416 from TCHeish

Reviewed by: l1aoxingyu
pull/424/head
Xingyu Liao 2021-03-04 15:18:47 +08:00 committed by GitHub
commit fcfa6800bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 93 additions and 80 deletions

View File

@ -4,7 +4,7 @@ set(LIBARARY_NAME "FastRT" CACHE STRING "The Fastreid-tensorrt library name")
set(LIBARARY_VERSION_MAJOR "0") set(LIBARARY_VERSION_MAJOR "0")
set(LIBARARY_VERSION_MINOR "0") set(LIBARARY_VERSION_MINOR "0")
set(LIBARARY_VERSION_SINOR "3") set(LIBARARY_VERSION_SINOR "4")
set(LIBARARY_SOVERSION "0") set(LIBARARY_SOVERSION "0")
set(LIBARARY_VERSION "${LIBARARY_VERSION_MAJOR}.${LIBARARY_VERSION_MINOR}.${LIBARARY_VERSION_SINOR}") set(LIBARARY_VERSION "${LIBARARY_VERSION_MAJOR}.${LIBARARY_VERSION_MINOR}.${LIBARARY_VERSION_SINOR}")
project(${LIBARARY_NAME}${LIBARARY_VERSION}) project(${LIBARARY_NAME}${LIBARARY_VERSION})

View File

@ -48,14 +48,14 @@ int main(int argc, char** argv) {
std::cout << "[ModelConfig]: \n" << modelCfg std::cout << "[ModelConfig]: \n" << modelCfg
<< "\n[FastreidConfig]: \n" << reidCfg << std::endl; << "\n[FastreidConfig]: \n" << reidCfg << std::endl;
Baseline baseline{modelCfg, reidCfg}; Baseline baseline{modelCfg};
if (argc == 2 && std::string(argv[1]) == "-s") { if (argc == 2 && std::string(argv[1]) == "-s") {
ModuleFactory moduleFactory; ModuleFactory moduleFactory;
std::cout << "[Serializling Engine]" << std::endl; std::cout << "[Serializling Engine]" << std::endl;
if (!baseline.serializeEngine(ENGINE_PATH, if (!baseline.serializeEngine(ENGINE_PATH,
{std::move(moduleFactory.createBackbone(reidCfg.backbone)), {std::move(moduleFactory.createBackbone(reidCfg)),
std::move(moduleFactory.createHead(reidCfg.head))})) { std::move(moduleFactory.createHead(reidCfg))})) {
std::cout << "SerializeEngine Failed." << std::endl; std::cout << "SerializeEngine Failed." << std::endl;
return -1; return -1;
} }

View File

@ -7,9 +7,9 @@ using namespace trtxapi;
namespace fastrt { namespace fastrt {
ILayer* backbone_sbsR34_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, const FastreidConfig& reidCfg) { ILayer* backbone_sbsR34_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
std::string ibn{""}; std::string ibn{""};
if(reidCfg.with_ibna) { if(_modelCfg.with_ibna) {
ibn = "a"; ibn = "a";
} }
std::map<std::string, std::vector<std::string>> ibn_layers{ std::map<std::string, std::vector<std::string>> ibn_layers{
@ -54,7 +54,7 @@ namespace fastrt {
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]); x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]); x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, reidCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]); x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]); x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]);
x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]); x = distill_basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]);
@ -63,9 +63,9 @@ namespace fastrt {
return relu2; return relu2;
} }
ILayer* backbone_sbsR50_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, const FastreidConfig& reidCfg) { ILayer* backbone_sbsR50_distill::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
std::string ibn{""}; std::string ibn{""};
if(reidCfg.with_ibna) { if(_modelCfg.with_ibna) {
ibn = "a"; ibn = "a";
} }
std::map<std::string, std::vector<std::string>> ibn_layers{ std::map<std::string, std::vector<std::string>> ibn_layers{
@ -102,12 +102,12 @@ namespace fastrt {
x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][4]); x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][4]);
x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.2.", ibn_layers[ibn][5]); x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.2.", ibn_layers[ibn][5]);
ILayer* _layer{x}; ILayer* _layer{x};
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.0."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.0.");
} }
x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 512, 128, 1, "backbone.layer2.3.", ibn_layers[ibn][6]); x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 512, 128, 1, "backbone.layer2.3.", ibn_layers[ibn][6]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.1."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.1.");
} }
@ -116,21 +116,21 @@ namespace fastrt {
x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.2.", ibn_layers[ibn][9]); x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.2.", ibn_layers[ibn][9]);
x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.3.", ibn_layers[ibn][10]); x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.3.", ibn_layers[ibn][10]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.0."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.0.");
} }
x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]); x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.1."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.1.");
} }
x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]); x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.2."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.2.");
} }
x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 512, reidCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]); x = distill_bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]);
x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]); x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]);
x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]); x = distill_bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]);
@ -139,9 +139,9 @@ namespace fastrt {
return relu2; return relu2;
} }
ILayer* backbone_sbsR34::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, const FastreidConfig& reidCfg) { ILayer* backbone_sbsR34::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
std::string ibn{""}; std::string ibn{""};
if(reidCfg.with_ibna) { if(_modelCfg.with_ibna) {
ibn = "a"; ibn = "a";
} }
std::map<std::string, std::vector<std::string>> ibn_layers{ std::map<std::string, std::vector<std::string>> ibn_layers{
@ -186,13 +186,13 @@ namespace fastrt {
x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]); x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]);
x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]); x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]);
x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, reidCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]); x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 256, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]);
x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]); x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]);
x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]); x = basicBlock_ibn(network, weightMap, *x->getOutput(0), 512, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]);
return x; return x;
} }
ILayer* backbone_sbsR50::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, const FastreidConfig& reidCfg) { ILayer* backbone_sbsR50::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
/* /*
* Reference: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/modeling/backbones/resnet.py * Reference: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/modeling/backbones/resnet.py
* NL layers follow by: nl_layers_per_stage = {'50x': [0, 2, 3, 0],}[depth] * NL layers follow by: nl_layers_per_stage = {'50x': [0, 2, 3, 0],}[depth]
@ -200,7 +200,7 @@ namespace fastrt {
* for nn.MaxPool2d(kernel_size=3, stride=2, padding=1) replace with => pool1->setPaddingNd(DimsHW{1, 1}); * for nn.MaxPool2d(kernel_size=3, stride=2, padding=1) replace with => pool1->setPaddingNd(DimsHW{1, 1});
*/ */
std::string ibn{""}; std::string ibn{""};
if(reidCfg.with_ibna) { if(_modelCfg.with_ibna) {
ibn = "a"; ibn = "a";
} }
std::map<std::string, std::vector<std::string>> ibn_layers{ std::map<std::string, std::vector<std::string>> ibn_layers{
@ -236,12 +236,12 @@ namespace fastrt {
x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][4]); x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.1.", ibn_layers[ibn][4]);
x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.2.", ibn_layers[ibn][5]); x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 512, 128, 1, "backbone.layer2.2.", ibn_layers[ibn][5]);
ILayer* _layer{x}; ILayer* _layer{x};
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.0."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.0.");
} }
x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 512, 128, 1, "backbone.layer2.3.", ibn_layers[ibn][6]); x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 512, 128, 1, "backbone.layer2.3.", ibn_layers[ibn][6]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.1."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_2.1.");
} }
@ -250,21 +250,21 @@ namespace fastrt {
x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.2.", ibn_layers[ibn][9]); x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.2.", ibn_layers[ibn][9]);
x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.3.", ibn_layers[ibn][10]); x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 1024, 256, 1, "backbone.layer3.3.", ibn_layers[ibn][10]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.0."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.0.");
} }
x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]); x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.4.", ibn_layers[ibn][11]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.1."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.1.");
} }
x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]); x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 256, 1, "backbone.layer3.5.", ibn_layers[ibn][12]);
_layer = x; _layer = x;
if(reidCfg.with_nl) { if(_modelCfg.with_nl) {
_layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.2."); _layer = Non_local(network, weightMap, *x->getOutput(0), "backbone.NL_3.2.");
} }
x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 512, reidCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]); x = bottleneck_ibn(network, weightMap, *_layer->getOutput(0), 1024, 512, _modelCfg.last_stride, "backbone.layer4.0.", ibn_layers[ibn][13]);
x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]); x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.1.", ibn_layers[ibn][14]);
x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]); x = bottleneck_ibn(network, weightMap, *x->getOutput(0), 2048, 512, 1, "backbone.layer4.2.", ibn_layers[ibn][15]);
return x; return x;

View File

@ -7,40 +7,40 @@
namespace fastrt { namespace fastrt {
std::unique_ptr<Module> ModuleFactory::createBackbone(const FastreidBackboneType& backbonetype) { std::unique_ptr<Module> ModuleFactory::createBackbone(FastreidConfig& modelCfg) {
switch(backbonetype) { switch(modelCfg.backbone) {
case FastreidBackboneType::r50: case FastreidBackboneType::r50:
/* cfg.MODEL.META_ARCHITECTURE: Baseline */ /* cfg.MODEL.META_ARCHITECTURE: Baseline */
/* cfg.MODEL.BACKBONE.DEPTH: 50x */ /* cfg.MODEL.BACKBONE.DEPTH: 50x */
std::cout << "[createBackboneModule]: backbone_sbsR50" << std::endl; std::cout << "[createBackboneModule]: backbone_sbsR50" << std::endl;
return make_unique<backbone_sbsR50>(); return make_unique<backbone_sbsR50>(modelCfg);
case FastreidBackboneType::r50_distill: case FastreidBackboneType::r50_distill:
/* cfg.MODEL.META_ARCHITECTURE: Distiller */ /* cfg.MODEL.META_ARCHITECTURE: Distiller */
/* cfg.MODEL.BACKBONE.DEPTH: 50x */ /* cfg.MODEL.BACKBONE.DEPTH: 50x */
std::cout << "[createBackboneModule]: backbone_sbsR50_distill" << std::endl; std::cout << "[createBackboneModule]: backbone_sbsR50_distill" << std::endl;
return make_unique<backbone_sbsR50_distill>(); return make_unique<backbone_sbsR50_distill>(modelCfg);
case FastreidBackboneType::r34: case FastreidBackboneType::r34:
/* cfg.MODEL.META_ARCHITECTURE: Baseline */ /* cfg.MODEL.META_ARCHITECTURE: Baseline */
/* cfg.MODEL.BACKBONE.DEPTH: 34x */ /* cfg.MODEL.BACKBONE.DEPTH: 34x */
std::cout << "[createBackboneModule]: backbone_sbsR34" << std::endl; std::cout << "[createBackboneModule]: backbone_sbsR34" << std::endl;
return make_unique<backbone_sbsR34>(); return make_unique<backbone_sbsR34>(modelCfg);
case FastreidBackboneType::r34_distill: case FastreidBackboneType::r34_distill:
/* cfg.MODEL.META_ARCHITECTURE: Distiller */ /* cfg.MODEL.META_ARCHITECTURE: Distiller */
/* cfg.MODEL.BACKBONE.DEPTH: 34x */ /* cfg.MODEL.BACKBONE.DEPTH: 34x */
std::cout << "[createBackboneModule]: backbone_sbsR34_distill" << std::endl; std::cout << "[createBackboneModule]: backbone_sbsR34_distill" << std::endl;
return make_unique<backbone_sbsR34_distill>(); return make_unique<backbone_sbsR34_distill>(modelCfg);
default: default:
std::cerr << "[Backbone is not supported.]" << std::endl; std::cerr << "[Backbone is not supported.]" << std::endl;
return nullptr; return nullptr;
} }
} }
std::unique_ptr<Module> ModuleFactory::createHead(const FastreidHeadType& headtype) { std::unique_ptr<Module> ModuleFactory::createHead(FastreidConfig& modelCfg) {
switch(headtype) { switch(modelCfg.head) {
case FastreidHeadType::EmbeddingHead: case FastreidHeadType::EmbeddingHead:
/* cfg.MODEL.HEADS.NAME: EmbeddingHead */ /* cfg.MODEL.HEADS.NAME: EmbeddingHead */
std::cout << "[createHeadModule]: EmbeddingHead" << std::endl; std::cout << "[createHeadModule]: EmbeddingHead" << std::endl;
return make_unique<embedding_head>(); return make_unique<embedding_head>(modelCfg);
default: default:
std::cerr << "[Head is not supported.]" << std::endl; std::cerr << "[Head is not supported.]" << std::endl;
return nullptr; return nullptr;

View File

@ -5,26 +5,28 @@
namespace fastrt { namespace fastrt {
embedding_head::embedding_head() : _layerFactory(make_unique<LayerFactory>()) {} embedding_head::embedding_head(FastreidConfig& modelCfg) :
_modelCfg(modelCfg), _layerFactory(make_unique<LayerFactory>()) {}
embedding_head::embedding_head(std::unique_ptr<LayerFactory> layerFactory) : _layerFactory(std::move(layerFactory)) {} embedding_head::embedding_head(FastreidConfig& modelCfg,
std::unique_ptr<LayerFactory> layerFactory) : _modelCfg(modelCfg), _layerFactory(std::move(layerFactory)) {}
ILayer* embedding_head::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, const FastreidConfig& reidCfg) { ILayer* embedding_head::topology(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input) {
/* /*
* Reference: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/modeling/heads/embedding_head.py * Reference: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/modeling/heads/embedding_head.py
*/ */
ILayer* pooling = _layerFactory->createPoolingLayer(reidCfg.pooling)->addPooling(network, weightMap, input); ILayer* pooling = _layerFactory->createPoolingLayer(_modelCfg.pooling)->addPooling(network, weightMap, input);
TRTASSERT(pooling); TRTASSERT(pooling);
// Hint: It's used to be "heads.bnneck.0" before Sep 10, 2020. (JDAI-CV/fast-reid) // Hint: It's used to be "heads.bnneck.0" before Sep 10, 2020. (JDAI-CV/fast-reid)
std::string bnneck_lname = "heads.bottleneck.0"; std::string bnneck_lname = "heads.bottleneck.0";
ILayer* reduction_neck{pooling}; ILayer* reduction_neck{pooling};
if(reidCfg.embedding_dim > 0) { if(_modelCfg.embedding_dim > 0) {
Weights emptywts{DataType::kFLOAT, nullptr, 0}; Weights emptywts{DataType::kFLOAT, nullptr, 0};
reduction_neck = network->addConvolutionNd(*pooling->getOutput(0), reduction_neck = network->addConvolutionNd(*pooling->getOutput(0),
reidCfg.embedding_dim, _modelCfg.embedding_dim,
DimsHW{1, 1}, DimsHW{1, 1},
weightMap["heads.bottleneck.0.weight"], weightMap["heads.bottleneck.0.weight"],
emptywts); emptywts);

View File

@ -14,20 +14,26 @@ namespace trtxapi {
return clip; return clip;
} }
ITensor* addDiv255(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname) {
Weights Div_225{ DataType::kFLOAT, nullptr, 3 };
float *wgt = reinterpret_cast<float*>(malloc(sizeof(float) * 3));
std::fill_n(wgt, 3, 255.0f);
Div_225.values = wgt;
weightMap[lname + ".div"] = Div_225;
IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225);
IElementWiseLayer* div255 = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV);
return div255->getOutput(0);
}
ITensor* addMeanStd(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname, const float* mean, const float* std, const bool div255) { ITensor* addMeanStd(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname, const float* mean, const float* std, const bool div255) {
ITensor* tensor_holder{input};
if (div255) { if (div255) {
Weights Div_225{ DataType::kFLOAT, nullptr, 3 }; tensor_holder = addDiv255(network, weightMap, input, lname);
float *wgt = reinterpret_cast<float*>(malloc(sizeof(float) * 3));
std::fill_n(wgt, 3, 255.0f);
Div_225.values = wgt;
weightMap[lname + ".div"] = Div_225;
IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225);
input = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV)->getOutput(0);
} }
Weights Mean{ DataType::kFLOAT, nullptr, 3 }; Weights Mean{ DataType::kFLOAT, nullptr, 3 };
Mean.values = mean; Mean.values = mean;
IConstantLayer* m = network->addConstant(Dims3{ 3, 1, 1 }, Mean); IConstantLayer* m = network->addConstant(Dims3{ 3, 1, 1 }, Mean);
IElementWiseLayer* sub_mean = network->addElementWise(*input, *m->getOutput(0), ElementWiseOperation::kSUB); IElementWiseLayer* sub_mean = network->addElementWise(*tensor_holder, *m->getOutput(0), ElementWiseOperation::kSUB);
if (std != nullptr) { if (std != nullptr) {
Weights Std{ DataType::kFLOAT, nullptr, 3 }; Weights Std{ DataType::kFLOAT, nullptr, 3 };
Std.values = std; Std.values = std;

View File

@ -3,7 +3,8 @@
namespace fastrt { namespace fastrt {
Baseline::Baseline(const trt::ModelConfig &modelcfg, const FastreidConfig& reidcfg) : Model(modelcfg, reidcfg) {} Baseline::Baseline(const trt::ModelConfig &modelcfg, const std::string input_name, const std::string output_name)
: Model(modelcfg, input_name, output_name) {}
void Baseline::preprocessing_cpu(const cv::Mat& img, float* const data, const std::size_t stride) { void Baseline::preprocessing_cpu(const cv::Mat& img, float* const data, const std::size_t stride) {
/* Normalization & BGR->RGB */ /* Normalization & BGR->RGB */

View File

@ -2,8 +2,7 @@
namespace fastrt { namespace fastrt {
Model::Model(const trt::ModelConfig &modelcfg, const FastreidConfig &reidcfg, const std::string input_name, const std::string output_name) : Model::Model(const trt::ModelConfig &modelcfg, const std::string input_name, const std::string output_name) {
_reidcfg(reidcfg) {
_engineCfg.weights_path = modelcfg.weights_path; _engineCfg.weights_path = modelcfg.weights_path;
_engineCfg.max_batch_size = modelcfg.max_batch_size; _engineCfg.max_batch_size = modelcfg.max_batch_size;
@ -57,7 +56,7 @@ namespace fastrt {
/* Modeling */ /* Modeling */
ILayer* output{nullptr}; ILayer* output{nullptr};
for(auto& sequential_module: modules) { for(auto& sequential_module: modules) {
output = sequential_module->topology(network.get(), weightMap, *input, _reidcfg); output = sequential_module->topology(network.get(), weightMap, *input);
TRTASSERT(output); TRTASSERT(output);
input = output->getOutput(0); input = output->getOutput(0);
} }

View File

@ -10,7 +10,9 @@ namespace fastrt {
class Baseline : public Model { class Baseline : public Model {
public: public:
Baseline(const trt::ModelConfig &modelcfg, const FastreidConfig& reidcfg); Baseline(const trt::ModelConfig &modelcfg,
const std::string input_name = "data",
const std::string output_name = "reid_embd");
~Baseline() = default; ~Baseline() = default;
private: private:
@ -18,8 +20,5 @@ namespace fastrt {
ITensor* preprocessing_gpu(INetworkDefinition* network, ITensor* preprocessing_gpu(INetworkDefinition* network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor* input); ITensor* input);
private:
std::string _input_name{"data"};
std::string _output_name{"embd"};
}; };
} }

View File

@ -11,17 +11,17 @@ namespace fastrt {
class embedding_head : public Module { class embedding_head : public Module {
private: private:
FastreidConfig& _modelCfg;
std::unique_ptr<LayerFactory> _layerFactory; std::unique_ptr<LayerFactory> _layerFactory;
public: public:
embedding_head(); embedding_head(FastreidConfig& modelCfg);
embedding_head(std::unique_ptr<LayerFactory> layerFactory); embedding_head(FastreidConfig& modelCfg, std::unique_ptr<LayerFactory> layerFactory);
~embedding_head() = default; ~embedding_head() = default;
ILayer* topology(INetworkDefinition *network, ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor& input, ITensor& input) override;
const FastreidConfig& reidCfg) override;
}; };
} }

View File

@ -11,8 +11,8 @@ namespace fastrt {
ModuleFactory() = default; ModuleFactory() = default;
~ModuleFactory() = default; ~ModuleFactory() = default;
std::unique_ptr<Module> createBackbone(const FastreidBackboneType& backbonetype); std::unique_ptr<Module> createBackbone(FastreidConfig& modelCfg);
std::unique_ptr<Module> createHead(const FastreidHeadType& headtype); std::unique_ptr<Module> createHead(FastreidConfig& modelCfg);
}; };
class LayerFactory { class LayerFactory {

View File

@ -13,6 +13,11 @@ namespace trtxapi {
ITensor& input, ITensor& input,
const float min); const float min);
ITensor* addDiv255(INetworkDefinition* network,
std::map<std::string, Weights>& weightMap,
ITensor* input,
const std::string lname);
ITensor* addMeanStd(INetworkDefinition* network, ITensor* addMeanStd(INetworkDefinition* network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor* input, ITensor* input,

View File

@ -19,7 +19,6 @@ namespace fastrt {
class Model { class Model {
public: public:
Model(const trt::ModelConfig &modelcfg, Model(const trt::ModelConfig &modelcfg,
const FastreidConfig &reidcfg,
const std::string input_name="input", const std::string input_name="input",
const std::string output_name="output"); const std::string output_name="output");
@ -67,7 +66,6 @@ namespace fastrt {
private: private:
DataType _dt{DataType::kFLOAT}; DataType _dt{DataType::kFLOAT};
trt::EngineConfig _engineCfg; trt::EngineConfig _engineCfg;
FastreidConfig _reidcfg;
std::unique_ptr<trt::InferenceEngine> _inferEngine{nullptr}; std::unique_ptr<trt::InferenceEngine> _inferEngine{nullptr};
}; };
} }

View File

@ -14,8 +14,7 @@ namespace fastrt {
virtual ILayer* topology(INetworkDefinition *network, virtual ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor& input, ITensor& input) = 0;
const FastreidConfig& reidCfg) = 0;
}; };
} }

View File

@ -9,43 +9,47 @@ using namespace nvinfer1;
namespace fastrt { namespace fastrt {
class backbone_sbsR34_distill : public Module { class backbone_sbsR34_distill : public Module {
private:
FastreidConfig& _modelCfg;
public: public:
backbone_sbsR34_distill() = default; backbone_sbsR34_distill(FastreidConfig& modelCfg) : _modelCfg(modelCfg) {}
~backbone_sbsR34_distill() = default; ~backbone_sbsR34_distill() = default;
ILayer* topology(INetworkDefinition *network, ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor& input, ITensor& input) override;
const FastreidConfig& reidCfg) override;
}; };
class backbone_sbsR50_distill : public Module { class backbone_sbsR50_distill : public Module {
private:
FastreidConfig& _modelCfg;
public: public:
backbone_sbsR50_distill() = default; backbone_sbsR50_distill(FastreidConfig& modelCfg) : _modelCfg(modelCfg) {}
~backbone_sbsR50_distill() = default; ~backbone_sbsR50_distill() = default;
ILayer* topology(INetworkDefinition *network, ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor& input, ITensor& input) override;
const FastreidConfig& reidCfg) override;
}; };
class backbone_sbsR34 : public Module { class backbone_sbsR34 : public Module {
private:
FastreidConfig& _modelCfg;
public: public:
backbone_sbsR34() = default; backbone_sbsR34(FastreidConfig& modelCfg) : _modelCfg(modelCfg) {}
~backbone_sbsR34() = default; ~backbone_sbsR34() = default;
ILayer* topology(INetworkDefinition *network, ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor& input, ITensor& input) override;
const FastreidConfig& reidCfg) override;
}; };
class backbone_sbsR50 : public Module { class backbone_sbsR50 : public Module {
private:
FastreidConfig& _modelCfg;
public: public:
backbone_sbsR50() = default; backbone_sbsR50(FastreidConfig& modelCfg) : _modelCfg(modelCfg) {}
~backbone_sbsR50() = default; ~backbone_sbsR50() = default;
ILayer* topology(INetworkDefinition *network, ILayer* topology(INetworkDefinition *network,
std::map<std::string, Weights>& weightMap, std::map<std::string, Weights>& weightMap,
ITensor& input, ITensor& input) override;
const FastreidConfig& reidCfg) override;
}; };
} }