mmdeploy/csrc/backend_ops/torchscript/optimizer/bind.cpp

27 lines
842 B
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include <pybind11/pybind11.h>
#include <string>
#include "optimizer.h"
void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript",
const std::string& backend = "torchscript") {
if (ir == "torchscript") {
model = mmdeploy::optimize_for_torchscript(model);
} else if (ir == "onnx") {
model = mmdeploy::optimize_for_onnx(model);
} else {
fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(),
backend.c_str());
exit(-1);
}
}
PYBIND11_MODULE(ts_optimizer, m) {
namespace py = pybind11;
m.def("optimize_for_backend", optimize_for_backend, py::arg("module"),
py::arg("ir") = std::string("torchscript"),
py::arg("backend") = std::string("torchscript"));
}