diff --git a/projects/FastRT/fastrt/layers/layers.cpp b/projects/FastRT/fastrt/layers/layers.cpp index 236e054..3f6a467 100644 --- a/projects/FastRT/fastrt/layers/layers.cpp +++ b/projects/FastRT/fastrt/layers/layers.cpp @@ -14,20 +14,26 @@ namespace trtxapi { return clip; } + ITensor* addDiv255(INetworkDefinition* network, std::map& weightMap, ITensor* input, const std::string lname) { + Weights Div_225{ DataType::kFLOAT, nullptr, 3 }; + float *wgt = reinterpret_cast(malloc(sizeof(float) * 3)); + std::fill_n(wgt, 3, 255.0f); + Div_225.values = wgt; + weightMap[lname + ".div"] = Div_225; + IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225); + IElementWiseLayer* div255 = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV); + return div255->getOutput(0); + } + ITensor* addMeanStd(INetworkDefinition* network, std::map& weightMap, ITensor* input, const std::string lname, const float* mean, const float* std, const bool div255) { + ITensor* tensor_holder{input}; if (div255) { - Weights Div_225{ DataType::kFLOAT, nullptr, 3 }; - float *wgt = reinterpret_cast(malloc(sizeof(float) * 3)); - std::fill_n(wgt, 3, 255.0f); - Div_225.values = wgt; - weightMap[lname + ".div"] = Div_225; - IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225); - input = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV)->getOutput(0); + tensor_holder = addDiv255(network, weightMap, input, lname); } Weights Mean{ DataType::kFLOAT, nullptr, 3 }; Mean.values = mean; IConstantLayer* m = network->addConstant(Dims3{ 3, 1, 1 }, Mean); - IElementWiseLayer* sub_mean = network->addElementWise(*input, *m->getOutput(0), ElementWiseOperation::kSUB); + IElementWiseLayer* sub_mean = network->addElementWise(*tensor_holder, *m->getOutput(0), ElementWiseOperation::kSUB); if (std != nullptr) { Weights Std{ DataType::kFLOAT, nullptr, 3 }; Std.values = std; diff --git a/projects/FastRT/include/fastrt/layers.h b/projects/FastRT/include/fastrt/layers.h index 57a9e1d..dee21a0 100644 --- a/projects/FastRT/include/fastrt/layers.h +++ b/projects/FastRT/include/fastrt/layers.h @@ -13,6 +13,11 @@ namespace trtxapi { ITensor& input, const float min); + ITensor* addDiv255(INetworkDefinition* network, + std::map& weightMap, + ITensor* input, + const std::string lname); + ITensor* addMeanStd(INetworkDefinition* network, std::map& weightMap, ITensor* input,