diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index df59d68b6..ca8cd628a 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -44,15 +44,11 @@ int main(int argc, char** argv) { fprintf(stderr, "read_proto_from_binary failed\n"); return -1; } - FILE* pp = fopen(ncnn_prototxt, "wb"); FILE* bp = fopen(ncnn_modelbin, "wb"); - // magic fprintf(pp, "7767517\n"); - onnx::GraphProto* mutable_graph = model.mutable_graph(); - int node_count = mutable_graph->node_size(); // node reference @@ -60,7 +56,6 @@ int main(int argc, char** argv) { // weight node and weight reshape node std::map<std::string, onnx::TensorProto> weights; - for (int j = 0; j < mutable_graph->initializer_size(); j++) { const onnx::TensorProto& initializer = mutable_graph->initializer(j); @@ -69,7 +64,6 @@ int main(int argc, char** argv) { weights[initializer.name()] = initializer; } - // topological sort { // name -> producer node index @@ -138,7 +132,6 @@ int main(int argc, char** argv) { *nodeq = tmp; } } - // global definition line // [layer count] [blob count] std::set<std::string> blob_names; @@ -184,7 +177,6 @@ int main(int argc, char** argv) { node_reference[output_name] = 0; } } - // include Input node int input_node_count = 0; for (int j = 0; j < mutable_graph->input_size(); j++) { @@ -232,7 +224,6 @@ int main(int argc, char** argv) { reduced_node_count); fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count); } - // reduce common const weight node_reference for (int i = 0; i < node_count; i++) { const onnx::NodeProto& node = mutable_graph->node(i); @@ -275,9 +266,11 @@ int main(int argc, char** argv) { int transB = get_node_attr_i(node, "transB", 0); if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C + // InnerProduct-like A * B + C, C is optional. node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; + if (node.input_size() == 3) { + node_reference[node.input(2)] -= 1; + } } } else if (op == "GroupNorm") { int affine = get_node_attr_i(node, "affine", 1); @@ -530,7 +523,6 @@ int main(int argc, char** argv) { for (int i = 0; i < node_count; i++) { const onnx::NodeProto& node = mutable_graph->node(i); - const std::string& op = node.op_type(); // fprintf(stderr, "op = %s\n", op.c_str()); @@ -1317,17 +1309,23 @@ int main(int argc, char** argv) { if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { // InnerProduct-like A * B + C const onnx::TensorProto& B = weights[node.input(1)]; - const onnx::TensorProto& C = weights[node.input(2)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(C)); - fprintf(pp, " 1=1"); + // B has transposed. + int num_output = B.dims(0); + fprintf(pp, " 0=%d", num_output); + if (node.input_size() == 3) { + fprintf(pp, " 1=1"); + } else { + fprintf(pp, " 1=0"); + } fprintf(pp, " 2=%d", get_tensor_proto_data_size(B)); int quantize_tag = 0; fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(B, bp); - fwrite_tensor_proto_data(C, bp); + if (node.input_size() == 3) { + const onnx::TensorProto& C = weights[node.input(2)]; + fwrite_tensor_proto_data(C, bp); + } } else { // gemm fprintf(pp, " 0=%e", alpha); @@ -2062,7 +2060,6 @@ int main(int argc, char** argv) { } else { pads = get_node_attr_from_input_ai(weights[node.input(1)]); } - int type = 0; if (mode == "constant") { type = 0; @@ -2193,15 +2190,16 @@ int main(int argc, char** argv) { fprintf(pp, " 1=%d", 0); fprintf(pp, " -23303=%zu", axes.size()); for (size_t j = 0; j < axes.size(); j++) { - if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3) + if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) fprintf(stderr, "Unsupported reduction axes !\n"); - fprintf(pp, ",%d", axes[j]); + fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); } } else { // if axes not set, reduce all axes by default fprintf(pp, " 1=%d", 1); } fprintf(pp, " 4=%d", keepdims); + fprintf(pp, " 5=1"); } else if (op == "Reorg") { int stride = get_node_attr_i(node, "stride", 1); fprintf(pp, " 0=%d", stride); @@ -2717,7 +2715,6 @@ int main(int argc, char** argv) { } fprintf(pp, "\n"); - for (int j = 0; j < output_size; j++) { const std::string& output_name = node.output(j); if (node_reference.find(output_name) != node_reference.end()) {