mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Fix onnx2ncnn.cpp bugs. (#1518)
* fix onnx2ncnn * remove unused debugging info
This commit is contained in:
parent
d6fdb3e860
commit
7f2e8f7ce0
@ -44,15 +44,11 @@ int main(int argc, char** argv) {
|
|||||||
fprintf(stderr, "read_proto_from_binary failed\n");
|
fprintf(stderr, "read_proto_from_binary failed\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
FILE* pp = fopen(ncnn_prototxt, "wb");
|
FILE* pp = fopen(ncnn_prototxt, "wb");
|
||||||
FILE* bp = fopen(ncnn_modelbin, "wb");
|
FILE* bp = fopen(ncnn_modelbin, "wb");
|
||||||
|
|
||||||
// magic
|
// magic
|
||||||
fprintf(pp, "7767517\n");
|
fprintf(pp, "7767517\n");
|
||||||
|
|
||||||
onnx::GraphProto* mutable_graph = model.mutable_graph();
|
onnx::GraphProto* mutable_graph = model.mutable_graph();
|
||||||
|
|
||||||
int node_count = mutable_graph->node_size();
|
int node_count = mutable_graph->node_size();
|
||||||
|
|
||||||
// node reference
|
// node reference
|
||||||
@ -60,7 +56,6 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
// weight node and weight reshape node
|
// weight node and weight reshape node
|
||||||
std::map<std::string, onnx::TensorProto> weights;
|
std::map<std::string, onnx::TensorProto> weights;
|
||||||
|
|
||||||
for (int j = 0; j < mutable_graph->initializer_size(); j++) {
|
for (int j = 0; j < mutable_graph->initializer_size(); j++) {
|
||||||
const onnx::TensorProto& initializer = mutable_graph->initializer(j);
|
const onnx::TensorProto& initializer = mutable_graph->initializer(j);
|
||||||
|
|
||||||
@ -69,7 +64,6 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
weights[initializer.name()] = initializer;
|
weights[initializer.name()] = initializer;
|
||||||
}
|
}
|
||||||
|
|
||||||
// topological sort
|
// topological sort
|
||||||
{
|
{
|
||||||
// name -> producer node index
|
// name -> producer node index
|
||||||
@ -138,7 +132,6 @@ int main(int argc, char** argv) {
|
|||||||
*nodeq = tmp;
|
*nodeq = tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// global definition line
|
// global definition line
|
||||||
// [layer count] [blob count]
|
// [layer count] [blob count]
|
||||||
std::set<std::string> blob_names;
|
std::set<std::string> blob_names;
|
||||||
@ -184,7 +177,6 @@ int main(int argc, char** argv) {
|
|||||||
node_reference[output_name] = 0;
|
node_reference[output_name] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// include Input node
|
// include Input node
|
||||||
int input_node_count = 0;
|
int input_node_count = 0;
|
||||||
for (int j = 0; j < mutable_graph->input_size(); j++) {
|
for (int j = 0; j < mutable_graph->input_size(); j++) {
|
||||||
@ -232,7 +224,6 @@ int main(int argc, char** argv) {
|
|||||||
reduced_node_count);
|
reduced_node_count);
|
||||||
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce common const weight node_reference
|
// reduce common const weight node_reference
|
||||||
for (int i = 0; i < node_count; i++) {
|
for (int i = 0; i < node_count; i++) {
|
||||||
const onnx::NodeProto& node = mutable_graph->node(i);
|
const onnx::NodeProto& node = mutable_graph->node(i);
|
||||||
@ -275,10 +266,12 @@ int main(int argc, char** argv) {
|
|||||||
int transB = get_node_attr_i(node, "transB", 0);
|
int transB = get_node_attr_i(node, "transB", 0);
|
||||||
|
|
||||||
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
|
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
|
||||||
// InnerProduct-like A * B + C
|
// InnerProduct-like A * B + C, C is optional.
|
||||||
node_reference[node.input(1)] -= 1;
|
node_reference[node.input(1)] -= 1;
|
||||||
|
if (node.input_size() == 3) {
|
||||||
node_reference[node.input(2)] -= 1;
|
node_reference[node.input(2)] -= 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else if (op == "GroupNorm") {
|
} else if (op == "GroupNorm") {
|
||||||
int affine = get_node_attr_i(node, "affine", 1);
|
int affine = get_node_attr_i(node, "affine", 1);
|
||||||
if (affine) {
|
if (affine) {
|
||||||
@ -530,7 +523,6 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
for (int i = 0; i < node_count; i++) {
|
for (int i = 0; i < node_count; i++) {
|
||||||
const onnx::NodeProto& node = mutable_graph->node(i);
|
const onnx::NodeProto& node = mutable_graph->node(i);
|
||||||
|
|
||||||
const std::string& op = node.op_type();
|
const std::string& op = node.op_type();
|
||||||
|
|
||||||
// fprintf(stderr, "op = %s\n", op.c_str());
|
// fprintf(stderr, "op = %s\n", op.c_str());
|
||||||
@ -1317,17 +1309,23 @@ int main(int argc, char** argv) {
|
|||||||
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
|
if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) {
|
||||||
// InnerProduct-like A * B + C
|
// InnerProduct-like A * B + C
|
||||||
const onnx::TensorProto& B = weights[node.input(1)];
|
const onnx::TensorProto& B = weights[node.input(1)];
|
||||||
const onnx::TensorProto& C = weights[node.input(2)];
|
// B has transposed.
|
||||||
|
int num_output = B.dims(0);
|
||||||
fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
|
fprintf(pp, " 0=%d", num_output);
|
||||||
|
if (node.input_size() == 3) {
|
||||||
fprintf(pp, " 1=1");
|
fprintf(pp, " 1=1");
|
||||||
|
} else {
|
||||||
|
fprintf(pp, " 1=0");
|
||||||
|
}
|
||||||
fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
|
fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
|
||||||
|
|
||||||
int quantize_tag = 0;
|
int quantize_tag = 0;
|
||||||
fwrite(&quantize_tag, sizeof(int), 1, bp);
|
fwrite(&quantize_tag, sizeof(int), 1, bp);
|
||||||
|
|
||||||
fwrite_tensor_proto_data(B, bp);
|
fwrite_tensor_proto_data(B, bp);
|
||||||
|
if (node.input_size() == 3) {
|
||||||
|
const onnx::TensorProto& C = weights[node.input(2)];
|
||||||
fwrite_tensor_proto_data(C, bp);
|
fwrite_tensor_proto_data(C, bp);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// gemm
|
// gemm
|
||||||
fprintf(pp, " 0=%e", alpha);
|
fprintf(pp, " 0=%e", alpha);
|
||||||
@ -2062,7 +2060,6 @@ int main(int argc, char** argv) {
|
|||||||
} else {
|
} else {
|
||||||
pads = get_node_attr_from_input_ai(weights[node.input(1)]);
|
pads = get_node_attr_from_input_ai(weights[node.input(1)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int type = 0;
|
int type = 0;
|
||||||
if (mode == "constant") {
|
if (mode == "constant") {
|
||||||
type = 0;
|
type = 0;
|
||||||
@ -2193,15 +2190,16 @@ int main(int argc, char** argv) {
|
|||||||
fprintf(pp, " 1=%d", 0);
|
fprintf(pp, " 1=%d", 0);
|
||||||
fprintf(pp, " -23303=%zu", axes.size());
|
fprintf(pp, " -23303=%zu", axes.size());
|
||||||
for (size_t j = 0; j < axes.size(); j++) {
|
for (size_t j = 0; j < axes.size(); j++) {
|
||||||
if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
|
if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3)
|
||||||
fprintf(stderr, "Unsupported reduction axes !\n");
|
fprintf(stderr, "Unsupported reduction axes !\n");
|
||||||
fprintf(pp, ",%d", axes[j]);
|
fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if axes not set, reduce all axes by default
|
// if axes not set, reduce all axes by default
|
||||||
fprintf(pp, " 1=%d", 1);
|
fprintf(pp, " 1=%d", 1);
|
||||||
}
|
}
|
||||||
fprintf(pp, " 4=%d", keepdims);
|
fprintf(pp, " 4=%d", keepdims);
|
||||||
|
fprintf(pp, " 5=1");
|
||||||
} else if (op == "Reorg") {
|
} else if (op == "Reorg") {
|
||||||
int stride = get_node_attr_i(node, "stride", 1);
|
int stride = get_node_attr_i(node, "stride", 1);
|
||||||
fprintf(pp, " 0=%d", stride);
|
fprintf(pp, " 0=%d", stride);
|
||||||
@ -2717,7 +2715,6 @@ int main(int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fprintf(pp, "\n");
|
fprintf(pp, "\n");
|
||||||
|
|
||||||
for (int j = 0; j < output_size; j++) {
|
for (int j = 0; j < output_size; j++) {
|
||||||
const std::string& output_name = node.output(j);
|
const std::string& output_name = node.output(j);
|
||||||
if (node_reference.find(output_name) != node_reference.end()) {
|
if (node_reference.find(output_name) != node_reference.end()) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user