[Enhancement] Add ONNX passes support (#390)
* add merge_shape_concate * add some peephole optimization * bug fixing * fix for torch1.9 * add flatten cls head * add subgraph matcher with attribute * add ut * fix lint * remove onnx2ncnn * add opset version * axis name * fix peephole * fix symbol compare * add docspull/557/head
parent
456076c06b
commit
74243dc98b
|
@ -4,13 +4,19 @@
|
|||
#include <string>
|
||||
|
||||
#include "optimizer.h"
|
||||
#include "passes/onnx/flatten_cls_head.h"
|
||||
#include "passes/onnx/merge_shape_concate.h"
|
||||
#include "passes/onnx/onnx_peephole.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
|
||||
void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript",
|
||||
const std::string& backend = "torchscript") {
|
||||
if (ir == "torchscript") {
|
||||
model = mmdeploy::optimize_for_torchscript(model);
|
||||
model = optimize_for_torchscript(model);
|
||||
} else if (ir == "onnx") {
|
||||
model = mmdeploy::optimize_for_onnx(model);
|
||||
model = optimize_for_onnx(model);
|
||||
} else {
|
||||
fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(),
|
||||
backend.c_str());
|
||||
|
@ -23,4 +29,11 @@ PYBIND11_MODULE(ts_optimizer, m) {
|
|||
m.def("optimize_for_backend", optimize_for_backend, py::arg("module"),
|
||||
py::arg("ir") = std::string("torchscript"),
|
||||
py::arg("backend") = std::string("torchscript"));
|
||||
py::module_ onnx_module = m.def_submodule("onnx");
|
||||
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
|
||||
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
|
||||
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
|
||||
}
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
// modify from:
|
||||
// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/ir/subgraph_matcher.cpp
|
||||
#include "subgraph_matcher.h"
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/ir/attributes.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
||||
#include <regex>
|
||||
#include <stack>
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
|
||||
using torch::jit::AttributeKind;
|
||||
using torch::jit::ClassType;
|
||||
using torch::jit::Node;
|
||||
using torch::jit::Symbol;
|
||||
using torch::jit::Value;
|
||||
|
||||
namespace prim {
|
||||
using namespace ::c10::prim;
|
||||
}
|
||||
|
||||
namespace attr {
|
||||
using namespace ::c10::attr;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief A class implementing an API for comparing subgraphs.
|
||||
*/
|
||||
class SubgraphMatcher::SubgraphMatcherImpl {
|
||||
public:
|
||||
explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute)
|
||||
: pattern_(pattern), match_attribute_(match_attribute) {}
|
||||
|
||||
/**
|
||||
* \brief Compare matchGraph with the part of the graph denoted by a node \p
|
||||
* ANCHOR.
|
||||
*
|
||||
* The anchor node would be compared against the deepest node in the
|
||||
* match-graph. A node is considered matching if its number of inputs/outputs
|
||||
* is the same as in the corresponding matchGraph node, its type is the same,
|
||||
* and all nodes producing input-values also match.
|
||||
*/
|
||||
bool matchesSubgraphFromAnchorNode(Node* anchor);
|
||||
|
||||
/** \brief Return match map for nodes. */
|
||||
std::unordered_map<const Node*, Node*> nodes_map() const { return nodes_map_; }
|
||||
|
||||
/** \brief Return match map for values. */
|
||||
std::unordered_map<const Value*, Value*> values_map() const { return values_map_; }
|
||||
|
||||
private:
|
||||
bool matchValues(const Value* v1, Value* v2);
|
||||
bool matchNodes(const Node* n1, Node* n2);
|
||||
bool matchAttributes(const Node* n1, Node* n2);
|
||||
|
||||
static bool isInput(const Value* v);
|
||||
static bool isOutput(const Value* v);
|
||||
|
||||
std::unordered_map<const Node*, Node*> nodes_map_;
|
||||
std::unordered_map<const Value*, Value*> values_map_;
|
||||
|
||||
const MatchAttribute match_attribute_;
|
||||
const Graph& pattern_;
|
||||
const Node* anchor_ = nullptr;
|
||||
};
|
||||
|
||||
bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) {
|
||||
return v->node()->kind() == prim::Param;
|
||||
}
|
||||
|
||||
bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) {
|
||||
for (const Value* output : v->owningGraph()->outputs()) {
|
||||
if (v == output) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare two Values. V1 is from pattern, V2 is from the actual graph.
|
||||
*
|
||||
* The values are considered matching if:
|
||||
* 1) the nodes defining them match
|
||||
* 2) they have the same number of uses, except they are entry or exit nodes.
|
||||
*/
|
||||
bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) {
|
||||
// Check if we've already visited these values.
|
||||
if (values_map_.count(v1)) {
|
||||
if (values_map_.at(v1) != v2) {
|
||||
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(),
|
||||
" did not match because %", v1->debugName(), " has already been matched with %",
|
||||
values_map_.at(v1)->debugName(), ".\n");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
|
||||
// PARAM, we're comparing entering values - in these two cases the number of
|
||||
// uses don't need to be the same.
|
||||
if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
|
||||
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(),
|
||||
" did not match because number of their uses is different.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add the values to the map before calling matchNodes to avoid infinite
|
||||
// recursion.
|
||||
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
|
||||
values_map_[v1] = v2;
|
||||
return matchNodes(v1->node(), v2->node());
|
||||
}
|
||||
|
||||
bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) {
|
||||
if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) {
|
||||
GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
for (const Symbol& attr_name : n1->attributeNames()) {
|
||||
if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
|
||||
GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(),
|
||||
"' did not match:\n", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
std::vector<long int> n1is, n2is;
|
||||
std::vector<double> n1fs, n2fs;
|
||||
switch (n1->kindOf(attr_name)) {
|
||||
case AttributeKind::s:
|
||||
if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
|
||||
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
|
||||
"' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1,
|
||||
*n2);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case AttributeKind::f:
|
||||
if (n1->f(attr_name) != n2->f(attr_name)) {
|
||||
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
|
||||
"' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1,
|
||||
*n2);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case AttributeKind::i:
|
||||
if (n1->i(attr_name) != n2->i(attr_name)) {
|
||||
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
|
||||
"' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1,
|
||||
*n2);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case AttributeKind::is:
|
||||
n1is = n1->is(attr_name);
|
||||
n2is = n2->is(attr_name);
|
||||
if (n1is.size() != n2is.size()) return false;
|
||||
for (int i = 0; i < n1is.size(); ++i) {
|
||||
if (n1is[i] != n2is[i]) return false;
|
||||
}
|
||||
break;
|
||||
case AttributeKind::fs:
|
||||
n1fs = n1->fs(attr_name);
|
||||
n2fs = n2->fs(attr_name);
|
||||
if (n1fs.size() != n2fs.size()) return false;
|
||||
for (int i = 0; i < n1fs.size(); ++i) {
|
||||
if (n1fs[i] != n2fs[i]) return false;
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
// Other attributes types not supported yet
|
||||
GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(),
|
||||
"' is not supported.\n", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool endsWith(const std::string& str, const std::string& suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
|
||||
*
|
||||
* The nodes are considered matching if:
|
||||
* 1) N1 and N2 are of the same kind.
|
||||
* 2) Number of inputs and outputs is the same.
|
||||
* 3) All input and output values match.
|
||||
*
|
||||
* A special case is when N1 is PARAM - this is considered outside the pattern,
|
||||
* so it matches everything.
|
||||
*/
|
||||
bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) {
|
||||
// Check if we've already visited these nodes.
|
||||
if (nodes_map_.count(n1)) {
|
||||
return nodes_map_.at(n1) == n2;
|
||||
}
|
||||
|
||||
// Param node in pattern graph matches everything.
|
||||
if (n1->kind() == prim::Param) {
|
||||
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
|
||||
return true;
|
||||
}
|
||||
|
||||
// We don't allow matches to span across blocks, so check if N2 is in the same
|
||||
// block as the first (anchor) node.
|
||||
if (n2->owningBlock() != anchor_->owningBlock()) {
|
||||
GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Special handling for matching modules
|
||||
if (n1->kind() == Symbol::fromQualString("match::module")) {
|
||||
if (n2->kind() == prim::GetAttr) {
|
||||
if (!n1->hasAttributeS("name")) {
|
||||
GRAPH_DEBUG(
|
||||
"Nodes did not match because special node match::module does not have 'name' "
|
||||
"attribute:\n",
|
||||
*n1, *n2);
|
||||
return false;
|
||||
}
|
||||
auto t = n2->output()->type()->expect<ClassType>();
|
||||
auto real_typename = t->name()->qualifiedName();
|
||||
auto pattern_typename = n1->s(attr::name);
|
||||
if (!endsWith(real_typename, pattern_typename)) {
|
||||
GRAPH_DEBUG("Nodes did not match because expected module type is different:\n");
|
||||
GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
|
||||
GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
|
||||
GRAPH_DEBUG("Nodes:", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() ||
|
||||
n1->inputs().size() != n2->inputs().size()) {
|
||||
GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (match_attribute_ != NO_MATCH) {
|
||||
if (!matchAttributes(n1, n2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add nodes to the map before calling matchValues to avoid infinite
|
||||
// recursion.
|
||||
nodes_map_[n1] = n2;
|
||||
for (const auto i : c10::irange(n1->outputs().size())) {
|
||||
if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(n1->inputs().size())) {
|
||||
if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively try to match pattern with the actual graph starting from the
|
||||
* exiting node in the pattern and anchor node in the actual graph.
|
||||
*/
|
||||
bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) {
|
||||
GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
|
||||
nodes_map_.clear();
|
||||
values_map_.clear();
|
||||
anchor_ = anchor;
|
||||
|
||||
const Node* bottom_node = *(pattern_.nodes().end());
|
||||
bottom_node = bottom_node->input(0)->node();
|
||||
|
||||
if (!matchNodes(bottom_node, anchor)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const Value* output : pattern_.outputs()) {
|
||||
AT_ASSERT(values_map_.count(output));
|
||||
}
|
||||
|
||||
GRAPH_UPDATE("Pattern matched!\n");
|
||||
return true;
|
||||
}
|
||||
|
||||
SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute)
|
||||
: impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {}
|
||||
|
||||
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
|
||||
return impl_->matchesSubgraphFromAnchorNode(anchor);
|
||||
}
|
||||
|
||||
std::unordered_map<const Node*, Node*> SubgraphMatcher::nodes_map() const {
|
||||
return impl_->nodes_map();
|
||||
}
|
||||
|
||||
std::unordered_map<const Value*, Value*> SubgraphMatcher::values_map() const {
|
||||
return impl_->values_map();
|
||||
}
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef _SUBGRAPH_MATCHER_H_
|
||||
#define _SUBGRAPH_MATCHER_H_
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
#include <memory>
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
using torch::jit::Graph;
|
||||
using torch::jit::Node;
|
||||
using torch::jit::Value;
|
||||
|
||||
enum MatchAttribute { FORCE_MATCH, TRY_MATCH, NO_MATCH };
|
||||
|
||||
class SubgraphMatcher {
|
||||
public:
|
||||
explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH);
|
||||
|
||||
bool matchesSubgraphFromAnchorNode(Node* anchor);
|
||||
|
||||
/** \brief Return match map for nodes. */
|
||||
std::unordered_map<const Node*, Node*> nodes_map() const;
|
||||
|
||||
/** \brief Return match map for values. */
|
||||
std::unordered_map<const Value*, Value*> values_map() const;
|
||||
|
||||
private:
|
||||
class SubgraphMatcherImpl;
|
||||
std::unique_ptr<SubgraphMatcherImpl> impl_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "flatten_cls_head.h"
|
||||
|
||||
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
|
||||
using c10::Symbol;
|
||||
using torch::jit::IValue;
|
||||
using torch::jit::Match;
|
||||
using torch::jit::TensorType;
|
||||
using torch::jit::TypeKind;
|
||||
using torch::jit::Value;
|
||||
|
||||
static bool matchClsHead(const Match& match, const std::unordered_map<std::string, Value*>& map) {
|
||||
// TODO: check if value map in latest pytorch can ease the filter.
|
||||
|
||||
// check cat -1
|
||||
{
|
||||
// check if the shape of second inputs is 1
|
||||
auto cat_v1 = match.values_map.at(map.at("cat1"));
|
||||
if (cat_v1->type()->kind() != TypeKind::TensorType) return false;
|
||||
auto cat_v1_type = cat_v1->type()->cast<TensorType>();
|
||||
auto cat_v1_size = cat_v1_type->sizes().concrete_sizes();
|
||||
if (!cat_v1_size.has_value()) return false;
|
||||
IValue cat_v1_size_value(cat_v1_size.value());
|
||||
auto size_list = cat_v1_size_value.toIntList();
|
||||
if (size_list.size() != 1 || size_list[0] != 1) return false;
|
||||
}
|
||||
|
||||
// check unsqueeze
|
||||
auto cat_v0 = match.values_map.at(map.at("cat0"));
|
||||
auto unsqueeze_node = cat_v0->node();
|
||||
{
|
||||
if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false;
|
||||
auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes"));
|
||||
if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false;
|
||||
}
|
||||
|
||||
// check gather
|
||||
auto gather_node = unsqueeze_node->input()->node();
|
||||
auto gather_inputs = gather_node->inputs();
|
||||
{
|
||||
if (!is_kind(gather_node, "onnx::Gather")) return false;
|
||||
auto gather_axis = gather_node->i(Symbol::attr("axis"));
|
||||
if (gather_axis != 0) return false;
|
||||
}
|
||||
|
||||
auto x = match.values_map.at(map.at("x"));
|
||||
// check shape
|
||||
auto shape_node = gather_inputs[0]->node();
|
||||
{
|
||||
if (!is_kind(shape_node, "onnx::Shape")) return false;
|
||||
if (shape_node->input() != x) return false;
|
||||
}
|
||||
|
||||
// check constant
|
||||
auto const_node = gather_inputs[1]->node();
|
||||
{
|
||||
if (!is_kind(const_node, "onnx::Constant")) return false;
|
||||
auto ival = const_node->t(Symbol::attr("value"));
|
||||
if (ival.dim() != 0) return false;
|
||||
auto ival_dataptr = ival.data_ptr<long>();
|
||||
if (ival_dataptr[0] != 0) return false;
|
||||
}
|
||||
|
||||
// check if reshape is the output of the graph
|
||||
auto reshape_pattern = map.at("reshape");
|
||||
auto reshape_node = match.values_map.at(reshape_pattern);
|
||||
auto uses = reshape_node->uses();
|
||||
for (auto use : uses) {
|
||||
auto user = use.user;
|
||||
if (is_kind(user, "prim::Return")) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// from:
|
||||
// x->shape->gather->unsqueeze->concat
|
||||
// | |
|
||||
// gap--------------------------reshape
|
||||
//
|
||||
// to:
|
||||
// x->gap->flatten
|
||||
void FlattenClsHead(std::shared_ptr<Graph>& graph) {
|
||||
std::string pattern = R"IR(
|
||||
graph(%x, %cat0, %cat1):
|
||||
%gap = onnx::GlobalAveragePool(%x)
|
||||
%cat = onnx::Concat[axis=0](%cat0, %cat1)
|
||||
%reshape = onnx::Reshape(%gap, %cat)
|
||||
return (%reshape)
|
||||
)IR";
|
||||
|
||||
std::string replacement = R"IR(
|
||||
graph(%x, %cat0, %cat1):
|
||||
%gap = onnx::GlobalAveragePool(%x)
|
||||
%flatten = onnx::Flatten(%gap)
|
||||
return (%flatten)
|
||||
)IR";
|
||||
|
||||
torch::jit::SubgraphRewriter subgraph_rewriter;
|
||||
subgraph_rewriter.RegisterRewritePattern(pattern, replacement);
|
||||
subgraph_rewriter.runOnGraph(graph, matchClsHead);
|
||||
|
||||
torch::jit::EliminateDeadCode(
|
||||
graph->block(), true,
|
||||
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef _FLATTEN_CLS_HEAD_H_
|
||||
#define _FLATTEN_CLS_HEAD_H_
|
||||
|
||||
#include <torch/script.h>
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
using torch::jit::Graph;
|
||||
|
||||
void FlattenClsHead(std::shared_ptr<Graph>& graph);
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "merge_shape_concate.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
|
||||
using c10::Symbol;
|
||||
using torch::jit::Block;
|
||||
using torch::jit::IValue;
|
||||
using torch::jit::Node;
|
||||
using torch::jit::TensorType;
|
||||
using torch::jit::Value;
|
||||
|
||||
void MergeShapeConcate(Node* node) {
|
||||
auto inputs = node->inputs();
|
||||
|
||||
std::vector<long> gather_value;
|
||||
Value* shape_from = nullptr;
|
||||
|
||||
std::vector<Node*> node_to_remove{node};
|
||||
|
||||
// check pattern shape->gather->unsqueeze->concate
|
||||
for (auto input : inputs) {
|
||||
auto unsqueeze_node = input->node();
|
||||
if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1)
|
||||
return;
|
||||
|
||||
auto axes = unsqueeze_node->is(Symbol::attr("axes"));
|
||||
if (axes.size() != 1 && axes[0] != 0) return;
|
||||
|
||||
auto gather_node = unsqueeze_node->input()->node();
|
||||
if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 ||
|
||||
gather_node->output()->uses().size() != 1)
|
||||
return;
|
||||
|
||||
auto gather_inputs = gather_node->inputs();
|
||||
auto gather_data = gather_inputs[0];
|
||||
auto gather_indices = gather_inputs[1];
|
||||
auto shape_node = gather_data->node();
|
||||
if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return;
|
||||
|
||||
auto current_shape_from = shape_node->input();
|
||||
if (!shape_from) {
|
||||
shape_from = current_shape_from;
|
||||
} else {
|
||||
if (shape_from != current_shape_from) return;
|
||||
}
|
||||
|
||||
auto constant_node = gather_indices->node();
|
||||
if (!is_kind(constant_node, "onnx::Constant")) return;
|
||||
|
||||
auto gather_indices_val = constant_node->t(Symbol::attr("value"));
|
||||
long* data_ptr = gather_indices_val.data_ptr<long>();
|
||||
if (gather_indices_val.dim() == 0) {
|
||||
gather_value.push_back(data_ptr[0]);
|
||||
} else {
|
||||
int element_size = gather_indices_val.element_size();
|
||||
for (int j = 0; j < element_size; ++j) {
|
||||
gather_value.push_back(data_ptr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node});
|
||||
}
|
||||
|
||||
// create constant value
|
||||
auto graph = node->owningGraph();
|
||||
auto const_node = graph->create(Symbol::onnx("Constant"));
|
||||
const_node->t_(Symbol::attr("value"), at::tensor(gather_value));
|
||||
auto first_node = node->owningGraph()->block()->nodes().front();
|
||||
if (const_node != first_node) const_node->insertBefore(first_node);
|
||||
|
||||
// recreate shape node
|
||||
auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from});
|
||||
shape_node->insertBefore(node);
|
||||
|
||||
// create gather node
|
||||
auto gather_node =
|
||||
graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()});
|
||||
|
||||
// insert into graph
|
||||
gather_node->insertAfter(node);
|
||||
node->output()->replaceAllUsesWith(gather_node->output());
|
||||
|
||||
for (auto n : node_to_remove) {
|
||||
n->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
void MergeShapeConcate(Block* block) {
|
||||
auto graph = block->owningGraph();
|
||||
auto it = block->nodes().begin();
|
||||
while (it != block->nodes().end()) {
|
||||
auto node = *it;
|
||||
++it;
|
||||
for (auto block : node->blocks()) {
|
||||
MergeShapeConcate(block);
|
||||
}
|
||||
|
||||
if (is_kind(node, "onnx::Concat")) {
|
||||
MergeShapeConcate(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MergeShapeConcate(const std::shared_ptr<Graph>& graph) { MergeShapeConcate(graph->block()); }
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef _MERGE_SHAPE_CONCATE_H_
|
||||
#define _MERGE_SHAPE_CONCATE_H_
|
||||
|
||||
#include <torch/script.h>
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
using torch::jit::Graph;
|
||||
|
||||
void MergeShapeConcate(const std::shared_ptr<Graph>& graph);
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "onnx_peephole.h"
|
||||
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
|
||||
using c10::Symbol;
|
||||
using torch::jit::Block;
|
||||
using torch::jit::IValue;
|
||||
using torch::jit::Node;
|
||||
using torch::jit::TensorType;
|
||||
using torch::jit::Value;
|
||||
|
||||
void RemoveReshapeChain(Node* node) {
|
||||
// reshape->reshape => reshape
|
||||
auto output = node->output();
|
||||
if (!(output->hasUses())) {
|
||||
return;
|
||||
}
|
||||
auto uses = output->uses();
|
||||
|
||||
for (auto use : uses) {
|
||||
if (is_kind(use.user, "onnx::Reshape") || use.offset != 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto input = node->inputs()[0];
|
||||
output->replaceAllUsesWith(input);
|
||||
|
||||
node->destroy();
|
||||
}
|
||||
|
||||
void RemoveRedundantCast(Node* node) {
|
||||
// Cast(type n)->Cast(type n) => Cast(type n)
|
||||
|
||||
auto to_type = node->i(Symbol::attr("to"));
|
||||
auto input = node->input();
|
||||
|
||||
auto input_node = input->node();
|
||||
if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) {
|
||||
auto output = node->output();
|
||||
|
||||
output->replaceAllUsesWith(input);
|
||||
node->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXPeephole(Block* block) {
|
||||
auto graph = block->owningGraph();
|
||||
auto it = block->nodes().begin();
|
||||
while (it != block->nodes().end()) {
|
||||
auto node = *it;
|
||||
++it;
|
||||
for (auto block : node->blocks()) {
|
||||
ONNXPeephole(block);
|
||||
}
|
||||
|
||||
if (is_kind(node, "onnx::Reshape")) {
|
||||
RemoveReshapeChain(node);
|
||||
} else if (is_kind(node, "onnx::Cast")) {
|
||||
RemoveRedundantCast(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXPeephole(const std::shared_ptr<Graph>& graph) {
|
||||
ONNXPeephole(graph->block());
|
||||
torch::jit::EliminateDeadCode(
|
||||
graph->block(), true,
|
||||
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
|
||||
}
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef _ONNX_PEEPHOLE_H_
|
||||
#define _ONNX_PEEPHOLE_H_
|
||||
|
||||
#include <torch/script.h>
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
using torch::jit::Graph;
|
||||
|
||||
void ONNXPeephole(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif
|
|
@ -0,0 +1,20 @@
|
|||
#ifndef _PASSES_ONNX_UTILS_H_
|
||||
#define _PASSES_ONNX_UTILS_H_
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace torch_jit {
|
||||
using c10::Symbol;
|
||||
using torch::jit::Node;
|
||||
|
||||
inline bool is_kind(const Node* node, const Symbol& symbol) { return node->kind() == symbol; }
|
||||
|
||||
inline bool is_kind(const Node* node, const char* symbol_name) {
|
||||
return is_kind(node, Symbol::fromQualString(symbol_name));
|
||||
}
|
||||
|
||||
} // namespace torch_jit
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif
|
|
@ -0,0 +1,50 @@
|
|||
# ONNX export Optimizer
|
||||
|
||||
This is a tool to optimize ONNX model when exporting from PyTorch.
|
||||
|
||||
## Installation
|
||||
|
||||
Build MMDeploy with `torchscript` support:
|
||||
|
||||
```shell
|
||||
export Torch_DIR=$(python -c "import torch;print(torch.utils.cmake_prefix_path + '/Torch')")
|
||||
|
||||
cmake \
|
||||
-DTorch_DIR=${Torch_DIR} \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="${your_backend};torchscript" \
|
||||
.. # You can also add other build flags if you need
|
||||
|
||||
cmake --build . -- -j$(nproc) && cmake --install .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# import model_to_graph_custom_optimizer so we can hijack onnx.export
|
||||
from mmdeploy.apis.onnx.optimizer import model_to_graph__custom_optimizer # noqa
|
||||
from mmdeploy.core import RewriterContext
|
||||
from mmdeploy.apis.onnx.passes import optimize_onnx
|
||||
|
||||
# load you model here
|
||||
model = create_model()
|
||||
|
||||
# export with ONNX Optimizer
|
||||
x = create_dummy_input()
|
||||
with RewriterContext({}, onnx_custom_passes=optimize_onnx):
|
||||
torch.onnx.export(model, x, output_path)
|
||||
```
|
||||
|
||||
The model would be optimized after export.
|
||||
|
||||
You can also define your own optimizer:
|
||||
|
||||
```python
|
||||
# create the optimize callback
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
ts_optimizer.onnx._jit_pass_onnx_peephole(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
|
||||
# export your model
|
||||
```
|
|
@ -8,6 +8,8 @@ import torch
|
|||
from mmdeploy.apis.core import PIPELINE_MANAGER
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import Backend, get_root_logger
|
||||
from .optimizer import * # noqa
|
||||
from .passes import optimize_onnx
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
|
@ -23,6 +25,7 @@ def export(model: torch.nn.Module,
|
|||
dynamic_axes: Optional[Dict] = None,
|
||||
verbose: bool = False,
|
||||
keep_initializers_as_inputs: Optional[bool] = None,
|
||||
optimize: bool = True,
|
||||
**kwargs):
|
||||
"""Export a PyTorch model into ONNX format. This is a wrap of
|
||||
`torch.onnx.export` with some enhancement.
|
||||
|
@ -64,6 +67,7 @@ def export(model: torch.nn.Module,
|
|||
verbose (bool): Enable verbose model on `torch.onnx.export`.
|
||||
keep_initializers_as_inputs (bool): Whether we should add inputs for
|
||||
each initializer.
|
||||
optimize (bool): Perform optimize on model.
|
||||
"""
|
||||
output_path = output_path_prefix + '.onnx'
|
||||
|
||||
|
@ -102,6 +106,9 @@ def export(model: torch.nn.Module,
|
|||
# patch model
|
||||
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
|
||||
|
||||
if 'onnx_custom_passes' not in context_info:
|
||||
onnx_custom_passes = optimize_onnx if optimize else None
|
||||
context_info['onnx_custom_passes'] = onnx_custom_passes
|
||||
with RewriterContext(**context_info), torch.no_grad():
|
||||
# patch input_metas
|
||||
if input_metas is not None:
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
|
||||
def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
|
||||
"""Rewriter of _model_to_graph, add custom passes."""
|
||||
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
|
||||
|
||||
custom_passes = getattr(ctx, 'onnx_custom_passes', None)
|
||||
|
||||
if custom_passes is not None:
|
||||
assert isinstance(
|
||||
custom_passes, Callable
|
||||
), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.'
|
||||
graph, params_dict, torch_out = custom_passes(graph, params_dict,
|
||||
torch_out)
|
||||
|
||||
return graph, params_dict, torch_out
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .optimize_onnx import optimize_onnx
|
||||
|
||||
__all__ = ['optimize_onnx']
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
def optimize_onnx(graph, params_dict, torch_out):
|
||||
logger = get_root_logger()
|
||||
logger.info('Execute onnx optimize passes.')
|
||||
try:
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
ts_optimizer.onnx._jit_pass_merge_shape_concate(graph)
|
||||
ts_optimizer.onnx._jit_pass_onnx_peephole(graph)
|
||||
ts_optimizer.onnx._jit_pass_flatten_cls_head(graph)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return graph, params_dict, torch_out
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmdeploy.apis.onnx.optimizer import \
|
||||
model_to_graph__custom_optimizer # noqa
|
||||
from mmdeploy.core import RewriterContext
|
||||
|
||||
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
|
||||
|
||||
def _find_next_node(start: int, nodes: List, op_type: str) -> Tuple[Any, int]:
|
||||
for idx, n in enumerate(nodes[start:]):
|
||||
if n.op_type == op_type:
|
||||
return n, idx
|
||||
return None, -1
|
||||
|
||||
|
||||
def test_merge_shape_concate():
|
||||
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')
|
||||
|
||||
try:
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
opt_pass = ts_optimizer.onnx._jit_pass_merge_shape_concate
|
||||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
opt_pass(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.new_zeros(x.shape[-2:])
|
||||
|
||||
model = TestModel()
|
||||
x = torch.rand(1, 3, 4, 8)
|
||||
|
||||
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
|
||||
torch.onnx.export(
|
||||
model,
|
||||
x,
|
||||
onnx_file,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes=dict(input={
|
||||
2: 'h',
|
||||
3: 'w'
|
||||
}),
|
||||
opset_version=11)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
graph = onnx_model.graph
|
||||
nodes = graph.node
|
||||
shape_idx = 0
|
||||
for n in nodes:
|
||||
if n.op_type != 'Shape':
|
||||
shape_idx += 1
|
||||
else:
|
||||
break
|
||||
|
||||
assert shape_idx < len(nodes)
|
||||
assert nodes[shape_idx + 1].op_type == 'Gather'
|
||||
assert nodes[shape_idx + 2].op_type == 'ConstantOfShape'
|
||||
|
||||
|
||||
def test_peephole():
|
||||
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')
|
||||
|
||||
try:
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
opt_pass = ts_optimizer.onnx._jit_pass_onnx_peephole
|
||||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
opt_pass(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = x.int()
|
||||
x = x.int()
|
||||
x = x.float()
|
||||
|
||||
x = x.view(10, -1)
|
||||
y = x.view(2, -1)
|
||||
z = x.view(3, -1)
|
||||
|
||||
return y, z
|
||||
|
||||
model = TestModel()
|
||||
x = torch.rand(2, 3, 5)
|
||||
|
||||
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
|
||||
torch.onnx.export(
|
||||
model,
|
||||
x,
|
||||
onnx_file,
|
||||
input_names=['input'],
|
||||
output_names=['output1', 'output2'],
|
||||
dynamic_axes=dict(input={
|
||||
0: 'b',
|
||||
1: 'c',
|
||||
2: 'w'
|
||||
}),
|
||||
opset_version=11)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
graph = onnx_model.graph
|
||||
nodes = graph.node
|
||||
|
||||
node, idx = _find_next_node(0, nodes, 'Cast')
|
||||
assert node is not None
|
||||
assert node.attribute[0].i == 6
|
||||
|
||||
node, idx = _find_next_node(idx + 1, nodes, 'Cast')
|
||||
assert node is not None
|
||||
assert node.attribute[0].i == 1
|
||||
|
||||
node, idx = _find_next_node(idx + 1, nodes, 'Reshape')
|
||||
assert node is not None
|
||||
|
||||
node, idx = _find_next_node(idx + 1, nodes, 'Reshape')
|
||||
assert node is not None
|
||||
|
||||
|
||||
def test_flatten_cls_head():
|
||||
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')
|
||||
|
||||
try:
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
opt_pass = ts_optimizer.onnx._jit_pass_flatten_cls_head
|
||||
except ImportError:
|
||||
pytest.skip('pass not found.')
|
||||
|
||||
def _optimize_onnx(graph, params_dict, torch_out):
|
||||
opt_pass(graph)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
gap = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||
gap = gap.reshape(batch, -1)
|
||||
return gap + 0 # gap should not be the output
|
||||
|
||||
model = TestModel()
|
||||
x = torch.rand(1, 4, 8, 8)
|
||||
|
||||
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
|
||||
torch.onnx.export(
|
||||
model,
|
||||
x,
|
||||
onnx_file,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes=dict(input={
|
||||
2: 'h',
|
||||
3: 'w'
|
||||
}),
|
||||
opset_version=11)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
graph = onnx_model.graph
|
||||
nodes = graph.node
|
||||
|
||||
node, idx = _find_next_node(0, nodes, 'GlobalAveragePool')
|
||||
assert node is not None
|
||||
|
||||
node, idx = _find_next_node(idx + 1, nodes, 'Flatten')
|
||||
assert node is not None
|
Loading…
Reference in New Issue