q.yao 74243dc98b
[Enhancement] Add ONNX passes support (#390)
* add merge_shape_concate

* add some peephole optimization

* bug fixing

* fix for torch1.9

* add flatten cls head

* add subgraph matcher with attribute

* add ut

* fix lint

* remove onnx2ncnn

* add opset version

* axis name

* fix peephole

* fix symbol compare

* add docs
2022-06-06 21:30:31 +08:00

40 lines
1.3 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include <pybind11/pybind11.h>
#include <string>
#include "optimizer.h"
#include "passes/onnx/flatten_cls_head.h"
#include "passes/onnx/merge_shape_concate.h"
#include "passes/onnx/onnx_peephole.h"
namespace mmdeploy {
namespace torch_jit {
void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript",
const std::string& backend = "torchscript") {
if (ir == "torchscript") {
model = optimize_for_torchscript(model);
} else if (ir == "onnx") {
model = optimize_for_onnx(model);
} else {
fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(),
backend.c_str());
exit(-1);
}
}
PYBIND11_MODULE(ts_optimizer, m) {
namespace py = pybind11;
m.def("optimize_for_backend", optimize_for_backend, py::arg("module"),
py::arg("ir") = std::string("torchscript"),
py::arg("backend") = std::string("torchscript"));
py::module_ onnx_module = m.def_submodule("onnx");
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
}
} // namespace torch_jit
} // namespace mmdeploy