mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* add merge_shape_concate * add some peephole optimization * bug fixing * fix for torch1.9 * add flatten cls head * add subgraph matcher with attribute * add ut * fix lint * remove onnx2ncnn * add opset version * axis name * fix peephole * fix symbol compare * add docs
40 lines
1.3 KiB
C++
40 lines
1.3 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
#include <pybind11/pybind11.h>
|
|
|
|
#include <string>
|
|
|
|
#include "optimizer.h"
|
|
#include "passes/onnx/flatten_cls_head.h"
|
|
#include "passes/onnx/merge_shape_concate.h"
|
|
#include "passes/onnx/onnx_peephole.h"
|
|
|
|
namespace mmdeploy {
|
|
namespace torch_jit {
|
|
|
|
void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript",
|
|
const std::string& backend = "torchscript") {
|
|
if (ir == "torchscript") {
|
|
model = optimize_for_torchscript(model);
|
|
} else if (ir == "onnx") {
|
|
model = optimize_for_onnx(model);
|
|
} else {
|
|
fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(),
|
|
backend.c_str());
|
|
exit(-1);
|
|
}
|
|
}
|
|
|
|
PYBIND11_MODULE(ts_optimizer, m) {
|
|
namespace py = pybind11;
|
|
m.def("optimize_for_backend", optimize_for_backend, py::arg("module"),
|
|
py::arg("ir") = std::string("torchscript"),
|
|
py::arg("backend") = std::string("torchscript"));
|
|
py::module_ onnx_module = m.def_submodule("onnx");
|
|
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
|
|
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
|
|
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
|
|
}
|
|
|
|
} // namespace torch_jit
|
|
} // namespace mmdeploy
|