[Fix] fix ncnn torch 1.12 dev-1.x (#1431)

* fix ncnn torch 1.12 dev-1.x

* remove debug line

* fix typo

* add docstring
pull/1397/head^2
hanrui1sensetime 2022-11-28 10:38:53 +08:00 committed by GitHub
parent f61b6008e3
commit 2f2ddb3572
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 0 deletions

View File

@ -1,6 +1,34 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "fuse_pass.h"
void fuse_identity(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count) {
// fuse
// identity --> op
// to
// noop_reducencnn --> op
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; ++i) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
for (int j = 0; j < node->input_size(); ++j) {
std::string output_name = node->input(j);
onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name);
if (last_node && last_node->op_type() == "Identity") {
node->set_input(j, last_node->input(0));
node_reference[last_node->output(0)] -= 1;
node_reference[last_node->input(0)] += 1;
if (node_reference[last_node->output(0)] == 0) {
last_node->set_op_type("noop_reducedncnn");
node_reference[last_node->input(0)] -= 1;
reduced_node_count += 1;
}
}
}
}
}
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,

View File

@ -4,6 +4,11 @@
#include "shape_inference.h"
#include "utils.h"
void fuse_identity(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,

View File

@ -206,6 +206,7 @@ int main(int argc, char** argv) {
// op chain fusion
int reduced_node_count = 0;
{
fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);