mirror of https://github.com/JDAI-CV/fast-reid.git
add trt api div255
parent
69eb044b81
commit
b9bda486f0
projects/FastRT
fastrt/layers
include/fastrt
|
@ -14,20 +14,26 @@ namespace trtxapi {
|
|||
return clip;
|
||||
}
|
||||
|
||||
ITensor* addDiv255(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname) {
|
||||
Weights Div_225{ DataType::kFLOAT, nullptr, 3 };
|
||||
float *wgt = reinterpret_cast<float*>(malloc(sizeof(float) * 3));
|
||||
std::fill_n(wgt, 3, 255.0f);
|
||||
Div_225.values = wgt;
|
||||
weightMap[lname + ".div"] = Div_225;
|
||||
IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225);
|
||||
IElementWiseLayer* div255 = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV);
|
||||
return div255->getOutput(0);
|
||||
}
|
||||
|
||||
ITensor* addMeanStd(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname, const float* mean, const float* std, const bool div255) {
|
||||
ITensor* tensor_holder{input};
|
||||
if (div255) {
|
||||
Weights Div_225{ DataType::kFLOAT, nullptr, 3 };
|
||||
float *wgt = reinterpret_cast<float*>(malloc(sizeof(float) * 3));
|
||||
std::fill_n(wgt, 3, 255.0f);
|
||||
Div_225.values = wgt;
|
||||
weightMap[lname + ".div"] = Div_225;
|
||||
IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225);
|
||||
input = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV)->getOutput(0);
|
||||
tensor_holder = addDiv255(network, weightMap, input, lname);
|
||||
}
|
||||
Weights Mean{ DataType::kFLOAT, nullptr, 3 };
|
||||
Mean.values = mean;
|
||||
IConstantLayer* m = network->addConstant(Dims3{ 3, 1, 1 }, Mean);
|
||||
IElementWiseLayer* sub_mean = network->addElementWise(*input, *m->getOutput(0), ElementWiseOperation::kSUB);
|
||||
IElementWiseLayer* sub_mean = network->addElementWise(*tensor_holder, *m->getOutput(0), ElementWiseOperation::kSUB);
|
||||
if (std != nullptr) {
|
||||
Weights Std{ DataType::kFLOAT, nullptr, 3 };
|
||||
Std.values = std;
|
||||
|
|
|
@ -13,6 +13,11 @@ namespace trtxapi {
|
|||
ITensor& input,
|
||||
const float min);
|
||||
|
||||
ITensor* addDiv255(INetworkDefinition* network,
|
||||
std::map<std::string, Weights>& weightMap,
|
||||
ITensor* input,
|
||||
const std::string lname);
|
||||
|
||||
ITensor* addMeanStd(INetworkDefinition* network,
|
||||
std::map<std::string, Weights>& weightMap,
|
||||
ITensor* input,
|
||||
|
|
Loading…
Reference in New Issue