// Copyright (c) OpenMMLab. All rights reserved. #include "fuse_pass.h" void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // weight <= Reshape(weight) if (node->op_type() == "Reshape") { // check weight if (weights.find(node->input(0)) == weights.end()) continue; weights[node->output(0)] = weights[node->input(0)]; // set weight shape directly std::vector shape; if (node->input_size() == 1) { shape = get_node_attr_ai(*node, "shape"); } else if (node->input_size() == 2) { // opset 5 shape = get_node_attr_from_input_ai(weights[node->input(1)]); } weights[node->output(0)].clear_dims(); for (int j = 0; j < shape.size(); j++) { weights[node->output(0)].add_dims(shape[j]); } // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; if (node->input_size() == 2) { node_reference[node->input(1)] -= 1; } reduced_node_count += 1; i += 1; } } } void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // weight <= Transpose(weight) if (node->op_type() == "Transpose") { // check weight if (weights.find(node->input(0)) == weights.end()) continue; if (weights[node->input(0)].dims_size() != 2) continue; // perm = (1, 0) std::vector perm = get_node_attr_ai(*node, "perm"); if (perm.size() != 2) continue; if (perm[0] != 1 || perm[1] != 0) continue; weights[node->output(0)] = weights[node->input(0)]; // permute weight { onnx::TensorProto& B = weights[node->output(0)]; const int h = B.dims(0); const int w = B.dims(1); std::vector permuted_data; permuted_data.reserve((size_t)h * w); const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); for (int j = 0; j < w; j++) { for (int k = 0; k < h; k++) { float vb = bptr[k * w + j]; permuted_data.push_back(vb); } } B.set_dims(0, w); B.set_dims(1, h); if (B.has_raw_data()) { B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float)); } else { for (int j = 0; j < (int)permuted_data.size(); j++) B.set_float_data(j, permuted_data[j]); } } // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; reduced_node_count += 1; i += 1; } } } void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // ShuffleChannel <= Reshape - Transpose - Reshape // ShuffleChannel <= Reshape - Transpose - Constant - Reshape if (node->op_type() == "Reshape") { if (node_reference[node->output(0)] != 1) continue; std::vector shape; if (node->input_size() == 1) { shape = get_node_attr_ai(*node, "shape"); } else { // skip weight reshape if (weights.find(node->input(1)) == weights.end()) continue; shape = get_node_attr_from_input_ai(weights[node->input(1)]); } // 1 groups channels_per_group, height, width // reverse style = channels_per_group, groups, height * width if (shape.size() != 5 && shape.size() != 3) continue; if (shape.size() == 5 && shape[0] != 1) continue; if (i + 2 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); if (node3->op_type() == "Constant") { if (i + 3 >= node_count) continue; node3 = mutable_graph->mutable_node(i + 3); } if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; if (node_reference[node2->output(0)] != 1) continue; // 0 2 1 3 4 // reverse style = 1 0 2 std::vector perm = get_node_attr_ai(*node2, "perm"); if (perm.size() != 5 && perm.size() != 3) continue; if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4)) continue; if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)) continue; std::vector shape3; if (node3->input_size() == 1) { shape3 = get_node_attr_ai(*node3, "shape"); } else { // skip weight reshape if (weights.find(node3->input(1)) == weights.end()) continue; shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); } // 1, -1, height, width // reverse style = group, -1, channels_per_group, height, width if (shape3.size() != 4 && shape3.size() != 5) continue; if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2]))) continue; if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2])) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); if (node->input_size() == 2) { node_reference[node->input(1)] -= 1; } node_reference[node->output(0)] -= 1; node_reference[node2->output(0)] -= 1; if (node3->input_size() == 2) { node_reference[node3->input(1)] -= 1; } blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); node3->set_op_type("ShuffleChannel"); node3->set_input(0, node->input(0)); onnx::AttributeProto* attr_group = node3->add_attribute(); attr_group->set_name("group"); attr_group->set_i(shape[1]); onnx::AttributeProto* attr_reverse = node3->add_attribute(); attr_reverse->set_name("reverse"); attr_reverse->set_i(shape.size() == 3); reduced_node_count += 2; i += 2; } } } void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) if (node->op_type() == "ShuffleChannel") { // reverse = 1 int reverse = get_node_attr_i(*node, "reverse"); if (reverse != 1) continue; if (i + 2 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); if (node2->op_type() != "Gather" || node3->op_type() != "Gather") continue; if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) continue; // axis = 0 int gather2_axis = get_node_attr_i(*node2, "axis"); if (gather2_axis != 0) continue; // indices = 0 if (weights.find(node2->input(1)) == weights.end()) continue; std::vector gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); if (gather2_indices.size() != 1 || gather2_indices[0] != 0) continue; // axis = 0 int gather3_axis = get_node_attr_i(*node3, "axis"); if (gather3_axis != 0) continue; // indices = 1 if (weights.find(node3->input(1)) == weights.end()) continue; std::vector gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); if (gather3_indices.size() != 1 || gather3_indices[0] != 1) continue; // reduce node2->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 2; node_reference[node2->input(1)] -= 1; node_reference[node3->input(1)] -= 1; node3->set_op_type("Split"); node3->clear_input(); node3->add_input(node->output(0)); node3->add_output(node3->output(0)); node3->set_output(0, node2->output(0)); node3->clear_attribute(); onnx::AttributeProto* attr_axis = node3->add_attribute(); attr_axis->set_name("axis"); attr_axis->set_i(1); reduced_node_count += 1; i += 1; } } } /** * @brief fuse subgraph * * conv - - - - - - - - - - - -> reshape * \ / * shape - slice - concat * * to * * conv --> reshape * * @param mutable_graph * @param weights * @param node_reference * @param blob_names * @param reduced_node_count */ void fuse_conv_reshape(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { std::map> shape_context; const int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* conv = mutable_graph->mutable_node(i); if (conv->op_type() != "Conv") { continue; } if (i + 4 >= node_count) { continue; } onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr; // match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant std::vector> candidates = { {"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}}; int MAX = std::min(10, node_count - i - 1); int pos_candidate = 0; for (int j = 0; j < MAX; ++j) { auto node_ptr = mutable_graph->mutable_node(j + i + 1); if (node_ptr->op_type() == "Constant") { continue; } if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) { *(std::get<1>(candidates[pos_candidate])) = node_ptr; pos_candidate++; } } if (pos_candidate != candidates.size()) { // not match the sequence continue; } if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 || node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 || node_reference[reshape->output(0)] != 1) { continue; } // check the connections if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) { continue; } if (slice->input(0) != shape->output(0)) { continue; } if (concat->input(0) != slice->output(0)) { continue; } if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) { continue; } // add reshape attr auto result = query_shape(mutable_graph, concat, weights, shape_context); if (!std::get<0>(result)) { continue; } set_node_attr_ai(*reshape, "shape", std::get<1>(result)); // reconstruct graph { // remove reference node_reference[reshape->input(1)] -= 1; node_reference[concat->input(0)] -= 1; node_reference[slice->input(0)] -= 1; node_reference[shape->input(0)] -= 1; // remove tensor/blob on edge blob_names.erase(slice->input(0)); blob_names.erase(slice->input(1)); blob_names.erase(slice->input(2)); blob_names.erase(slice->input(3)); weights.erase(slice->input(1)); weights.erase(slice->input(2)); weights.erase(slice->input(3)); blob_names.erase(concat->input(0)); blob_names.erase(concat->input(1)); weights.erase(concat->input(1)); blob_names.erase(reshape->input(0)); // update edge shape->clear_input(); reshape->clear_input(); reshape->add_input(conv->output(0)); shape->set_op_type("noop_reducedncnn"); slice->set_op_type("noop_reducedncnn"); concat->set_op_type("noop_reducedncnn"); reduced_node_count += 3; } i += 3; } } void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // Add/Sub/Mul/Div/Min/Max/Pow if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow") { if (weights.find(node->input(1)) == weights.end()) continue; const onnx::TensorProto& scalar_b = weights[node->input(1)]; if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue; float b = get_node_attr_from_input(scalar_b); node_reference[node->input(1)] -= 1; std::string input = node->input(0); node->clear_input(); node->add_input(input); onnx::AttributeProto* attr_with_scalar = node->add_attribute(); attr_with_scalar->set_name("with_scalar"); attr_with_scalar->set_i(1); onnx::AttributeProto* attr_b = node->add_attribute(); attr_b->set_name("b"); attr_b->set_f(b); } } } void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6) // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6)) // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6) // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6)) // out = x * F.relu6(x + 3, inplace=True) / 6 if (node->op_type() == "Add") { if (node_reference[node->output(0)] != 1) continue; if (i + 3 >= node_count) continue; if (weights.find(node->input(1)) == weights.end()) continue; const onnx::TensorProto& add_three = weights[node->input(1)]; if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; float constant_add_three = get_node_attr_from_input(add_three); if (constant_add_three != 3.f) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); if (node4->op_type() == "Constant") { if (i + 4 >= node_count) continue; node4 = mutable_graph->mutable_node(i + 4); } if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul")) continue; if (node_reference[node2->output(0)] != 1) continue; float relu6_min; float relu6_max; if (node2->input_size() == 1) { relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); } else { const onnx::TensorProto& min_tp = weights[node2->input(1)]; const onnx::TensorProto& max_tp = weights[node2->input(2)]; relu6_min = get_node_attr_from_input(min_tp); relu6_max = get_node_attr_from_input(max_tp); } if (relu6_min != 0.f || relu6_max != 6.f) continue; if (node_reference[node3->output(0)] != 1) continue; if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue; if (weights.find(node4->input(1)) == weights.end()) continue; const onnx::TensorProto& div_six = weights[node4->input(1)]; if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; float constant_div_six = get_node_attr_from_input(div_six); if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; node_reference[node->input(1)] -= 1; node_reference[node->output(0)] -= 1; if (node2->input_size() == 3) { node_reference[node2->input(1)] -= 1; node_reference[node2->input(2)] -= 1; } node_reference[node2->output(0)] -= 1; node_reference[node3->output(0)] -= 1; node_reference[node4->input(1)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); blob_names.erase(node3->output(0)); node4->set_op_type("HardSwish"); node4->clear_input(); node4->add_input(node->input(0)); onnx::AttributeProto* attr_alpha = node4->add_attribute(); attr_alpha->set_name("alpha"); attr_alpha->set_f(1.f / 6.f); onnx::AttributeProto* attr_beta = node4->add_attribute(); attr_beta->set_name("beta"); attr_beta->set_f(3.f / 6.f); reduced_node_count += 3; i += 3; } } for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // HardSwish <= HardSigmoid - Mul // out = x * hsigmoid(x) if (node->op_type() == "HardSigmoid") { if (node_reference[node->output(0)] != 1) continue; float alpha = get_node_attr_f(*node, "alpha", 0.2f); float beta = get_node_attr_f(*node, "beta", 0.5f); if (i + 1 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); if (node2->op_type() != "Mul") continue; if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node2->set_op_type("HardSwish"); node2->clear_input(); node2->add_input(node->input(0)); onnx::AttributeProto* attr_alpha = node2->add_attribute(); attr_alpha->set_name("alpha"); attr_alpha->set_f(alpha); onnx::AttributeProto* attr_beta = node2->add_attribute(); attr_beta->set_name("beta"); attr_beta->set_f(beta); reduced_node_count += 1; i += 1; } } } void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6) // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6)) // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6) // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6)) // out = F.relu6(x + 3, inplace=True) / 6 if (node->op_type() == "Add") { if (node_reference[node->output(0)] != 1) continue; if (i + 2 >= node_count) continue; if (weights.find(node->input(1)) == weights.end()) continue; const onnx::TensorProto& add_three = weights[node->input(1)]; if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; float constant_add_three = get_node_attr_from_input(add_three); if (constant_add_three != 3.f) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); if (node3->op_type() == "Constant") { if (i + 3 >= node_count) continue; node3 = mutable_graph->mutable_node(i + 3); } if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul")) continue; if (node_reference[node2->output(0)] != 1) continue; float relu6_min; float relu6_max; if (node2->input_size() == 1) { relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); } else { const onnx::TensorProto& min_tp = weights[node2->input(1)]; const onnx::TensorProto& max_tp = weights[node2->input(2)]; relu6_min = get_node_attr_from_input(min_tp); relu6_max = get_node_attr_from_input(max_tp); } if (relu6_min != 0.f || relu6_max != 6.f) continue; if (weights.find(node3->input(1)) == weights.end()) continue; const onnx::TensorProto& div_six = weights[node3->input(1)]; if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; float constant_div_six = get_node_attr_from_input(div_six); if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); node_reference[node->input(1)] -= 1; node_reference[node->output(0)] -= 1; if (node2->input_size() == 3) { node_reference[node2->input(1)] -= 1; node_reference[node2->input(2)] -= 1; } node_reference[node2->output(0)] -= 1; node_reference[node3->input(1)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); node3->set_op_type("HardSigmoid"); node3->clear_input(); node3->add_input(node->input(0)); onnx::AttributeProto* attr_alpha = node3->add_attribute(); attr_alpha->set_name("alpha"); attr_alpha->set_f(1.f / 6.f); onnx::AttributeProto* attr_beta = node3->add_attribute(); attr_beta->set_name("beta"); attr_beta->set_f(3.f / 6.f); reduced_node_count += 2; i += 2; } } } void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // Swish <= Sigmoid - Mul // x * torch.sigmoid(x) if (node->op_type() == "Sigmoid") { if (node_reference[node->output(0)] != 1) continue; if (i + 1 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); if (node2->op_type() != "Mul") continue; if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node2->set_op_type("Swish"); node2->clear_input(); node2->add_input(node->input(0)); reduced_node_count += 1; i += 1; } } } void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze if (node->op_type() == "Unsqueeze") { if (node_reference[node->output(0)] != 1) continue; if (i + 2 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue; if (node_reference[node2->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; // reduce node->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; node_reference[node2->output(0)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); node2->set_input(0, node->input(0)); node2->set_output(0, node3->output(0)); reduced_node_count += 2; i += 2; } } } void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // PReLU <= Unsqueeze - PReLU if (node->op_type() == "Unsqueeze") { // check weight if (weights.find(node->input(0)) == weights.end()) continue; onnx::TensorProto& B = weights[node->input(0)]; if (B.dims_size() != 1) continue; if (node_reference[node->output(0)] != 1) continue; // axes = (1, 2) std::vector axes = get_node_attr_ai(*node, "axes"); if (axes.size() != 2) continue; if (axes[0] != 1 || axes[1] != 2) continue; if (i + 1 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); if (node2->op_type() != "PRelu") continue; if (node2->input(1) != node->output(0)) continue; // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node2->set_input(1, node->input(0)); reduced_node_count += 1; i += 1; } } } void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // Normalize <= X - ReduceL2 - Clip - Expand - Div // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div if (node->op_type() == "ReduceL2") { if (node_reference[node->output(0)] != 1) continue; // axes = (1) std::vector axes = get_node_attr_ai(*node, "axes"); if (axes.size() != 1) continue; if (axes[0] != 1) continue; if (i + 3 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); bool has_shape_node = node3->op_type() == "Shape"; onnx::NodeProto* node_shape = 0; if (has_shape_node) { if (i + 4 >= node_count) continue; node_shape = node3; node3 = mutable_graph->mutable_node(i + 3); node4 = mutable_graph->mutable_node(i + 4); } if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div") continue; if (node_reference[node2->output(0)] != 1) continue; if (node_reference[node3->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0)) continue; if (has_shape_node) { if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0)) continue; } // +eps float clip_min; if (node2->input_size() == 1) { clip_min = get_node_attr_f(*node2, "min", -FLT_MAX); } else { const onnx::TensorProto& min_tp = weights[node2->input(1)]; clip_min = get_node_attr_from_input(min_tp); } // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); if (has_shape_node) { node_shape->set_op_type("noop_reducedncnn"); } node3->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= has_shape_node ? 2 : 1; node_reference[node->output(0)] -= 1; node_reference[node2->output(0)] -= 1; if (has_shape_node) { node_reference[node_shape->output(0)] -= 1; } node_reference[node3->output(0)] -= 1; if (node3->input_size() == 2) { node_reference[node3->input(1)] -= 1; } blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); if (has_shape_node) { blob_names.erase(node_shape->output(0)); } blob_names.erase(node3->output(0)); node4->set_op_type("Normalize"); node4->clear_input(); node4->add_input(node->input(0)); onnx::AttributeProto* attr_alpha = node4->add_attribute(); attr_alpha->set_name("eps"); attr_alpha->set_f(clip_min); reduced_node_count += has_shape_node ? 4 : 3; i += has_shape_node ? 4 : 3; } } } void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add if (node->op_type() == "Reshape") { if (node_reference[node->output(0)] != 1) continue; std::vector shape; if (node->input_size() == 1) { shape = get_node_attr_ai(*node, "shape"); } else { // skip weight reshape if (weights.find(node->input(1)) == weights.end()) continue; shape = get_node_attr_from_input_ai(weights[node->input(1)]); } // 0, group, -1 if (shape.size() != 3) continue; if (shape[0] != 0 || shape[2] != -1) continue; int groups = shape[1]; if (i + 4 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add") continue; if (node_reference[node2->output(0)] != 1) continue; if (node_reference[node3->output(0)] != 1) continue; if (node_reference[node4->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0)) continue; // +eps float eps = get_node_attr_f(*node2, "epsilon", 1e-05f); // InstanceNormalization S=1 B=0 std::vector S = get_node_attr_from_input_af(weights[node2->input(1)]); std::vector B = get_node_attr_from_input_af(weights[node2->input(2)]); if ((int)S.size() != groups || (int)B.size() != groups) continue; bool instancenorm_affine = false; for (int j = 0; j < groups; j++) { if (S[j] != 1.f || B[j] != 0.f) { instancenorm_affine = true; break; } } if (instancenorm_affine) continue; std::vector shape2; if (node3->input_size() == 1) { shape2 = get_node_attr_ai(*node3, "shape"); } else { // skip weight reshape if (weights.find(node3->input(1)) == weights.end()) continue; shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]); } // 1, channels, w, h if (shape2.size() != 4) continue; if (shape2[0] != 1) continue; int channels = shape2[1]; // affine std::vector affine_S = get_node_attr_from_input_af(weights[node4->input(1)]); std::vector affine_B = get_node_attr_from_input_af(weights[node5->input(1)]); if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f) { // no affine } else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels) { // we only allow per-channel affine continue; } // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node4->set_op_type("noop_reducedncnn"); if (node->input_size() == 2) { node_reference[node->input(1)] -= 1; } node_reference[node->output(0)] -= 1; node_reference[node2->input(1)] -= 1; node_reference[node2->input(2)] -= 1; node_reference[node2->output(0)] -= 1; if (node3->input_size() == 2) { node_reference[node3->input(1)] -= 1; } node_reference[node3->output(0)] -= 1; node_reference[node4->output(0)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); blob_names.erase(node3->output(0)); blob_names.erase(node4->output(0)); std::string affine_scale = node4->input(1); std::string affine_bias = node5->input(1); node5->set_op_type("GroupNorm"); node5->clear_input(); node5->add_input(node->input(0)); node5->add_input(affine_scale); node5->add_input(affine_bias); onnx::AttributeProto* attr_groups = node5->add_attribute(); attr_groups->set_name("groups"); attr_groups->set_i(groups); onnx::AttributeProto* attr_channels = node5->add_attribute(); attr_channels->set_name("channels"); attr_channels->set_i(channels); onnx::AttributeProto* attr_eps = node5->add_attribute(); attr_eps->set_name("epsilon"); attr_eps->set_f(eps); onnx::AttributeProto* attr_affine = node5->add_attribute(); attr_affine->set_name("affine"); attr_affine->set_i(1); reduced_node_count += 4; i += 4; } } } void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - // Mul - Add if (node->op_type() == "ReduceMean") { if (node_reference[node->output(0)] != 1) continue; std::vector axes = get_node_attr_ai(*node, "axes"); // -1 // -2 -1 if (axes.size() != 1 && axes.size() != 2) continue; int normed_axes = (int)axes.size(); if (normed_axes == 1 && axes[0] != -1) continue; if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1)) continue; if (i + 6 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div") continue; if (node_reference[node2->output(0)] != 2) continue; if (node_reference[node3->output(0)] != 1) continue; if (node_reference[node4->output(0)] != 1) continue; if (node_reference[node5->output(0)] != 1) continue; if (node_reference[node6->output(0)] != 1) continue; if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0)) continue; if (weights.find(node3->input(1)) == weights.end()) continue; const onnx::TensorProto& pow_two = weights[node3->input(1)]; if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue; float constant_pow_two = get_node_attr_from_input(pow_two); if (constant_pow_two != 2.f) continue; std::vector axes4 = get_node_attr_ai(*node4, "axes"); // -1 // -2 -1 if ((int)axes4.size() != normed_axes) continue; if (normed_axes == 1 && axes4[0] != -1) continue; if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1)) continue; if (weights.find(node5->input(1)) == weights.end()) continue; const onnx::TensorProto& add_eps = weights[node5->input(1)]; if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue; float eps = get_node_attr_from_input(add_eps); int affine = 0; while (i + 8 < node_count) { onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); if (node8->op_type() != "Mul" || node9->op_type() != "Add") break; if (node_reference[node7->output(0)] != 1) break; if (node_reference[node8->output(0)] != 1) break; if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0)) break; // affine std::vector affine_S = get_node_attr_from_input_af(weights[node8->input(1)]); std::vector affine_B = get_node_attr_from_input_af(weights[node9->input(1)]); if (affine_S.size() != affine_B.size()) break; affine = 1; break; } // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node4->set_op_type("noop_reducedncnn"); node5->set_op_type("noop_reducedncnn"); node6->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; node_reference[node2->input(0)] -= 1; node_reference[node2->input(1)] -= 1; node_reference[node3->input(0)] -= 1; node_reference[node3->input(1)] -= 1; node_reference[node4->input(0)] -= 1; node_reference[node5->input(0)] -= 1; node_reference[node5->input(1)] -= 1; node_reference[node6->input(0)] -= 1; node_reference[node7->input(0)] -= 1; node_reference[node7->input(1)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); blob_names.erase(node3->output(0)); blob_names.erase(node4->output(0)); blob_names.erase(node5->output(0)); blob_names.erase(node6->output(0)); node_reference[node->input(0)] += 1; if (affine == 0) { node7->set_op_type("LayerNorm"); node7->clear_input(); node7->add_input(node->input(0)); onnx::AttributeProto* attr_eps = node7->add_attribute(); attr_eps->set_name("epsilon"); attr_eps->set_f(eps); onnx::AttributeProto* attr_affine = node7->add_attribute(); attr_affine->set_name("affine"); attr_affine->set_i(affine); reduced_node_count += 6; i += 6; } else // if (affine == 1) { onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); node7->set_op_type("noop_reducedncnn"); node8->set_op_type("noop_reducedncnn"); node_reference[node8->input(0)] -= 1; node_reference[node9->input(0)] -= 1; blob_names.erase(node7->output(0)); blob_names.erase(node8->output(0)); std::string affine_scale = node8->input(1); std::string affine_bias = node9->input(1); node9->set_op_type("LayerNorm"); node9->clear_input(); node9->add_input(node->input(0)); node9->add_input(affine_scale); node9->add_input(affine_bias); onnx::AttributeProto* attr_eps = node9->add_attribute(); attr_eps->set_name("epsilon"); attr_eps->set_f(eps); onnx::AttributeProto* attr_affine = node9->add_attribute(); attr_affine->set_name("affine"); attr_affine->set_i(affine); reduced_node_count += 8; i += 8; } } } } void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat // - Reshape if (node->op_type() == "Shape") { if (node_reference[node->output(0)] != 1) continue; if (i + 6 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" || node6->op_type() != "Concat" || node7->op_type() != "Reshape") continue; if (node_reference[node2->output(0)] != 1) continue; // if (node_reference[node3->output(0)] != 1) // continue; if (node_reference[node4->output(0)] != 1) continue; if (node_reference[node5->output(0)] != 1) continue; if (node_reference[node6->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0) || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0) || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0)) continue; // axis = 0 int gather_axis = get_node_attr_i(*node2, "axis"); if (gather_axis != 0) continue; // indices = 0 if (weights.find(node2->input(1)) == weights.end()) continue; std::vector gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); if (gather_indices.size() != 1 || gather_indices[0] != 0) continue; // axes = (0) std::vector unsqueeze_axes = get_node_attr_ai(*node4, "axes"); if (unsqueeze_axes.size() != 1) continue; if (unsqueeze_axes[0] != 0) continue; // axes = (0) std::vector unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); if (unsqueeze2_axes.size() != 1) continue; if (unsqueeze2_axes[0] != 0) continue; // data = -1 if (weights.find(node5->input(0)) == weights.end()) continue; std::vector unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]); if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) continue; // axis = 0 int concat_axis = get_node_attr_i(*node6, "axis"); if (concat_axis != 0) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); // node3->set_op_type("noop_reducedncnn"); node4->set_op_type("noop_reducedncnn"); node5->set_op_type("noop_reducedncnn"); node6->set_op_type("noop_reducedncnn"); node_reference[node->input(0)] -= 1; node_reference[node->output(0)] -= 1; node_reference[node2->input(1)] -= 1; node_reference[node2->output(0)] -= 1; // node_reference[node3->output(0)] -= 1; node_reference[node4->output(0)] -= 1; node_reference[node5->input(0)] -= 1; node_reference[node5->output(0)] -= 1; node_reference[node6->output(0)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); // blob_names.erase(node3->output(0)); blob_names.erase(node4->output(0)); blob_names.erase(node5->output(0)); blob_names.erase(node6->output(0)); node7->set_op_type("Flatten"); node7->clear_input(); node7->add_input(node->input(0)); reduced_node_count += 5; i += 5; } } } void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // PixelShuffle <= Reshape - Transpose - Reshape // PixelShuffle <= Reshape - Transpose - Constant - Reshape if (node->op_type() == "Reshape") { if (node_reference[node->output(0)] != 1) continue; std::vector shape; if (node->input_size() == 1) { shape = get_node_attr_ai(*node, "shape"); } else { // skip weight reshape if (weights.find(node->input(1)) == weights.end()) continue; shape = get_node_attr_from_input_ai(weights[node->input(1)]); } // -1, 3, upscale_factor, upscale_factor, height, width if (shape.size() != 6) continue; if (shape[0] != 1 && shape[0] != -1) continue; if (shape[2] != shape[3]) continue; if (i + 2 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); if (node3->op_type() == "Constant") { if (i + 3 >= node_count) continue; node3 = mutable_graph->mutable_node(i + 3); } if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; if (node_reference[node2->output(0)] != 1) continue; // 0 1 4 2 5 3 std::vector perm = get_node_attr_ai(*node2, "perm"); if (perm.size() != 6) continue; if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3) continue; std::vector shape3; if (node3->input_size() == 1) { shape3 = get_node_attr_ai(*node3, "shape"); } else { // skip weight reshape if (weights.find(node3->input(1)) == weights.end()) continue; shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); } // -1, 3, height, width if (shape3.size() != 4) continue; if (shape3[0] != 1 && shape3[0] != -1) continue; if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5]) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); if (node->input_size() == 2) { node_reference[node->input(1)] -= 1; } node_reference[node->output(0)] -= 1; node_reference[node2->output(0)] -= 1; if (node3->input_size() == 2) { node_reference[node3->input(1)] -= 1; } blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); node3->set_op_type("PixelShuffle"); node3->set_input(0, node->input(0)); onnx::AttributeProto* attr_group = node3->add_attribute(); attr_group->set_name("scale_factor"); attr_group->set_i(shape[2]); reduced_node_count += 2; i += 2; } } } void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // PixelShuffle <= Reshape - Transpose - Reshape // PixelShuffle <= Reshape - Transpose - Constant - Reshape if (node->op_type() == "Reshape") { if (node_reference[node->output(0)] != 1) continue; std::vector shape; if (node->input_size() == 1) { shape = get_node_attr_ai(*node, "shape"); } else { // skip weight reshape if (weights.find(node->input(1)) == weights.end()) continue; shape = get_node_attr_from_input_ai(weights[node->input(1)]); } // -1, 3, out_height, block_size, out_width, block_size if (shape.size() != 6) continue; if (shape[0] != 1 && shape[0] != -1) continue; if (shape[3] != shape[5]) continue; if (i + 2 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); if (node3->op_type() == "Constant") { if (i + 3 >= node_count) continue; node3 = mutable_graph->mutable_node(i + 3); } if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; if (node_reference[node2->output(0)] != 1) continue; // 0 1 3 5 2 4 std::vector perm = get_node_attr_ai(*node2, "perm"); if (perm.size() != 6) continue; if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4) continue; std::vector shape3; if (node3->input_size() == 1) { shape3 = get_node_attr_ai(*node3, "shape"); } else { // skip weight reshape if (weights.find(node3->input(1)) == weights.end()) continue; shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); } // -1, out_channels, out_height, out_width if (shape3.size() != 4) continue; if (shape3[0] != 1 && shape3[0] != -1) continue; if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4]) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); if (node->input_size() == 2) { node_reference[node->input(1)] -= 1; } node_reference[node->output(0)] -= 1; node_reference[node2->output(0)] -= 1; if (node3->input_size() == 2) { node_reference[node3->input(1)] -= 1; } blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); node3->set_op_type("Reorg"); node3->set_input(0, node->input(0)); onnx::AttributeProto* attr_group = node3->add_attribute(); attr_group->set_name("stride"); attr_group->set_i(shape[3]); reduced_node_count += 2; i += 2; } } } void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max if (node->op_type() == "Expand") { if (node_reference[node->output(0)] != 1) continue; if (i + 1 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max") continue; if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0)) continue; // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; if (node->input_size() == 2) { node_reference[node->input(1)] -= 1; } blob_names.erase(node->output(0)); if (node2->input(0) == node->output(0)) { node2->set_input(0, node->input(0)); } else { node2->set_input(1, node->input(0)); } reduced_node_count += 1; i += 1; } } } void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose // or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { if (node_reference[node->output(0)] != 1) continue; if (i + 2 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); // skip if second ops is constant if (node3->op_type() == "Constant") { if (i + 3 >= node_count) continue; node3 = mutable_graph->mutable_node(i + 3); i += 1; } if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; if (node_reference[node2->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; std::string direction = get_node_attr_s(*node, "direction"); if (direction != "bidirectional") continue; // 0 2 1 3 std::vector perm = get_node_attr_ai(*node2, "perm"); if (perm.size() != 4) continue; if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3) continue; std::vector shape; if (node3->input_size() == 1) { shape = get_node_attr_ai(*node3, "shape"); } else { // skip weight reshape if (weights.find(node3->input(1)) == weights.end()) continue; shape = get_node_attr_from_input_ai(weights[node3->input(1)]); } // 0 0 -1 if (shape.size() != 3) continue; if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1) continue; // reduce node2->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; node_reference[node2->output(0)] -= 1; if (node3->input_size() == 2) { node_reference[node3->input(1)] -= 1; } blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); node->set_output(0, node3->output(0)); reduced_node_count += 2; i += 2; if (i + 1 < node_count) { if (node_reference[node3->output(0)] != 1) continue; onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); if (node4->op_type() != "Transpose") continue; if (node4->input(0) != node->output(0)) continue; // 1 0 2 std::vector perm4 = get_node_attr_ai(*node4, "perm"); if (perm4.size() != 3) continue; if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; // reduce node4->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node->set_output(0, node4->output(0)); reduced_node_count += 1; i += 1; } } } for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { if (node_reference[node->output(0)] != 1) continue; if (i + 1 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); if (node2->op_type() != "Squeeze") continue; if (node2->input(0) != node->output(0)) continue; std::string direction = get_node_attr_s(*node, "direction"); if (direction == "bidirectional") continue; // 1 std::vector axes = get_node_attr_ai(*node2, "axes"); if (axes.size() != 1) continue; if (axes[0] != 1) continue; // reduce node2->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node->set_output(0, node2->output(0)); reduced_node_count += 1; i += 1; if (i + 1 < node_count) { if (node_reference[node2->output(0)] != 1) continue; onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); if (node3->op_type() != "Transpose") continue; if (node3->input(0) != node->output(0)) continue; // 1 0 2 std::vector perm4 = get_node_attr_ai(*node3, "perm"); if (perm4.size() != 3) continue; if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; // reduce node3->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node->set_output(0, node3->output(0)); reduced_node_count += 1; i += 1; } } } for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // LSTM <= Transpose - LSTM if (node->op_type() == "Transpose") { if (node_reference[node->output(0)] != 1) continue; // 1 0 2 std::vector perm = get_node_attr_ai(*node, "perm"); if (perm.size() != 3) continue; if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2) continue; if (i + 1 >= node_count) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN") continue; if (node2->input(0) != node->output(0)) continue; // reduce node->set_op_type("noop_reducedncnn"); node_reference[node->output(0)] -= 1; blob_names.erase(node->output(0)); node2->set_input(0, node->input(0)); reduced_node_count += 1; i += 1; } } } void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) { int node_count = mutable_graph->node_size(); for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // MultiHeadAttention <= MatMul(q) - Add // - MatMul(k) - Add // - MatMul(v) - Add // - Mul // - Reshape - Transpose // - Reshape - Reshape - Transpose - Transpose // - Gemm - Softmax - Gemm - Transpose - Reshape - // MatMul - Add if (node->op_type() == "MatMul") { if (i + 19 >= node_count) continue; if (node_reference[node->output(0)] != 1) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17); onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18); onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19); if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add") continue; if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0)) continue; std::vector q_B = get_node_attr_from_input_af(weights[node2->input(1)]); std::vector k_B = get_node_attr_from_input_af(weights[node4->input(1)]); std::vector v_B = get_node_attr_from_input_af(weights[node6->input(1)]); std::vector o_B = get_node_attr_from_input_af(weights[node20->input(1)]); if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size()) continue; int embed_dim = q_B.size(); // 1 0 2 std::vector perm9 = get_node_attr_ai(*node9, "perm"); std::vector perm12 = get_node_attr_ai(*node12, "perm"); if (perm9.size() != 3 || perm12.size() != 3) continue; if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2) continue; // 1 2 0 std::vector perm13 = get_node_attr_ai(*node13, "perm"); if (perm13.size() != 3) continue; if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0) continue; // 1 0 2 std::vector perm17 = get_node_attr_ai(*node17, "perm"); if (perm17.size() != 3) continue; if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2) continue; int softmax_axis = get_node_attr_i(*node15, "axis"); if (softmax_axis != 2) continue; // 1/-1, seqlen * num_heads, embed_dim / num_heads std::vector shape8; std::vector shape10; std::vector shape11; if (node8->input_size() == 1) { shape8 = get_node_attr_ai(*node8, "shape"); } else { // skip weight reshape if (weights.find(node8->input(1)) == weights.end()) continue; shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); } if (node10->input_size() == 1) { shape10 = get_node_attr_ai(*node10, "shape"); } else { // skip weight reshape if (weights.find(node10->input(1)) == weights.end()) continue; shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]); } if (node11->input_size() == 1) { shape11 = get_node_attr_ai(*node11, "shape"); } else { // skip weight reshape if (weights.find(node11->input(1)) == weights.end()) continue; shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]); } if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3) continue; if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2]) continue; int num_heads = embed_dim / shape8[2]; // 1, seqlen, embed_dim std::vector shape18; if (node18->input_size() == 1) { shape18 = get_node_attr_ai(*node18, "shape"); } else { // skip weight reshape if (weights.find(node18->input(1)) == weights.end()) continue; shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]); } if (shape18.size() != 3) continue; if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1]) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node4->set_op_type("noop_reducedncnn"); node5->set_op_type("noop_reducedncnn"); node6->set_op_type("noop_reducedncnn"); node7->set_op_type("noop_reducedncnn"); node8->set_op_type("noop_reducedncnn"); node9->set_op_type("noop_reducedncnn"); node10->set_op_type("noop_reducedncnn"); node11->set_op_type("noop_reducedncnn"); node12->set_op_type("noop_reducedncnn"); node13->set_op_type("noop_reducedncnn"); node14->set_op_type("noop_reducedncnn"); node15->set_op_type("noop_reducedncnn"); node16->set_op_type("noop_reducedncnn"); node17->set_op_type("noop_reducedncnn"); node18->set_op_type("noop_reducedncnn"); node19->set_op_type("noop_reducedncnn"); node_reference[node2->input(0)] -= 1; node_reference[node4->input(0)] -= 1; node_reference[node6->input(0)] -= 1; node_reference[node7->input(0)] -= 1; node_reference[node7->input(1)] -= 1; node_reference[node8->input(0)] -= 1; if (node8->input_size() == 2) { node_reference[node8->input(1)] -= 1; } node_reference[node9->input(0)] -= 1; node_reference[node10->input(0)] -= 1; if (node10->input_size() == 2) { node_reference[node10->input(1)] -= 1; } node_reference[node11->input(0)] -= 1; if (node11->input_size() == 2) { node_reference[node11->input(1)] -= 1; } node_reference[node12->input(0)] -= 1; node_reference[node13->input(0)] -= 1; node_reference[node14->input(0)] -= 1; node_reference[node14->input(1)] -= 1; node_reference[node15->input(0)] -= 1; node_reference[node16->input(0)] -= 1; node_reference[node16->input(1)] -= 1; node_reference[node17->input(0)] -= 1; node_reference[node18->input(0)] -= 1; if (node18->input_size() == 2) { node_reference[node18->input(1)] -= 1; } node_reference[node19->input(0)] -= 1; node_reference[node20->input(0)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); blob_names.erase(node3->output(0)); blob_names.erase(node4->output(0)); blob_names.erase(node5->output(0)); blob_names.erase(node6->output(0)); blob_names.erase(node7->output(0)); blob_names.erase(node8->output(0)); blob_names.erase(node9->output(0)); blob_names.erase(node10->output(0)); blob_names.erase(node11->output(0)); blob_names.erase(node12->output(0)); blob_names.erase(node13->output(0)); blob_names.erase(node14->output(0)); blob_names.erase(node15->output(0)); blob_names.erase(node16->output(0)); blob_names.erase(node17->output(0)); blob_names.erase(node18->output(0)); blob_names.erase(node19->output(0)); std::string qw = node->input(1); std::string qb = node2->input(1); std::string kw = node3->input(1); std::string kb = node4->input(1); std::string vw = node5->input(1); std::string vb = node6->input(1); std::string ow = node19->input(1); std::string ob = node20->input(1); node20->set_op_type("MultiHeadAttention"); node20->clear_input(); node20->add_input(node->input(0)); node20->add_input(node3->input(0)); node20->add_input(node5->input(0)); // q node20->add_input(qw); node20->add_input(qb); // k node20->add_input(kw); node20->add_input(kb); // v node20->add_input(vw); node20->add_input(vb); // out linear node20->add_input(ow); node20->add_input(ob); onnx::AttributeProto* attr_embed_dim = node20->add_attribute(); attr_embed_dim->set_name("embed_dim"); attr_embed_dim->set_i(embed_dim); onnx::AttributeProto* attr_num_heads = node20->add_attribute(); attr_num_heads->set_name("num_heads"); attr_num_heads->set_i(num_heads); reduced_node_count += 19; i += 19; } } for (int i = 0; i < node_count; i++) { onnx::NodeProto* node = mutable_graph->mutable_node(i); // MultiHeadAttention <= MatMul(qkv) - Add - Split // - Mul // - Reshape - Transpose // - Reshape - Reshape - Transpose - Transpose // - Gemm - Softmax - Gemm - Transpose - Reshape - // MatMul - Add if (node->op_type() == "MatMul") { if (i + 16 >= node_count) continue; if (node_reference[node->output(0)] != 1) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add") continue; if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1) continue; if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0)) continue; std::vector qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]); std::vector o_B = get_node_attr_from_input_af(weights[node17->input(1)]); if (qkv_B.size() != o_B.size() * 3) continue; int embed_dim = o_B.size(); // 1 0 2 std::vector perm6 = get_node_attr_ai(*node6, "perm"); std::vector perm9 = get_node_attr_ai(*node9, "perm"); if (perm6.size() != 3 || perm9.size() != 3) continue; if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2) continue; // 1 2 0 std::vector perm10 = get_node_attr_ai(*node10, "perm"); if (perm10.size() != 3) continue; if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0) continue; // 1 0 2 std::vector perm14 = get_node_attr_ai(*node14, "perm"); if (perm14.size() != 3) continue; if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2) continue; int softmax_axis = get_node_attr_i(*node12, "axis"); if (softmax_axis != 2) continue; // 1/-1, seqlen * num_heads, embed_dim / num_heads std::vector shape5; std::vector shape7; std::vector shape8; if (node5->input_size() == 1) { shape5 = get_node_attr_ai(*node5, "shape"); } else { // skip weight reshape if (weights.find(node5->input(1)) == weights.end()) continue; shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]); } if (node7->input_size() == 1) { shape7 = get_node_attr_ai(*node7, "shape"); } else { // skip weight reshape if (weights.find(node7->input(1)) == weights.end()) continue; shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]); } if (node8->input_size() == 1) { shape8 = get_node_attr_ai(*node8, "shape"); } else { // skip weight reshape if (weights.find(node8->input(1)) == weights.end()) continue; shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); } if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3) continue; if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2]) continue; int num_heads = embed_dim / shape5[2]; // 1, seqlen, embed_dim std::vector shape15; if (node15->input_size() == 1) { shape15 = get_node_attr_ai(*node15, "shape"); } else { // skip weight reshape if (weights.find(node15->input(1)) == weights.end()) continue; shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]); } if (shape15.size() != 3) continue; if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1]) continue; // reduce node->set_op_type("noop_reducedncnn"); node2->set_op_type("noop_reducedncnn"); node3->set_op_type("noop_reducedncnn"); node4->set_op_type("noop_reducedncnn"); node5->set_op_type("noop_reducedncnn"); node6->set_op_type("noop_reducedncnn"); node7->set_op_type("noop_reducedncnn"); node8->set_op_type("noop_reducedncnn"); node9->set_op_type("noop_reducedncnn"); node10->set_op_type("noop_reducedncnn"); node11->set_op_type("noop_reducedncnn"); node12->set_op_type("noop_reducedncnn"); node13->set_op_type("noop_reducedncnn"); node14->set_op_type("noop_reducedncnn"); node15->set_op_type("noop_reducedncnn"); node16->set_op_type("noop_reducedncnn"); node_reference[node2->input(0)] -= 1; node_reference[node3->input(0)] -= 1; node_reference[node4->input(0)] -= 1; node_reference[node4->input(1)] -= 1; node_reference[node5->input(0)] -= 1; if (node5->input_size() == 2) { node_reference[node5->input(1)] -= 1; } node_reference[node6->input(0)] -= 1; node_reference[node7->input(0)] -= 1; if (node7->input_size() == 2) { node_reference[node7->input(1)] -= 1; } node_reference[node8->input(0)] -= 1; if (node8->input_size() == 2) { node_reference[node8->input(1)] -= 1; } node_reference[node9->input(0)] -= 1; node_reference[node10->input(0)] -= 1; node_reference[node11->input(0)] -= 1; node_reference[node11->input(1)] -= 1; node_reference[node12->input(0)] -= 1; node_reference[node13->input(0)] -= 1; node_reference[node13->input(1)] -= 1; node_reference[node14->input(0)] -= 1; node_reference[node15->input(0)] -= 1; if (node15->input_size() == 2) { node_reference[node15->input(1)] -= 1; } node_reference[node16->input(0)] -= 1; node_reference[node17->input(0)] -= 1; blob_names.erase(node->output(0)); blob_names.erase(node2->output(0)); blob_names.erase(node3->output(0)); blob_names.erase(node3->output(1)); blob_names.erase(node3->output(2)); blob_names.erase(node4->output(0)); blob_names.erase(node5->output(0)); blob_names.erase(node6->output(0)); blob_names.erase(node7->output(0)); blob_names.erase(node8->output(0)); blob_names.erase(node9->output(0)); blob_names.erase(node10->output(0)); blob_names.erase(node11->output(0)); blob_names.erase(node12->output(0)); blob_names.erase(node13->output(0)); blob_names.erase(node14->output(0)); blob_names.erase(node15->output(0)); blob_names.erase(node16->output(0)); std::string qkvw = node->input(1); std::string qkvb = node2->input(1); std::string ow = node16->input(1); std::string ob = node17->input(1); node17->set_op_type("MultiHeadAttention"); node17->clear_input(); node17->add_input(node->input(0)); // qkv node17->add_input(qkvw); node17->add_input(qkvb); // out linear node17->add_input(ow); node17->add_input(ob); onnx::AttributeProto* attr_embed_dim = node17->add_attribute(); attr_embed_dim->set_name("embed_dim"); attr_embed_dim->set_i(embed_dim); onnx::AttributeProto* attr_num_heads = node17->add_attribute(); attr_num_heads->set_name("num_heads"); attr_num_heads->set_i(num_heads); reduced_node_count += 16; i += 16; } } }