// Copyright (c) OpenMMLab. All rights reserved. #include "optimizer.h" #include #include #include #include #include #include #include #include #include #if TORCH_VERSION_MINOR >= 9 #include #include #include #endif namespace mmdeploy { using torch::jit::Graph; const std::shared_ptr& required_passes(const std::shared_ptr& graph) { RemoveExpands(graph); CanonicalizeOps(graph); EliminateDeadCode(graph); return graph; } Module optimize_for_torchscript(const Module& model) { auto frozen_model = freeze_module(model); auto graph = frozen_model.get_method("forward").graph(); OptimizeFrozenGraph(graph, true); #if TORCH_VERSION_MINOR >= 9 FuseFrozenConvAddRelu(graph); ConvertFrozenOpsToMKLDNN(graph); FrozenLinearTranspose(graph); #endif graph = required_passes(graph); EliminateCommonSubexpression(graph); PeepholeOptimize(graph); ConstantPropagation(graph); ConstantPooling(graph); // TODO: add more custom passes return frozen_model; } Module optimize_for_onnx(const Module& model) { auto frozen_model = freeze_module(model, {"training"}); auto graph = frozen_model.get_method("forward").graph(); OptimizeFrozenGraph(graph, true); #if TORCH_VERSION_MINOR >= 9 FuseFrozenConvAddRelu(graph); ConvertFrozenOpsToMKLDNN(graph); FrozenLinearTranspose(graph); #endif // TODO: add more custom passes return frozen_model; } // TODO: add optimizer for other backend/onnx } // namespace mmdeploy