[Fix] fix ncnn torch 1.12 dev-1.x (#1431)
* fix ncnn torch 1.12 dev-1.x * remove debug line * fix typo * add docstringpull/1397/head^2
parent
f61b6008e3
commit
2f2ddb3572
|
@ -1,6 +1,34 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "fuse_pass.h"
|
||||
|
||||
void fuse_identity(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count) {
|
||||
// fuse
|
||||
// identity --> op
|
||||
// to
|
||||
// noop_reducencnn --> op
|
||||
const int node_count = mutable_graph->node_size();
|
||||
for (int i = 0; i < node_count; ++i) {
|
||||
onnx::NodeProto* node = mutable_graph->mutable_node(i);
|
||||
for (int j = 0; j < node->input_size(); ++j) {
|
||||
std::string output_name = node->input(j);
|
||||
onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name);
|
||||
if (last_node && last_node->op_type() == "Identity") {
|
||||
node->set_input(j, last_node->input(0));
|
||||
node_reference[last_node->output(0)] -= 1;
|
||||
node_reference[last_node->input(0)] += 1;
|
||||
if (node_reference[last_node->output(0)] == 0) {
|
||||
last_node->set_op_type("noop_reducedncnn");
|
||||
node_reference[last_node->input(0)] -= 1;
|
||||
reduced_node_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
|
|
|
@ -4,6 +4,11 @@
|
|||
#include "shape_inference.h"
|
||||
#include "utils.h"
|
||||
|
||||
void fuse_identity(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
|
|
|
@ -206,6 +206,7 @@ int main(int argc, char** argv) {
|
|||
// op chain fusion
|
||||
int reduced_node_count = 0;
|
||||
{
|
||||
fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
|
|
Loading…
Reference in New Issue