diff --git a/csrc/backend_ops/torchscript/optimizer/bind.cpp b/csrc/backend_ops/torchscript/optimizer/bind.cpp index 73594776a..21a691f14 100644 --- a/csrc/backend_ops/torchscript/optimizer/bind.cpp +++ b/csrc/backend_ops/torchscript/optimizer/bind.cpp @@ -4,13 +4,19 @@ #include #include "optimizer.h" +#include "passes/onnx/flatten_cls_head.h" +#include "passes/onnx/merge_shape_concate.h" +#include "passes/onnx/onnx_peephole.h" + +namespace mmdeploy { +namespace torch_jit { void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript", const std::string& backend = "torchscript") { if (ir == "torchscript") { - model = mmdeploy::optimize_for_torchscript(model); + model = optimize_for_torchscript(model); } else if (ir == "onnx") { - model = mmdeploy::optimize_for_onnx(model); + model = optimize_for_onnx(model); } else { fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(), backend.c_str()); @@ -23,4 +29,11 @@ PYBIND11_MODULE(ts_optimizer, m) { m.def("optimize_for_backend", optimize_for_backend, py::arg("module"), py::arg("ir") = std::string("torchscript"), py::arg("backend") = std::string("torchscript")); + py::module_ onnx_module = m.def_submodule("onnx"); + onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); + onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); + onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); } + +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp new file mode 100644 index 000000000..97425aa5b --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp @@ -0,0 +1,311 @@ +// modify from: +// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/ir/subgraph_matcher.cpp +#include "subgraph_matcher.h" + +#include +#include +#include + +#include +#include +namespace mmdeploy { +namespace torch_jit { + +using torch::jit::AttributeKind; +using torch::jit::ClassType; +using torch::jit::Node; +using torch::jit::Symbol; +using torch::jit::Value; + +namespace prim { +using namespace ::c10::prim; +} + +namespace attr { +using namespace ::c10::attr; +} + +/** + * \brief A class implementing an API for comparing subgraphs. + */ +class SubgraphMatcher::SubgraphMatcherImpl { + public: + explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute) + : pattern_(pattern), match_attribute_(match_attribute) {} + + /** + * \brief Compare matchGraph with the part of the graph denoted by a node \p + * ANCHOR. + * + * The anchor node would be compared against the deepest node in the + * match-graph. A node is considered matching if its number of inputs/outputs + * is the same as in the corresponding matchGraph node, its type is the same, + * and all nodes producing input-values also match. + */ + bool matchesSubgraphFromAnchorNode(Node* anchor); + + /** \brief Return match map for nodes. */ + std::unordered_map nodes_map() const { return nodes_map_; } + + /** \brief Return match map for values. */ + std::unordered_map values_map() const { return values_map_; } + + private: + bool matchValues(const Value* v1, Value* v2); + bool matchNodes(const Node* n1, Node* n2); + bool matchAttributes(const Node* n1, Node* n2); + + static bool isInput(const Value* v); + static bool isOutput(const Value* v); + + std::unordered_map nodes_map_; + std::unordered_map values_map_; + + const MatchAttribute match_attribute_; + const Graph& pattern_; + const Node* anchor_ = nullptr; +}; + +bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) { + return v->node()->kind() == prim::Param; +} + +bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) { + for (const Value* output : v->owningGraph()->outputs()) { + if (v == output) { + return true; + } + } + return false; +} + +/** + * Compare two Values. V1 is from pattern, V2 is from the actual graph. + * + * The values are considered matching if: + * 1) the nodes defining them match + * 2) they have the same number of uses, except they are entry or exit nodes. + */ +bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) { + // Check if we've already visited these values. + if (values_map_.count(v1)) { + if (values_map_.at(v1) != v2) { + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), + " did not match because %", v1->debugName(), " has already been matched with %", + values_map_.at(v1)->debugName(), ".\n"); + return false; + } + return true; + } + + // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is + // PARAM, we're comparing entering values - in these two cases the number of + // uses don't need to be the same. + if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) { + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), + " did not match because number of their uses is different.\n"); + return false; + } + + // Add the values to the map before calling matchNodes to avoid infinite + // recursion. + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n"); + values_map_[v1] = v2; + return matchNodes(v1->node(), v2->node()); +} + +bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) { + if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) { + GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2); + return false; + } + for (const Symbol& attr_name : n1->attributeNames()) { + if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) { + GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), + "' did not match:\n", *n1, *n2); + return false; + } + std::vector n1is, n2is; + std::vector n1fs, n2fs; + switch (n1->kindOf(attr_name)) { + case AttributeKind::s: + if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) { + GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), + "' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1, + *n2); + return false; + } + break; + case AttributeKind::f: + if (n1->f(attr_name) != n2->f(attr_name)) { + GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), + "' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1, + *n2); + return false; + } + break; + case AttributeKind::i: + if (n1->i(attr_name) != n2->i(attr_name)) { + GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), + "' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1, + *n2); + return false; + } + break; + case AttributeKind::is: + n1is = n1->is(attr_name); + n2is = n2->is(attr_name); + if (n1is.size() != n2is.size()) return false; + for (int i = 0; i < n1is.size(); ++i) { + if (n1is[i] != n2is[i]) return false; + } + break; + case AttributeKind::fs: + n1fs = n1->fs(attr_name); + n2fs = n2->fs(attr_name); + if (n1fs.size() != n2fs.size()) return false; + for (int i = 0; i < n1fs.size(); ++i) { + if (n1fs[i] != n2fs[i]) return false; + } + break; + default: { + // Other attributes types not supported yet + GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), + "' is not supported.\n", *n1, *n2); + return false; + } + } + } + return true; +} + +static bool endsWith(const std::string& str, const std::string& suffix) { + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +/** + * Compare two Nodes. N1 is from pattern, N2 is from the actual graph. + * + * The nodes are considered matching if: + * 1) N1 and N2 are of the same kind. + * 2) Number of inputs and outputs is the same. + * 3) All input and output values match. + * + * A special case is when N1 is PARAM - this is considered outside the pattern, + * so it matches everything. + */ +bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) { + // Check if we've already visited these nodes. + if (nodes_map_.count(n1)) { + return nodes_map_.at(n1) == n2; + } + + // Param node in pattern graph matches everything. + if (n1->kind() == prim::Param) { + GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); + return true; + } + + // We don't allow matches to span across blocks, so check if N2 is in the same + // block as the first (anchor) node. + if (n2->owningBlock() != anchor_->owningBlock()) { + GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2); + return false; + } + + // Special handling for matching modules + if (n1->kind() == Symbol::fromQualString("match::module")) { + if (n2->kind() == prim::GetAttr) { + if (!n1->hasAttributeS("name")) { + GRAPH_DEBUG( + "Nodes did not match because special node match::module does not have 'name' " + "attribute:\n", + *n1, *n2); + return false; + } + auto t = n2->output()->type()->expect(); + auto real_typename = t->name()->qualifiedName(); + auto pattern_typename = n1->s(attr::name); + if (!endsWith(real_typename, pattern_typename)) { + GRAPH_DEBUG("Nodes did not match because expected module type is different:\n"); + GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); + GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); + GRAPH_DEBUG("Nodes:", *n1, *n2); + return false; + } + } + } else { + if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() || + n1->inputs().size() != n2->inputs().size()) { + GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2); + return false; + } + + if (match_attribute_ != NO_MATCH) { + if (!matchAttributes(n1, n2)) { + return false; + } + } + } + + // Add nodes to the map before calling matchValues to avoid infinite + // recursion. + nodes_map_[n1] = n2; + for (const auto i : c10::irange(n1->outputs().size())) { + if (!matchValues(n1->outputs()[i], n2->outputs()[i])) { + return false; + } + } + for (const auto i : c10::irange(n1->inputs().size())) { + if (!matchValues(n1->inputs()[i], n2->inputs()[i])) { + return false; + } + } + + GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); + return true; +} + +/** + * Recursively try to match pattern with the actual graph starting from the + * exiting node in the pattern and anchor node in the actual graph. + */ +bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) { + GRAPH_UPDATE("Starting match from a new anchor: ", *anchor); + nodes_map_.clear(); + values_map_.clear(); + anchor_ = anchor; + + const Node* bottom_node = *(pattern_.nodes().end()); + bottom_node = bottom_node->input(0)->node(); + + if (!matchNodes(bottom_node, anchor)) { + return false; + } + + for (const Value* output : pattern_.outputs()) { + AT_ASSERT(values_map_.count(output)); + } + + GRAPH_UPDATE("Pattern matched!\n"); + return true; +} + +SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) + : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {} + +bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { + return impl_->matchesSubgraphFromAnchorNode(anchor); +} + +std::unordered_map SubgraphMatcher::nodes_map() const { + return impl_->nodes_map(); +} + +std::unordered_map SubgraphMatcher::values_map() const { + return impl_->values_map(); +} + +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h new file mode 100644 index 000000000..6629b598e --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h @@ -0,0 +1,36 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _SUBGRAPH_MATCHER_H_ +#define _SUBGRAPH_MATCHER_H_ + +#include + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::jit::Graph; +using torch::jit::Node; +using torch::jit::Value; + +enum MatchAttribute { FORCE_MATCH, TRY_MATCH, NO_MATCH }; + +class SubgraphMatcher { + public: + explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); + + bool matchesSubgraphFromAnchorNode(Node* anchor); + + /** \brief Return match map for nodes. */ + std::unordered_map nodes_map() const; + + /** \brief Return match map for values. */ + std::unordered_map values_map() const; + + private: + class SubgraphMatcherImpl; + std::unique_ptr impl_ = nullptr; +}; + +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp new file mode 100644 index 000000000..5c7082f6a --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp @@ -0,0 +1,119 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "flatten_cls_head.h" + +#include +#include +#include + +#include + +#include "utils.h" + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::IValue; +using torch::jit::Match; +using torch::jit::TensorType; +using torch::jit::TypeKind; +using torch::jit::Value; + +static bool matchClsHead(const Match& match, const std::unordered_map& map) { + // TODO: check if value map in latest pytorch can ease the filter. + + // check cat -1 + { + // check if the shape of second inputs is 1 + auto cat_v1 = match.values_map.at(map.at("cat1")); + if (cat_v1->type()->kind() != TypeKind::TensorType) return false; + auto cat_v1_type = cat_v1->type()->cast(); + auto cat_v1_size = cat_v1_type->sizes().concrete_sizes(); + if (!cat_v1_size.has_value()) return false; + IValue cat_v1_size_value(cat_v1_size.value()); + auto size_list = cat_v1_size_value.toIntList(); + if (size_list.size() != 1 || size_list[0] != 1) return false; + } + + // check unsqueeze + auto cat_v0 = match.values_map.at(map.at("cat0")); + auto unsqueeze_node = cat_v0->node(); + { + if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false; + auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes")); + if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false; + } + + // check gather + auto gather_node = unsqueeze_node->input()->node(); + auto gather_inputs = gather_node->inputs(); + { + if (!is_kind(gather_node, "onnx::Gather")) return false; + auto gather_axis = gather_node->i(Symbol::attr("axis")); + if (gather_axis != 0) return false; + } + + auto x = match.values_map.at(map.at("x")); + // check shape + auto shape_node = gather_inputs[0]->node(); + { + if (!is_kind(shape_node, "onnx::Shape")) return false; + if (shape_node->input() != x) return false; + } + + // check constant + auto const_node = gather_inputs[1]->node(); + { + if (!is_kind(const_node, "onnx::Constant")) return false; + auto ival = const_node->t(Symbol::attr("value")); + if (ival.dim() != 0) return false; + auto ival_dataptr = ival.data_ptr(); + if (ival_dataptr[0] != 0) return false; + } + + // check if reshape is the output of the graph + auto reshape_pattern = map.at("reshape"); + auto reshape_node = match.values_map.at(reshape_pattern); + auto uses = reshape_node->uses(); + for (auto use : uses) { + auto user = use.user; + if (is_kind(user, "prim::Return")) return false; + } + + return true; +} + +// from: +// x->shape->gather->unsqueeze->concat +// | | +// gap--------------------------reshape +// +// to: +// x->gap->flatten +void FlattenClsHead(std::shared_ptr& graph) { + std::string pattern = R"IR( + graph(%x, %cat0, %cat1): + %gap = onnx::GlobalAveragePool(%x) + %cat = onnx::Concat[axis=0](%cat0, %cat1) + %reshape = onnx::Reshape(%gap, %cat) + return (%reshape) + )IR"; + + std::string replacement = R"IR( + graph(%x, %cat0, %cat1): + %gap = onnx::GlobalAveragePool(%x) + %flatten = onnx::Flatten(%gap) + return (%flatten) + )IR"; + + torch::jit::SubgraphRewriter subgraph_rewriter; + subgraph_rewriter.RegisterRewritePattern(pattern, replacement); + subgraph_rewriter.runOnGraph(graph, matchClsHead); + + torch::jit::EliminateDeadCode( + graph->block(), true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); +} + +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h new file mode 100644 index 000000000..b66b700d1 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _FLATTEN_CLS_HEAD_H_ +#define _FLATTEN_CLS_HEAD_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::jit::Graph; + +void FlattenClsHead(std::shared_ptr& graph); +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp new file mode 100644 index 000000000..47fc8c205 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp @@ -0,0 +1,113 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "merge_shape_concate.h" + +#include + +#include "utils.h" + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::Block; +using torch::jit::IValue; +using torch::jit::Node; +using torch::jit::TensorType; +using torch::jit::Value; + +void MergeShapeConcate(Node* node) { + auto inputs = node->inputs(); + + std::vector gather_value; + Value* shape_from = nullptr; + + std::vector node_to_remove{node}; + + // check pattern shape->gather->unsqueeze->concate + for (auto input : inputs) { + auto unsqueeze_node = input->node(); + if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1) + return; + + auto axes = unsqueeze_node->is(Symbol::attr("axes")); + if (axes.size() != 1 && axes[0] != 0) return; + + auto gather_node = unsqueeze_node->input()->node(); + if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 || + gather_node->output()->uses().size() != 1) + return; + + auto gather_inputs = gather_node->inputs(); + auto gather_data = gather_inputs[0]; + auto gather_indices = gather_inputs[1]; + auto shape_node = gather_data->node(); + if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return; + + auto current_shape_from = shape_node->input(); + if (!shape_from) { + shape_from = current_shape_from; + } else { + if (shape_from != current_shape_from) return; + } + + auto constant_node = gather_indices->node(); + if (!is_kind(constant_node, "onnx::Constant")) return; + + auto gather_indices_val = constant_node->t(Symbol::attr("value")); + long* data_ptr = gather_indices_val.data_ptr(); + if (gather_indices_val.dim() == 0) { + gather_value.push_back(data_ptr[0]); + } else { + int element_size = gather_indices_val.element_size(); + for (int j = 0; j < element_size; ++j) { + gather_value.push_back(data_ptr[j]); + } + } + + node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node}); + } + + // create constant value + auto graph = node->owningGraph(); + auto const_node = graph->create(Symbol::onnx("Constant")); + const_node->t_(Symbol::attr("value"), at::tensor(gather_value)); + auto first_node = node->owningGraph()->block()->nodes().front(); + if (const_node != first_node) const_node->insertBefore(first_node); + + // recreate shape node + auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from}); + shape_node->insertBefore(node); + + // create gather node + auto gather_node = + graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()}); + + // insert into graph + gather_node->insertAfter(node); + node->output()->replaceAllUsesWith(gather_node->output()); + + for (auto n : node_to_remove) { + n->destroy(); + } +} + +void MergeShapeConcate(Block* block) { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) { + auto node = *it; + ++it; + for (auto block : node->blocks()) { + MergeShapeConcate(block); + } + + if (is_kind(node, "onnx::Concat")) { + MergeShapeConcate(node); + } + } +} + +void MergeShapeConcate(const std::shared_ptr& graph) { MergeShapeConcate(graph->block()); } + +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h new file mode 100644 index 000000000..8656da63c --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _MERGE_SHAPE_CONCATE_H_ +#define _MERGE_SHAPE_CONCATE_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::jit::Graph; + +void MergeShapeConcate(const std::shared_ptr& graph); +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp new file mode 100644 index 000000000..0ba9b9cd1 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp @@ -0,0 +1,81 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "onnx_peephole.h" + +#include + +#include + +#include "utils.h" + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::Block; +using torch::jit::IValue; +using torch::jit::Node; +using torch::jit::TensorType; +using torch::jit::Value; + +void RemoveReshapeChain(Node* node) { + // reshape->reshape => reshape + auto output = node->output(); + if (!(output->hasUses())) { + return; + } + auto uses = output->uses(); + + for (auto use : uses) { + if (is_kind(use.user, "onnx::Reshape") || use.offset != 0) { + return; + } + } + + auto input = node->inputs()[0]; + output->replaceAllUsesWith(input); + + node->destroy(); +} + +void RemoveRedundantCast(Node* node) { + // Cast(type n)->Cast(type n) => Cast(type n) + + auto to_type = node->i(Symbol::attr("to")); + auto input = node->input(); + + auto input_node = input->node(); + if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) { + auto output = node->output(); + + output->replaceAllUsesWith(input); + node->destroy(); + } +} + +void ONNXPeephole(Block* block) { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) { + auto node = *it; + ++it; + for (auto block : node->blocks()) { + ONNXPeephole(block); + } + + if (is_kind(node, "onnx::Reshape")) { + RemoveReshapeChain(node); + } else if (is_kind(node, "onnx::Cast")) { + RemoveRedundantCast(node); + } + } +} + +void ONNXPeephole(const std::shared_ptr& graph) { + ONNXPeephole(graph->block()); + torch::jit::EliminateDeadCode( + graph->block(), true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); +} + +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h new file mode 100644 index 000000000..f388da1bf --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h @@ -0,0 +1,15 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _ONNX_PEEPHOLE_H_ +#define _ONNX_PEEPHOLE_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::jit::Graph; + +void ONNXPeephole(const std::shared_ptr& graph); + +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/utils.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/utils.h new file mode 100644 index 000000000..1c92cd15a --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/utils.h @@ -0,0 +1,20 @@ +#ifndef _PASSES_ONNX_UTILS_H_ +#define _PASSES_ONNX_UTILS_H_ + +#include + +namespace mmdeploy { +namespace torch_jit { +using c10::Symbol; +using torch::jit::Node; + +inline bool is_kind(const Node* node, const Symbol& symbol) { return node->kind() == symbol; } + +inline bool is_kind(const Node* node, const char* symbol_name) { + return is_kind(node, Symbol::fromQualString(symbol_name)); +} + +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/docs/en/experimental/onnx_optimizer.md b/docs/en/experimental/onnx_optimizer.md new file mode 100644 index 000000000..a40939d18 --- /dev/null +++ b/docs/en/experimental/onnx_optimizer.md @@ -0,0 +1,50 @@ +# ONNX export Optimizer + +This is a tool to optimize ONNX model when exporting from PyTorch. + +## Installation + +Build MMDeploy with `torchscript` support: + +```shell +export Torch_DIR=$(python -c "import torch;print(torch.utils.cmake_prefix_path + '/Torch')") + +cmake \ + -DTorch_DIR=${Torch_DIR} \ + -DMMDEPLOY_TARGET_BACKENDS="${your_backend};torchscript" \ + .. # You can also add other build flags if you need + +cmake --build . -- -j$(nproc) && cmake --install . +``` + +## Usage + +```python +# import model_to_graph_custom_optimizer so we can hijack onnx.export +from mmdeploy.apis.onnx.optimizer import model_to_graph__custom_optimizer # noqa +from mmdeploy.core import RewriterContext +from mmdeploy.apis.onnx.passes import optimize_onnx + +# load you model here +model = create_model() + +# export with ONNX Optimizer +x = create_dummy_input() +with RewriterContext({}, onnx_custom_passes=optimize_onnx): + torch.onnx.export(model, x, output_path) +``` + +The model would be optimized after export. + +You can also define your own optimizer: + +```python +# create the optimize callback +def _optimize_onnx(graph, params_dict, torch_out): + from mmdeploy.backend.torchscript import ts_optimizer + ts_optimizer.onnx._jit_pass_onnx_peephole(graph) + return graph, params_dict, torch_out + +with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + # export your model +``` diff --git a/mmdeploy/apis/onnx/export.py b/mmdeploy/apis/onnx/export.py index ddad7f30c..db3ede9ce 100644 --- a/mmdeploy/apis/onnx/export.py +++ b/mmdeploy/apis/onnx/export.py @@ -8,6 +8,8 @@ import torch from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.core import RewriterContext, patch_model from mmdeploy.utils import Backend, get_root_logger +from .optimizer import * # noqa +from .passes import optimize_onnx @PIPELINE_MANAGER.register_pipeline() @@ -23,6 +25,7 @@ def export(model: torch.nn.Module, dynamic_axes: Optional[Dict] = None, verbose: bool = False, keep_initializers_as_inputs: Optional[bool] = None, + optimize: bool = True, **kwargs): """Export a PyTorch model into ONNX format. This is a wrap of `torch.onnx.export` with some enhancement. @@ -64,6 +67,7 @@ def export(model: torch.nn.Module, verbose (bool): Enable verbose model on `torch.onnx.export`. keep_initializers_as_inputs (bool): Whether we should add inputs for each initializer. + optimize (bool): Perform optimize on model. """ output_path = output_path_prefix + '.onnx' @@ -102,6 +106,9 @@ def export(model: torch.nn.Module, # patch model patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) + if 'onnx_custom_passes' not in context_info: + onnx_custom_passes = optimize_onnx if optimize else None + context_info['onnx_custom_passes'] = onnx_custom_passes with RewriterContext(**context_info), torch.no_grad(): # patch input_metas if input_metas is not None: diff --git a/mmdeploy/apis/onnx/optimizer.py b/mmdeploy/apis/onnx/optimizer.py new file mode 100644 index 000000000..612e9d8ea --- /dev/null +++ b/mmdeploy/apis/onnx/optimizer.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph') +def model_to_graph__custom_optimizer(ctx, *args, **kwargs): + """Rewriter of _model_to_graph, add custom passes.""" + graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs) + + custom_passes = getattr(ctx, 'onnx_custom_passes', None) + + if custom_passes is not None: + assert isinstance( + custom_passes, Callable + ), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.' + graph, params_dict, torch_out = custom_passes(graph, params_dict, + torch_out) + + return graph, params_dict, torch_out diff --git a/mmdeploy/apis/onnx/passes/__init__.py b/mmdeploy/apis/onnx/passes/__init__.py new file mode 100644 index 000000000..130a31823 --- /dev/null +++ b/mmdeploy/apis/onnx/passes/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .optimize_onnx import optimize_onnx + +__all__ = ['optimize_onnx'] diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py new file mode 100644 index 000000000..d413a513e --- /dev/null +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.utils import get_root_logger + + +def optimize_onnx(graph, params_dict, torch_out): + logger = get_root_logger() + logger.info('Execute onnx optimize passes.') + try: + from mmdeploy.backend.torchscript import ts_optimizer + ts_optimizer.onnx._jit_pass_merge_shape_concate(graph) + ts_optimizer.onnx._jit_pass_onnx_peephole(graph) + ts_optimizer.onnx._jit_pass_flatten_cls_head(graph) + except Exception: + pass + + return graph, params_dict, torch_out diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py new file mode 100644 index 000000000..c7dc891c5 --- /dev/null +++ b/tests/test_apis/test_onnx_passes.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import tempfile +from typing import Any, List, Tuple + +import onnx +import pytest +import torch +import torch.nn as nn + +from mmdeploy.apis.onnx.optimizer import \ + model_to_graph__custom_optimizer # noqa +from mmdeploy.core import RewriterContext + +onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name + + +def _find_next_node(start: int, nodes: List, op_type: str) -> Tuple[Any, int]: + for idx, n in enumerate(nodes[start:]): + if n.op_type == op_type: + return n, idx + return None, -1 + + +def test_merge_shape_concate(): + pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx') + + try: + from mmdeploy.backend.torchscript import ts_optimizer + opt_pass = ts_optimizer.onnx._jit_pass_merge_shape_concate + except ImportError: + pytest.skip('pass not found.') + + def _optimize_onnx(graph, params_dict, torch_out): + opt_pass(graph) + return graph, params_dict, torch_out + + class TestModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.new_zeros(x.shape[-2:]) + + model = TestModel() + x = torch.rand(1, 3, 4, 8) + + with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + torch.onnx.export( + model, + x, + onnx_file, + input_names=['input'], + output_names=['output'], + dynamic_axes=dict(input={ + 2: 'h', + 3: 'w' + }), + opset_version=11) + + onnx_model = onnx.load(onnx_file) + graph = onnx_model.graph + nodes = graph.node + shape_idx = 0 + for n in nodes: + if n.op_type != 'Shape': + shape_idx += 1 + else: + break + + assert shape_idx < len(nodes) + assert nodes[shape_idx + 1].op_type == 'Gather' + assert nodes[shape_idx + 2].op_type == 'ConstantOfShape' + + +def test_peephole(): + pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx') + + try: + from mmdeploy.backend.torchscript import ts_optimizer + opt_pass = ts_optimizer.onnx._jit_pass_onnx_peephole + except ImportError: + pytest.skip('pass not found.') + + def _optimize_onnx(graph, params_dict, torch_out): + opt_pass(graph) + return graph, params_dict, torch_out + + class TestModel(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + + x = x.int() + x = x.int() + x = x.float() + + x = x.view(10, -1) + y = x.view(2, -1) + z = x.view(3, -1) + + return y, z + + model = TestModel() + x = torch.rand(2, 3, 5) + + with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + torch.onnx.export( + model, + x, + onnx_file, + input_names=['input'], + output_names=['output1', 'output2'], + dynamic_axes=dict(input={ + 0: 'b', + 1: 'c', + 2: 'w' + }), + opset_version=11) + + onnx_model = onnx.load(onnx_file) + graph = onnx_model.graph + nodes = graph.node + + node, idx = _find_next_node(0, nodes, 'Cast') + assert node is not None + assert node.attribute[0].i == 6 + + node, idx = _find_next_node(idx + 1, nodes, 'Cast') + assert node is not None + assert node.attribute[0].i == 1 + + node, idx = _find_next_node(idx + 1, nodes, 'Reshape') + assert node is not None + + node, idx = _find_next_node(idx + 1, nodes, 'Reshape') + assert node is not None + + +def test_flatten_cls_head(): + pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx') + + try: + from mmdeploy.backend.torchscript import ts_optimizer + opt_pass = ts_optimizer.onnx._jit_pass_flatten_cls_head + except ImportError: + pytest.skip('pass not found.') + + def _optimize_onnx(graph, params_dict, torch_out): + opt_pass(graph) + return graph, params_dict, torch_out + + class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + batch = x.size(0) + gap = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + gap = gap.reshape(batch, -1) + return gap + 0 # gap should not be the output + + model = TestModel() + x = torch.rand(1, 4, 8, 8) + + with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + torch.onnx.export( + model, + x, + onnx_file, + input_names=['input'], + output_names=['output'], + dynamic_axes=dict(input={ + 2: 'h', + 3: 'w' + }), + opset_version=11) + + onnx_model = onnx.load(onnx_file) + graph = onnx_model.graph + nodes = graph.node + + node, idx = _find_next_node(0, nodes, 'GlobalAveragePool') + assert node is not None + + node, idx = _find_next_node(idx + 1, nodes, 'Flatten') + assert node is not None