[Enhancement] Add ONNX passes support (#390)

* add merge_shape_concate

* add some peephole optimization

* bug fixing

* fix for torch1.9

* add flatten cls head

* add subgraph matcher with attribute

* add ut

* fix lint

* remove onnx2ncnn

* add opset version

* axis name

* fix peephole

* fix symbol compare

* add docs
pull/557/head
q.yao 2022-06-06 21:30:31 +08:00 committed by GitHub
parent 456076c06b
commit 74243dc98b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1026 additions and 2 deletions

View File

@ -4,13 +4,19 @@
#include <string>
#include "optimizer.h"
#include "passes/onnx/flatten_cls_head.h"
#include "passes/onnx/merge_shape_concate.h"
#include "passes/onnx/onnx_peephole.h"
namespace mmdeploy {
namespace torch_jit {
void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript",
const std::string& backend = "torchscript") {
if (ir == "torchscript") {
model = mmdeploy::optimize_for_torchscript(model);
model = optimize_for_torchscript(model);
} else if (ir == "onnx") {
model = mmdeploy::optimize_for_onnx(model);
model = optimize_for_onnx(model);
} else {
fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(),
backend.c_str());
@ -23,4 +29,11 @@ PYBIND11_MODULE(ts_optimizer, m) {
m.def("optimize_for_backend", optimize_for_backend, py::arg("module"),
py::arg("ir") = std::string("torchscript"),
py::arg("backend") = std::string("torchscript"));
py::module_ onnx_module = m.def_submodule("onnx");
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
}
} // namespace torch_jit
} // namespace mmdeploy

View File

@ -0,0 +1,311 @@
// modify from:
// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/ir/subgraph_matcher.cpp
#include "subgraph_matcher.h"
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/jit_log.h>
#include <regex>
#include <stack>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::AttributeKind;
using torch::jit::ClassType;
using torch::jit::Node;
using torch::jit::Symbol;
using torch::jit::Value;
namespace prim {
using namespace ::c10::prim;
}
namespace attr {
using namespace ::c10::attr;
}
/**
* \brief A class implementing an API for comparing subgraphs.
*/
class SubgraphMatcher::SubgraphMatcherImpl {
public:
explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute)
: pattern_(pattern), match_attribute_(match_attribute) {}
/**
* \brief Compare matchGraph with the part of the graph denoted by a node \p
* ANCHOR.
*
* The anchor node would be compared against the deepest node in the
* match-graph. A node is considered matching if its number of inputs/outputs
* is the same as in the corresponding matchGraph node, its type is the same,
* and all nodes producing input-values also match.
*/
bool matchesSubgraphFromAnchorNode(Node* anchor);
/** \brief Return match map for nodes. */
std::unordered_map<const Node*, Node*> nodes_map() const { return nodes_map_; }
/** \brief Return match map for values. */
std::unordered_map<const Value*, Value*> values_map() const { return values_map_; }
private:
bool matchValues(const Value* v1, Value* v2);
bool matchNodes(const Node* n1, Node* n2);
bool matchAttributes(const Node* n1, Node* n2);
static bool isInput(const Value* v);
static bool isOutput(const Value* v);
std::unordered_map<const Node*, Node*> nodes_map_;
std::unordered_map<const Value*, Value*> values_map_;
const MatchAttribute match_attribute_;
const Graph& pattern_;
const Node* anchor_ = nullptr;
};
bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) {
return v->node()->kind() == prim::Param;
}
bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) {
for (const Value* output : v->owningGraph()->outputs()) {
if (v == output) {
return true;
}
}
return false;
}
/**
* Compare two Values. V1 is from pattern, V2 is from the actual graph.
*
* The values are considered matching if:
* 1) the nodes defining them match
* 2) they have the same number of uses, except they are entry or exit nodes.
*/
bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) {
// Check if we've already visited these values.
if (values_map_.count(v1)) {
if (values_map_.at(v1) != v2) {
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(),
" did not match because %", v1->debugName(), " has already been matched with %",
values_map_.at(v1)->debugName(), ".\n");
return false;
}
return true;
}
// When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
// PARAM, we're comparing entering values - in these two cases the number of
// uses don't need to be the same.
if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(),
" did not match because number of their uses is different.\n");
return false;
}
// Add the values to the map before calling matchNodes to avoid infinite
// recursion.
GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
values_map_[v1] = v2;
return matchNodes(v1->node(), v2->node());
}
bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) {
if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) {
GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
return false;
}
for (const Symbol& attr_name : n1->attributeNames()) {
if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(),
"' did not match:\n", *n1, *n2);
return false;
}
std::vector<long int> n1is, n2is;
std::vector<double> n1fs, n2fs;
switch (n1->kindOf(attr_name)) {
case AttributeKind::s:
if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
"' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1,
*n2);
return false;
}
break;
case AttributeKind::f:
if (n1->f(attr_name) != n2->f(attr_name)) {
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
"' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1,
*n2);
return false;
}
break;
case AttributeKind::i:
if (n1->i(attr_name) != n2->i(attr_name)) {
GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(),
"' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1,
*n2);
return false;
}
break;
case AttributeKind::is:
n1is = n1->is(attr_name);
n2is = n2->is(attr_name);
if (n1is.size() != n2is.size()) return false;
for (int i = 0; i < n1is.size(); ++i) {
if (n1is[i] != n2is[i]) return false;
}
break;
case AttributeKind::fs:
n1fs = n1->fs(attr_name);
n2fs = n2->fs(attr_name);
if (n1fs.size() != n2fs.size()) return false;
for (int i = 0; i < n1fs.size(); ++i) {
if (n1fs[i] != n2fs[i]) return false;
}
break;
default: {
// Other attributes types not supported yet
GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(),
"' is not supported.\n", *n1, *n2);
return false;
}
}
}
return true;
}
static bool endsWith(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
/**
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
*
* The nodes are considered matching if:
* 1) N1 and N2 are of the same kind.
* 2) Number of inputs and outputs is the same.
* 3) All input and output values match.
*
* A special case is when N1 is PARAM - this is considered outside the pattern,
* so it matches everything.
*/
bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) {
// Check if we've already visited these nodes.
if (nodes_map_.count(n1)) {
return nodes_map_.at(n1) == n2;
}
// Param node in pattern graph matches everything.
if (n1->kind() == prim::Param) {
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
return true;
}
// We don't allow matches to span across blocks, so check if N2 is in the same
// block as the first (anchor) node.
if (n2->owningBlock() != anchor_->owningBlock()) {
GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2);
return false;
}
// Special handling for matching modules
if (n1->kind() == Symbol::fromQualString("match::module")) {
if (n2->kind() == prim::GetAttr) {
if (!n1->hasAttributeS("name")) {
GRAPH_DEBUG(
"Nodes did not match because special node match::module does not have 'name' "
"attribute:\n",
*n1, *n2);
return false;
}
auto t = n2->output()->type()->expect<ClassType>();
auto real_typename = t->name()->qualifiedName();
auto pattern_typename = n1->s(attr::name);
if (!endsWith(real_typename, pattern_typename)) {
GRAPH_DEBUG("Nodes did not match because expected module type is different:\n");
GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
GRAPH_DEBUG("Nodes:", *n1, *n2);
return false;
}
}
} else {
if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() ||
n1->inputs().size() != n2->inputs().size()) {
GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2);
return false;
}
if (match_attribute_ != NO_MATCH) {
if (!matchAttributes(n1, n2)) {
return false;
}
}
}
// Add nodes to the map before calling matchValues to avoid infinite
// recursion.
nodes_map_[n1] = n2;
for (const auto i : c10::irange(n1->outputs().size())) {
if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
return false;
}
}
for (const auto i : c10::irange(n1->inputs().size())) {
if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
return false;
}
}
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
return true;
}
/**
* Recursively try to match pattern with the actual graph starting from the
* exiting node in the pattern and anchor node in the actual graph.
*/
bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) {
GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
nodes_map_.clear();
values_map_.clear();
anchor_ = anchor;
const Node* bottom_node = *(pattern_.nodes().end());
bottom_node = bottom_node->input(0)->node();
if (!matchNodes(bottom_node, anchor)) {
return false;
}
for (const Value* output : pattern_.outputs()) {
AT_ASSERT(values_map_.count(output));
}
GRAPH_UPDATE("Pattern matched!\n");
return true;
}
SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute)
: impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {}
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
return impl_->matchesSubgraphFromAnchorNode(anchor);
}
std::unordered_map<const Node*, Node*> SubgraphMatcher::nodes_map() const {
return impl_->nodes_map();
}
std::unordered_map<const Value*, Value*> SubgraphMatcher::values_map() const {
return impl_->values_map();
}
} // namespace torch_jit
} // namespace mmdeploy

View File

@ -0,0 +1,36 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _SUBGRAPH_MATCHER_H_
#define _SUBGRAPH_MATCHER_H_
#include <torch/script.h>
#include <memory>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::Graph;
using torch::jit::Node;
using torch::jit::Value;
enum MatchAttribute { FORCE_MATCH, TRY_MATCH, NO_MATCH };
class SubgraphMatcher {
public:
explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH);
bool matchesSubgraphFromAnchorNode(Node* anchor);
/** \brief Return match map for nodes. */
std::unordered_map<const Node*, Node*> nodes_map() const;
/** \brief Return match map for values. */
std::unordered_map<const Value*, Value*> values_map() const;
private:
class SubgraphMatcherImpl;
std::unique_ptr<SubgraphMatcherImpl> impl_ = nullptr;
};
} // namespace torch_jit
} // namespace mmdeploy
#endif

View File

@ -0,0 +1,119 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "flatten_cls_head.h"
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <vector>
#include "utils.h"
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::IValue;
using torch::jit::Match;
using torch::jit::TensorType;
using torch::jit::TypeKind;
using torch::jit::Value;
static bool matchClsHead(const Match& match, const std::unordered_map<std::string, Value*>& map) {
// TODO: check if value map in latest pytorch can ease the filter.
// check cat -1
{
// check if the shape of second inputs is 1
auto cat_v1 = match.values_map.at(map.at("cat1"));
if (cat_v1->type()->kind() != TypeKind::TensorType) return false;
auto cat_v1_type = cat_v1->type()->cast<TensorType>();
auto cat_v1_size = cat_v1_type->sizes().concrete_sizes();
if (!cat_v1_size.has_value()) return false;
IValue cat_v1_size_value(cat_v1_size.value());
auto size_list = cat_v1_size_value.toIntList();
if (size_list.size() != 1 || size_list[0] != 1) return false;
}
// check unsqueeze
auto cat_v0 = match.values_map.at(map.at("cat0"));
auto unsqueeze_node = cat_v0->node();
{
if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false;
auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes"));
if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false;
}
// check gather
auto gather_node = unsqueeze_node->input()->node();
auto gather_inputs = gather_node->inputs();
{
if (!is_kind(gather_node, "onnx::Gather")) return false;
auto gather_axis = gather_node->i(Symbol::attr("axis"));
if (gather_axis != 0) return false;
}
auto x = match.values_map.at(map.at("x"));
// check shape
auto shape_node = gather_inputs[0]->node();
{
if (!is_kind(shape_node, "onnx::Shape")) return false;
if (shape_node->input() != x) return false;
}
// check constant
auto const_node = gather_inputs[1]->node();
{
if (!is_kind(const_node, "onnx::Constant")) return false;
auto ival = const_node->t(Symbol::attr("value"));
if (ival.dim() != 0) return false;
auto ival_dataptr = ival.data_ptr<long>();
if (ival_dataptr[0] != 0) return false;
}
// check if reshape is the output of the graph
auto reshape_pattern = map.at("reshape");
auto reshape_node = match.values_map.at(reshape_pattern);
auto uses = reshape_node->uses();
for (auto use : uses) {
auto user = use.user;
if (is_kind(user, "prim::Return")) return false;
}
return true;
}
// from:
// x->shape->gather->unsqueeze->concat
// | |
// gap--------------------------reshape
//
// to:
// x->gap->flatten
void FlattenClsHead(std::shared_ptr<Graph>& graph) {
std::string pattern = R"IR(
graph(%x, %cat0, %cat1):
%gap = onnx::GlobalAveragePool(%x)
%cat = onnx::Concat[axis=0](%cat0, %cat1)
%reshape = onnx::Reshape(%gap, %cat)
return (%reshape)
)IR";
std::string replacement = R"IR(
graph(%x, %cat0, %cat1):
%gap = onnx::GlobalAveragePool(%x)
%flatten = onnx::Flatten(%gap)
return (%flatten)
)IR";
torch::jit::SubgraphRewriter subgraph_rewriter;
subgraph_rewriter.RegisterRewritePattern(pattern, replacement);
subgraph_rewriter.runOnGraph(graph, matchClsHead);
torch::jit::EliminateDeadCode(
graph->block(), true,
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
} // namespace torch_jit
} // namespace mmdeploy

View File

@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FLATTEN_CLS_HEAD_H_
#define _FLATTEN_CLS_HEAD_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::Graph;
void FlattenClsHead(std::shared_ptr<Graph>& graph);
} // namespace torch_jit
} // namespace mmdeploy
#endif

View File

@ -0,0 +1,113 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "merge_shape_concate.h"
#include <vector>
#include "utils.h"
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::Block;
using torch::jit::IValue;
using torch::jit::Node;
using torch::jit::TensorType;
using torch::jit::Value;
void MergeShapeConcate(Node* node) {
auto inputs = node->inputs();
std::vector<long> gather_value;
Value* shape_from = nullptr;
std::vector<Node*> node_to_remove{node};
// check pattern shape->gather->unsqueeze->concate
for (auto input : inputs) {
auto unsqueeze_node = input->node();
if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1)
return;
auto axes = unsqueeze_node->is(Symbol::attr("axes"));
if (axes.size() != 1 && axes[0] != 0) return;
auto gather_node = unsqueeze_node->input()->node();
if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 ||
gather_node->output()->uses().size() != 1)
return;
auto gather_inputs = gather_node->inputs();
auto gather_data = gather_inputs[0];
auto gather_indices = gather_inputs[1];
auto shape_node = gather_data->node();
if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return;
auto current_shape_from = shape_node->input();
if (!shape_from) {
shape_from = current_shape_from;
} else {
if (shape_from != current_shape_from) return;
}
auto constant_node = gather_indices->node();
if (!is_kind(constant_node, "onnx::Constant")) return;
auto gather_indices_val = constant_node->t(Symbol::attr("value"));
long* data_ptr = gather_indices_val.data_ptr<long>();
if (gather_indices_val.dim() == 0) {
gather_value.push_back(data_ptr[0]);
} else {
int element_size = gather_indices_val.element_size();
for (int j = 0; j < element_size; ++j) {
gather_value.push_back(data_ptr[j]);
}
}
node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node});
}
// create constant value
auto graph = node->owningGraph();
auto const_node = graph->create(Symbol::onnx("Constant"));
const_node->t_(Symbol::attr("value"), at::tensor(gather_value));
auto first_node = node->owningGraph()->block()->nodes().front();
if (const_node != first_node) const_node->insertBefore(first_node);
// recreate shape node
auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from});
shape_node->insertBefore(node);
// create gather node
auto gather_node =
graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()});
// insert into graph
gather_node->insertAfter(node);
node->output()->replaceAllUsesWith(gather_node->output());
for (auto n : node_to_remove) {
n->destroy();
}
}
void MergeShapeConcate(Block* block) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
MergeShapeConcate(block);
}
if (is_kind(node, "onnx::Concat")) {
MergeShapeConcate(node);
}
}
}
void MergeShapeConcate(const std::shared_ptr<Graph>& graph) { MergeShapeConcate(graph->block()); }
} // namespace torch_jit
} // namespace mmdeploy

View File

@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _MERGE_SHAPE_CONCATE_H_
#define _MERGE_SHAPE_CONCATE_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::Graph;
void MergeShapeConcate(const std::shared_ptr<Graph>& graph);
} // namespace torch_jit
} // namespace mmdeploy
#endif

View File

@ -0,0 +1,81 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "onnx_peephole.h"
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <vector>
#include "utils.h"
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::Block;
using torch::jit::IValue;
using torch::jit::Node;
using torch::jit::TensorType;
using torch::jit::Value;
void RemoveReshapeChain(Node* node) {
// reshape->reshape => reshape
auto output = node->output();
if (!(output->hasUses())) {
return;
}
auto uses = output->uses();
for (auto use : uses) {
if (is_kind(use.user, "onnx::Reshape") || use.offset != 0) {
return;
}
}
auto input = node->inputs()[0];
output->replaceAllUsesWith(input);
node->destroy();
}
void RemoveRedundantCast(Node* node) {
// Cast(type n)->Cast(type n) => Cast(type n)
auto to_type = node->i(Symbol::attr("to"));
auto input = node->input();
auto input_node = input->node();
if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) {
auto output = node->output();
output->replaceAllUsesWith(input);
node->destroy();
}
}
void ONNXPeephole(Block* block) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
ONNXPeephole(block);
}
if (is_kind(node, "onnx::Reshape")) {
RemoveReshapeChain(node);
} else if (is_kind(node, "onnx::Cast")) {
RemoveRedundantCast(node);
}
}
}
void ONNXPeephole(const std::shared_ptr<Graph>& graph) {
ONNXPeephole(graph->block());
torch::jit::EliminateDeadCode(
graph->block(), true,
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
} // namespace torch_jit
} // namespace mmdeploy

View File

@ -0,0 +1,15 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _ONNX_PEEPHOLE_H_
#define _ONNX_PEEPHOLE_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::jit::Graph;
void ONNXPeephole(const std::shared_ptr<Graph>& graph);
} // namespace torch_jit
} // namespace mmdeploy
#endif

View File

@ -0,0 +1,20 @@
#ifndef _PASSES_ONNX_UTILS_H_
#define _PASSES_ONNX_UTILS_H_
#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using c10::Symbol;
using torch::jit::Node;
inline bool is_kind(const Node* node, const Symbol& symbol) { return node->kind() == symbol; }
inline bool is_kind(const Node* node, const char* symbol_name) {
return is_kind(node, Symbol::fromQualString(symbol_name));
}
} // namespace torch_jit
} // namespace mmdeploy
#endif

View File

@ -0,0 +1,50 @@
# ONNX export Optimizer
This is a tool to optimize ONNX model when exporting from PyTorch.
## Installation
Build MMDeploy with `torchscript` support:
```shell
export Torch_DIR=$(python -c "import torch;print(torch.utils.cmake_prefix_path + '/Torch')")
cmake \
-DTorch_DIR=${Torch_DIR} \
-DMMDEPLOY_TARGET_BACKENDS="${your_backend};torchscript" \
.. # You can also add other build flags if you need
cmake --build . -- -j$(nproc) && cmake --install .
```
## Usage
```python
# import model_to_graph_custom_optimizer so we can hijack onnx.export
from mmdeploy.apis.onnx.optimizer import model_to_graph__custom_optimizer # noqa
from mmdeploy.core import RewriterContext
from mmdeploy.apis.onnx.passes import optimize_onnx
# load you model here
model = create_model()
# export with ONNX Optimizer
x = create_dummy_input()
with RewriterContext({}, onnx_custom_passes=optimize_onnx):
torch.onnx.export(model, x, output_path)
```
The model would be optimized after export.
You can also define your own optimizer:
```python
# create the optimize callback
def _optimize_onnx(graph, params_dict, torch_out):
from mmdeploy.backend.torchscript import ts_optimizer
ts_optimizer.onnx._jit_pass_onnx_peephole(graph)
return graph, params_dict, torch_out
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
# export your model
```

View File

@ -8,6 +8,8 @@ import torch
from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend, get_root_logger
from .optimizer import * # noqa
from .passes import optimize_onnx
@PIPELINE_MANAGER.register_pipeline()
@ -23,6 +25,7 @@ def export(model: torch.nn.Module,
dynamic_axes: Optional[Dict] = None,
verbose: bool = False,
keep_initializers_as_inputs: Optional[bool] = None,
optimize: bool = True,
**kwargs):
"""Export a PyTorch model into ONNX format. This is a wrap of
`torch.onnx.export` with some enhancement.
@ -64,6 +67,7 @@ def export(model: torch.nn.Module,
verbose (bool): Enable verbose model on `torch.onnx.export`.
keep_initializers_as_inputs (bool): Whether we should add inputs for
each initializer.
optimize (bool): Perform optimize on model.
"""
output_path = output_path_prefix + '.onnx'
@ -102,6 +106,9 @@ def export(model: torch.nn.Module,
# patch model
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
if 'onnx_custom_passes' not in context_info:
onnx_custom_passes = optimize_onnx if optimize else None
context_info['onnx_custom_passes'] = onnx_custom_passes
with RewriterContext(**context_info), torch.no_grad():
# patch input_metas
if input_metas is not None:

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
"""Rewriter of _model_to_graph, add custom passes."""
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
custom_passes = getattr(ctx, 'onnx_custom_passes', None)
if custom_passes is not None:
assert isinstance(
custom_passes, Callable
), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.'
graph, params_dict, torch_out = custom_passes(graph, params_dict,
torch_out)
return graph, params_dict, torch_out

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .optimize_onnx import optimize_onnx
__all__ = ['optimize_onnx']

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.utils import get_root_logger
def optimize_onnx(graph, params_dict, torch_out):
logger = get_root_logger()
logger.info('Execute onnx optimize passes.')
try:
from mmdeploy.backend.torchscript import ts_optimizer
ts_optimizer.onnx._jit_pass_merge_shape_concate(graph)
ts_optimizer.onnx._jit_pass_onnx_peephole(graph)
ts_optimizer.onnx._jit_pass_flatten_cls_head(graph)
except Exception:
pass
return graph, params_dict, torch_out

View File

@ -0,0 +1,190 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from typing import Any, List, Tuple
import onnx
import pytest
import torch
import torch.nn as nn
from mmdeploy.apis.onnx.optimizer import \
model_to_graph__custom_optimizer # noqa
from mmdeploy.core import RewriterContext
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
def _find_next_node(start: int, nodes: List, op_type: str) -> Tuple[Any, int]:
for idx, n in enumerate(nodes[start:]):
if n.op_type == op_type:
return n, idx
return None, -1
def test_merge_shape_concate():
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')
try:
from mmdeploy.backend.torchscript import ts_optimizer
opt_pass = ts_optimizer.onnx._jit_pass_merge_shape_concate
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.new_zeros(x.shape[-2:])
model = TestModel()
x = torch.rand(1, 3, 4, 8)
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
torch.onnx.export(
model,
x,
onnx_file,
input_names=['input'],
output_names=['output'],
dynamic_axes=dict(input={
2: 'h',
3: 'w'
}),
opset_version=11)
onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
nodes = graph.node
shape_idx = 0
for n in nodes:
if n.op_type != 'Shape':
shape_idx += 1
else:
break
assert shape_idx < len(nodes)
assert nodes[shape_idx + 1].op_type == 'Gather'
assert nodes[shape_idx + 2].op_type == 'ConstantOfShape'
def test_peephole():
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')
try:
from mmdeploy.backend.torchscript import ts_optimizer
opt_pass = ts_optimizer.onnx._jit_pass_onnx_peephole
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x.int()
x = x.int()
x = x.float()
x = x.view(10, -1)
y = x.view(2, -1)
z = x.view(3, -1)
return y, z
model = TestModel()
x = torch.rand(2, 3, 5)
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
torch.onnx.export(
model,
x,
onnx_file,
input_names=['input'],
output_names=['output1', 'output2'],
dynamic_axes=dict(input={
0: 'b',
1: 'c',
2: 'w'
}),
opset_version=11)
onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
nodes = graph.node
node, idx = _find_next_node(0, nodes, 'Cast')
assert node is not None
assert node.attribute[0].i == 6
node, idx = _find_next_node(idx + 1, nodes, 'Cast')
assert node is not None
assert node.attribute[0].i == 1
node, idx = _find_next_node(idx + 1, nodes, 'Reshape')
assert node is not None
node, idx = _find_next_node(idx + 1, nodes, 'Reshape')
assert node is not None
def test_flatten_cls_head():
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')
try:
from mmdeploy.backend.torchscript import ts_optimizer
opt_pass = ts_optimizer.onnx._jit_pass_flatten_cls_head
except ImportError:
pytest.skip('pass not found.')
def _optimize_onnx(graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
batch = x.size(0)
gap = nn.functional.adaptive_avg_pool2d(x, (1, 1))
gap = gap.reshape(batch, -1)
return gap + 0 # gap should not be the output
model = TestModel()
x = torch.rand(1, 4, 8, 8)
with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
torch.onnx.export(
model,
x,
onnx_file,
input_names=['input'],
output_names=['output'],
dynamic_axes=dict(input={
2: 'h',
3: 'w'
}),
opset_version=11)
onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
nodes = graph.node
node, idx = _find_next_node(0, nodes, 'GlobalAveragePool')
assert node is not None
node, idx = _find_next_node(idx + 1, nodes, 'Flatten')
assert node is not None