add trt api div255

pull/416/head
darrenhsieh 2021-02-27 16:45:29 +08:00
parent 69eb044b81
commit b9bda486f0
2 changed files with 19 additions and 8 deletions
projects/FastRT
fastrt/layers
include/fastrt

View File

@ -14,20 +14,26 @@ namespace trtxapi {
return clip;
}
ITensor* addDiv255(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname) {
Weights Div_225{ DataType::kFLOAT, nullptr, 3 };
float *wgt = reinterpret_cast<float*>(malloc(sizeof(float) * 3));
std::fill_n(wgt, 3, 255.0f);
Div_225.values = wgt;
weightMap[lname + ".div"] = Div_225;
IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225);
IElementWiseLayer* div255 = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV);
return div255->getOutput(0);
}
ITensor* addMeanStd(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor* input, const std::string lname, const float* mean, const float* std, const bool div255) {
ITensor* tensor_holder{input};
if (div255) {
Weights Div_225{ DataType::kFLOAT, nullptr, 3 };
float *wgt = reinterpret_cast<float*>(malloc(sizeof(float) * 3));
std::fill_n(wgt, 3, 255.0f);
Div_225.values = wgt;
weightMap[lname + ".div"] = Div_225;
IConstantLayer* d = network->addConstant(Dims3{ 3, 1, 1 }, Div_225);
input = network->addElementWise(*input, *d->getOutput(0), ElementWiseOperation::kDIV)->getOutput(0);
tensor_holder = addDiv255(network, weightMap, input, lname);
}
Weights Mean{ DataType::kFLOAT, nullptr, 3 };
Mean.values = mean;
IConstantLayer* m = network->addConstant(Dims3{ 3, 1, 1 }, Mean);
IElementWiseLayer* sub_mean = network->addElementWise(*input, *m->getOutput(0), ElementWiseOperation::kSUB);
IElementWiseLayer* sub_mean = network->addElementWise(*tensor_holder, *m->getOutput(0), ElementWiseOperation::kSUB);
if (std != nullptr) {
Weights Std{ DataType::kFLOAT, nullptr, 3 };
Std.values = std;

View File

@ -13,6 +13,11 @@ namespace trtxapi {
ITensor& input,
const float min);
ITensor* addDiv255(INetworkDefinition* network,
std::map<std::string, Weights>& weightMap,
ITensor* input,
const std::string lname);
ITensor* addMeanStd(INetworkDefinition* network,
std::map<std::string, Weights>& weightMap,
ITensor* input,