refactor(onnx2ncnn): add test case and simplify code (#436)
* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils
* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils
* refactor(onnx2ncnn.cpp): split code
* refactor(net_module.cpp): fix build error
* ci(test_onnx2ncnn.py): add generate model adn run
* ci(onnx2ncnn): add ncnn backend
* ci(test_onnx2ncnn): add converted onnx model`
* ci(onnx2ncnn): fix ncnn tar
* ci(backed-ncnn): simplify dependency install
* ci(onnx2ncnn): fix apt install
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* Update backend-ncnn.yml
* fix(ci): add include algorithm
* Update build.yml
* parent aa85760531
author q.yao <streetyao@live.com> 1651287879 +0800
committer tpoisonooo <khj.application@aliyun.com> 1652169959 +0800
[Fix] Fix ci (#426)
* fix ci
* add nvidia key
* remote torch
* recover pytorch
refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils
* fix(onnx2ncnn): review
* fix(onnx2ncnn): build error
Co-authored-by: q.yao <streetyao@live.com>
pull/442/head
parent
6eb83a9daa
commit
d04c8dc9c0
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
# list of tuple: config, pretrained model, onnx filename
|
||||
CONFIGS = [
|
||||
(
|
||||
'mmclassification/configs/vision_transformer/vit-base-p32_ft-64xb64_in1k-384.py', # noqa: E501
|
||||
'https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth', # noqa: E501
|
||||
'vit.onnx'),
|
||||
(
|
||||
'mmclassification/configs/resnet/resnet50_8xb32_in1k.py',
|
||||
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', # noqa: E501
|
||||
'resnet50.onnx',
|
||||
),
|
||||
(
|
||||
'mmclassification/configs/resnet/resnet18_8xb32_in1k.py',
|
||||
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', # noqa: E501
|
||||
'resnet18.onnx',
|
||||
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/resnet18.onnx', # noqa: E501
|
||||
),
|
||||
(
|
||||
'mmclassification/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py',
|
||||
'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', # noqa: E501
|
||||
'mobilenet-v2.onnx',
|
||||
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/mobilenet-v2.onnx', # noqa: E501
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMDeploy onnx2ncnn test tool.')
|
||||
parser.add_argument('--run', type=bool, help='Execute onnx2ncnn bin.')
|
||||
parser.add_argument(
|
||||
'--repo-dir', type=str, default='~/', help='mmcls directory.')
|
||||
parser.add_argument(
|
||||
'--out',
|
||||
type=str,
|
||||
default='onnx_output',
|
||||
help='onnx model output directory.')
|
||||
parser.add_argument(
|
||||
'--generate-onnx', type=bool, help='Generate onnx model.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def generate_onnx(args):
|
||||
import mmcv
|
||||
mmcv.mkdir_or_exist(args.out)
|
||||
for conf in CONFIGS:
|
||||
config = os.path.join(args.repo_dir, conf[0])
|
||||
model = conf[1]
|
||||
convert_cmd = [
|
||||
'python3', 'tools/deploy.py',
|
||||
'configs/mmcls/classification_ncnn_static.py', config, model,
|
||||
'cat-dog.png', '--work-dir', 'work_dir', '--device', 'cpu'
|
||||
]
|
||||
print(subprocess.call(convert_cmd))
|
||||
|
||||
move_cmd = [
|
||||
'mv', 'work_dir/end2end.onnx',
|
||||
os.path.join(args.out, conf[2])
|
||||
]
|
||||
print(subprocess.call(move_cmd))
|
||||
|
||||
|
||||
def run(args):
|
||||
for conf in CONFIGS:
|
||||
if len(conf) < 4:
|
||||
continue
|
||||
download_url = conf[3]
|
||||
filename = conf[2]
|
||||
download_cmd = ['wget', download_url]
|
||||
# show processbar
|
||||
os.system(' '.join(download_cmd))
|
||||
|
||||
convert_cmd = ['./onnx2ncnn', filename, 'onnx.param', 'onnx.bin']
|
||||
subprocess.run(convert_cmd, capture_output=True, check=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""test `onnx2ncnn.cpp`
|
||||
|
||||
First generate onnx model then convert it with `onnx2ncnn`.
|
||||
"""
|
||||
args = parse_args()
|
||||
if args.generate_onnx:
|
||||
generate_onnx(args)
|
||||
if args.run:
|
||||
run(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,68 @@
|
|||
name: backend
|
||||
|
||||
on:
|
||||
push:
|
||||
paths-ignore:
|
||||
- "demo/**"
|
||||
- "tools/**"
|
||||
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- "demo/**"
|
||||
- "tools/**"
|
||||
- "docs/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test_onnx2ncnn:
|
||||
runs-on: ubuntu-18.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.9.0]
|
||||
mmcv: [1.4.2]
|
||||
include:
|
||||
- torch: 1.9.0
|
||||
torch_version: torch1.9
|
||||
torchvision: 0.10.0
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install unittest dependencies
|
||||
run: |
|
||||
pip install cmake onnx
|
||||
- name: update
|
||||
run: sudo apt update
|
||||
- name: gcc-multilib
|
||||
run: sudo apt install gcc-multilib g++-multilib wget libprotobuf-dev protobuf-compiler
|
||||
- name: Install ncnn
|
||||
run: |
|
||||
wget https://github.com/Tencent/ncnn/archive/refs/tags/20220420.tar.gz
|
||||
tar xf 20220420.tar.gz
|
||||
pushd ncnn-20220420
|
||||
mkdir build && pushd build
|
||||
cmake -DCMAKE_INSTALL_PREFIX=$(pwd)/../install -DNCNN_BUILD_TESTS=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF ..
|
||||
cmake --build . -j2
|
||||
make install
|
||||
popd && popd
|
||||
- name: Install mmdeploy with ncnn backend
|
||||
run: |
|
||||
mkdir -p build && pushd build
|
||||
export LD_LIBRARY_PATH=/home/runner/work/mmdeploy/mmdeploy/ncnn-20220420/install/lib/:$LD_LIBRARY_PATH
|
||||
cmake -DMMDEPLOY_TARGET_BACKENDS=ncnn -Dncnn_DIR=/home/runner/work/mmdeploy/mmdeploy/ncnn-20220420/install/lib/cmake/ncnn/ ..
|
||||
make onnx2ncnn -j2
|
||||
popd
|
||||
- name: Test onnx2ncnn
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ln -s build/bin/onnx2ncnn ./
|
||||
python3 .github/scripts/test_onnx2ncnn.py --run 1
|
|
@ -152,4 +152,4 @@ jobs:
|
|||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
gcov_ignore : [".github/scripts/doc_link_checker.py"]
|
||||
gcov_ignore : [".github/scripts/*"]
|
||||
|
|
|
@ -5,7 +5,7 @@ endif ()
|
|||
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
||||
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(MMDeploy VERSION 0.1.0)
|
||||
project(MMDeploy VERSION 0.5.0)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ find_package(Protobuf)
|
|||
if (PROTOBUF_FOUND)
|
||||
protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto)
|
||||
add_executable(onnx2ncnn onnx2ncnn.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
|
||||
add_executable(onnx2ncnn onnx2ncnn.cpp fuse_pass.cpp shape_inference.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
|
||||
target_include_directories(onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR})
|
||||
target_link_libraries(onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES})
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,120 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "shape_inference.h"
|
||||
#include "utils.h"
|
||||
|
||||
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_shufflechannel(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
/**
|
||||
* @brief fuse subgraph
|
||||
*
|
||||
* conv - - - - - - - - - - - -> reshape
|
||||
* \ /
|
||||
* shape - slice - concat
|
||||
*
|
||||
* to
|
||||
*
|
||||
* conv --> reshape
|
||||
*
|
||||
* @param mutable_graph
|
||||
* @param weights
|
||||
* @param node_reference
|
||||
* @param blob_names
|
||||
* @param reduced_node_count
|
||||
*/
|
||||
void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_hardswish(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_normalize(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_groupnorm(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_layernorm(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_flatten(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_pixelshuffle(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
||||
|
||||
void fuse_expand_broadcast(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_weight_transpose(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||
int& reduced_node_count);
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,168 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "shape_inference.h"
|
||||
|
||||
/**
|
||||
* @brief query output shape of target node
|
||||
*
|
||||
* @param mutable_graph
|
||||
* @param target
|
||||
* @param weights
|
||||
* @param context <tensor name, shape>
|
||||
* @return std::tuple<bool, std::vector<int>>
|
||||
*/
|
||||
std::tuple<bool, std::vector<int>> query_shape(
|
||||
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
|
||||
const std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, std::vector<int>>& context) {
|
||||
// emplace all input nodes
|
||||
const int input_count = mutable_graph->input_size();
|
||||
for (int i = 0; i < input_count; i++) {
|
||||
auto inp = mutable_graph->input(i);
|
||||
onnx::TypeProto inp_type = inp.type();
|
||||
onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape();
|
||||
|
||||
auto dim_size = shape_proto.dim_size();
|
||||
std::vector<int> shape(dim_size);
|
||||
for (int index = 0; index < dim_size; ++index) {
|
||||
shape[index] = shape_proto.dim(index).dim_value();
|
||||
}
|
||||
|
||||
context.emplace(inp.name(), shape);
|
||||
}
|
||||
|
||||
// BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes
|
||||
std::vector<onnx::NodeProto*> serial = {target};
|
||||
{
|
||||
std::set<std::string> mark_as_appended = {};
|
||||
while (true) {
|
||||
int start = 0, end = serial.size();
|
||||
for (int i = start; i < end; ++i) {
|
||||
auto node_ptr = serial[i];
|
||||
auto len = node_ptr->input_size();
|
||||
|
||||
for (int j = 0; j < len; ++j) {
|
||||
std::string name = node_ptr->input(j);
|
||||
if (context.find(name) != context.end()) {
|
||||
// if input founded, skip
|
||||
continue;
|
||||
}
|
||||
|
||||
if (weights.find(name) != weights.end()) {
|
||||
// if founded in weights, extract shape to context
|
||||
auto weight = weights.at(name);
|
||||
std::vector<int> shape;
|
||||
for (auto index = 0; index < weight.dims_size(); ++index) {
|
||||
shape.emplace_back(weight.dims(index));
|
||||
}
|
||||
context.emplace(name, shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mark_as_appended.find(name) != mark_as_appended.end()) {
|
||||
// if mark as appended, skip
|
||||
continue;
|
||||
}
|
||||
// else append it to serialization list
|
||||
auto depend_ptr = find_node_by_output_name(mutable_graph, name);
|
||||
if (depend_ptr == nullptr) {
|
||||
fprintf(stderr, "cannot find %s from graph !\n", name.c_str());
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
mark_as_appended.insert(name);
|
||||
serial.emplace_back(depend_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (serial.size() <= end) {
|
||||
// if not new node added, quit
|
||||
break;
|
||||
}
|
||||
|
||||
// update start and end position, continue BFS the tree
|
||||
start = end;
|
||||
end = serial.size();
|
||||
}
|
||||
}
|
||||
|
||||
// for each node in serialization list, calculate the output shape
|
||||
{
|
||||
std::reverse(serial.begin(), serial.end());
|
||||
for (auto node : serial) {
|
||||
if (node->op_type() == "Conv") {
|
||||
auto inp = context[node->input(0)];
|
||||
auto weight = context[node->input(1)];
|
||||
assert(inp.size() == 4 and weight.size() == 4);
|
||||
|
||||
int group = get_node_attr_i(*node, "group", 1);
|
||||
assert(group == 1);
|
||||
|
||||
// treat multiple spatial attr as single one
|
||||
#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \
|
||||
int ATTR = DEFAULT; \
|
||||
{ \
|
||||
std::vector<int> _vec = get_node_attr_ai(*node, NAME); \
|
||||
if (not _vec.empty()) { \
|
||||
ATTR = _vec[0]; \
|
||||
} \
|
||||
}
|
||||
|
||||
EXTRACT_REPEATED_PARAM("dilations", dilation, 1);
|
||||
EXTRACT_REPEATED_PARAM("pads", pad, 0);
|
||||
EXTRACT_REPEATED_PARAM("strides", stride, 1);
|
||||
|
||||
#undef EXTRACT_REPEATED_PARAM
|
||||
|
||||
int on = inp[0];
|
||||
int oc = weight[0];
|
||||
int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1;
|
||||
int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1;
|
||||
context.emplace(node->output(0), std::vector<int>{on, oc, oh, ow});
|
||||
|
||||
} else if (node->op_type() == "Shape") {
|
||||
auto inp = context[node->input(0)];
|
||||
context.emplace(node->output(0), std::vector<int>{1, inp[1], inp[2], inp[3]});
|
||||
|
||||
} else if (node->op_type() == "Slice") {
|
||||
assert(node->input_size() >= 4);
|
||||
|
||||
auto inp = context[node->input(0)];
|
||||
int start = get_node_attr_from_input<int>(weights.at(node->input(1)));
|
||||
int end = get_node_attr_from_input<int>(weights.at(node->input(2)));
|
||||
int axes = get_node_attr_from_input<int>(weights.at(node->input(3)));
|
||||
|
||||
if (axes != 0) {
|
||||
fprintf(stderr, "Not support axes=%d !\n", axes);
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
|
||||
assert(inp.size() >= end - start);
|
||||
context.emplace(node->output(0), std::vector<int>{inp.begin() + start, inp.begin() + end});
|
||||
|
||||
} else if (node->op_type() == "Concat") {
|
||||
assert(node->input_size() >= 2);
|
||||
|
||||
auto axis = get_node_attr_i(*node, "axis", 0);
|
||||
if (axis != 0) {
|
||||
fprintf(stderr, "Not support axes=%d !\n", axis);
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
|
||||
std::vector<int> inp = context[node->input(0)];
|
||||
std::vector<int> w_data = get_node_attr_from_input_ai(weights.at(node->input(1)));
|
||||
|
||||
// concat data on axis 0
|
||||
inp.insert(inp.end(), w_data.begin(), w_data.end());
|
||||
context.emplace(node->output(0), inp);
|
||||
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str());
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(context.find(target->output(0)) != context.end());
|
||||
auto target_shape = context[target->output(0)];
|
||||
return std::make_tuple(true, target_shape);
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
/**
|
||||
* @brief query output shape of target node
|
||||
*
|
||||
* @param mutable_graph
|
||||
* @param target
|
||||
* @param weights
|
||||
* @param context <tensor name, shape>
|
||||
* @return std::tuple<bool, std::vector<int>>
|
||||
*/
|
||||
std::tuple<bool, std::vector<int>> query_shape(
|
||||
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
|
||||
const std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, std::vector<int>>& context);
|
|
@ -0,0 +1,401 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <float.h>
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include <google/protobuf/message.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "onnx.pb.h"
|
||||
|
||||
/**
|
||||
* @brief find graph node by output name
|
||||
*
|
||||
* @param graph
|
||||
* @param name
|
||||
* @return onnx::NodeProto*
|
||||
*/
|
||||
static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph,
|
||||
const std::string& name) {
|
||||
const int input_count = mutable_graph->node_size();
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
onnx::NodeProto* node = mutable_graph->mutable_node(i);
|
||||
|
||||
for (int j = 0; j < node->output_size(); ++j) {
|
||||
auto output = node->output(j);
|
||||
if (output == name) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) {
|
||||
std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
|
||||
if (!fs.is_open()) {
|
||||
fprintf(stderr, "open failed %s\n", filepath);
|
||||
return false;
|
||||
}
|
||||
|
||||
google::protobuf::io::IstreamInputStream input(&fs);
|
||||
google::protobuf::io::CodedInputStream codedstr(&input);
|
||||
|
||||
#if GOOGLE_PROTOBUF_VERSION >= 3011000
|
||||
codedstr.SetTotalBytesLimit(INT_MAX);
|
||||
#else
|
||||
codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
|
||||
#endif
|
||||
|
||||
bool success = message->ParseFromCodedStream(&codedstr);
|
||||
|
||||
fs.close();
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key) {
|
||||
std::vector<int> v;
|
||||
|
||||
for (int i = 0; i < node.attribute_size(); i++) {
|
||||
const onnx::AttributeProto& attr = node.attribute(i);
|
||||
if (attr.name() == key) {
|
||||
v.resize(attr.ints_size());
|
||||
for (int j = 0; j < attr.ints_size(); j++) {
|
||||
v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX),
|
||||
(::google::protobuf::int64)INT_MIN);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static void set_node_attr_ai(onnx::NodeProto& node, const char* key,
|
||||
const std::vector<int>& value) {
|
||||
onnx::AttributeProto* attr_group = node.add_attribute();
|
||||
attr_group->set_name(key);
|
||||
for (auto v : value) {
|
||||
attr_group->add_ints(v);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key) {
|
||||
std::vector<float> v;
|
||||
|
||||
for (int i = 0; i < node.attribute_size(); i++) {
|
||||
const onnx::AttributeProto& attr = node.attribute(i);
|
||||
if (attr.name() == key) {
|
||||
v.resize(attr.floats_size());
|
||||
for (int j = 0; j < attr.floats_size(); j++) {
|
||||
v[j] = attr.floats(j);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) {
|
||||
for (int i = 0; i < node.attribute_size(); i++) {
|
||||
const onnx::AttributeProto& attr = node.attribute(i);
|
||||
if (attr.name() == key) {
|
||||
return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX),
|
||||
(::google::protobuf::int64)INT_MIN);
|
||||
}
|
||||
}
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) {
|
||||
for (int i = 0; i < node.attribute_size(); i++) {
|
||||
const onnx::AttributeProto& attr = node.attribute(i);
|
||||
if (attr.name() == key) {
|
||||
return attr.f();
|
||||
}
|
||||
}
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key,
|
||||
const std::string& def = std::string()) {
|
||||
for (int i = 0; i < node.attribute_size(); i++) {
|
||||
const onnx::AttributeProto& attr = node.attribute(i);
|
||||
if (attr.name() == key) {
|
||||
return attr.s();
|
||||
}
|
||||
}
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) {
|
||||
for (int i = 0; i < node.attribute_size(); i++) {
|
||||
const onnx::AttributeProto& attr = node.attribute(i);
|
||||
if (attr.name() == key) {
|
||||
return attr.t();
|
||||
}
|
||||
}
|
||||
|
||||
return onnx::TensorProto();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T get_node_attr_from_input(const onnx::TensorProto& tp) {
|
||||
T v = 0.f;
|
||||
|
||||
// float
|
||||
if (tp.data_type() == 1) {
|
||||
const float* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const float*)tp.raw_data().data();
|
||||
} else {
|
||||
shape_data = tp.float_data().data();
|
||||
}
|
||||
v = shape_data[0];
|
||||
}
|
||||
// double
|
||||
else if (tp.data_type() == 11) {
|
||||
const double* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const double*)tp.raw_data().data();
|
||||
} else {
|
||||
shape_data = tp.double_data().data();
|
||||
}
|
||||
v = shape_data[0];
|
||||
}
|
||||
// int64
|
||||
else if (tp.data_type() == 7) {
|
||||
const int64_t* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const int64_t*)tp.raw_data().data();
|
||||
} else {
|
||||
shape_data = tp.int64_data().data();
|
||||
}
|
||||
v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX),
|
||||
(::google::protobuf::int64)INT_MIN);
|
||||
}
|
||||
// int32
|
||||
else if (tp.data_type() == 6) {
|
||||
const int32_t* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const int32_t*)tp.raw_data().data();
|
||||
} else {
|
||||
shape_data = tp.int32_data().data();
|
||||
}
|
||||
v = shape_data[0];
|
||||
} else {
|
||||
// fprintf(stderr, "tp.name: %s\n", tp.name().c_str());
|
||||
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||
fprintf(stderr, "get_node_attr_from_input\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp) {
|
||||
int size = 0;
|
||||
|
||||
std::vector<int> v;
|
||||
|
||||
// int64
|
||||
if (tp.data_type() == 7) {
|
||||
const int64_t* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const int64_t*)tp.raw_data().data();
|
||||
size = (int)(tp.raw_data().size() / 8);
|
||||
} else {
|
||||
shape_data = tp.int64_data().data();
|
||||
size = tp.int64_data_size();
|
||||
}
|
||||
for (int j = 0; j < size; j++) {
|
||||
int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX),
|
||||
(::google::protobuf::int64)INT_MIN);
|
||||
v.push_back(vi);
|
||||
}
|
||||
}
|
||||
// int32
|
||||
else if (tp.data_type() == 6) {
|
||||
const int32_t* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const int32_t*)tp.raw_data().data();
|
||||
size = (int)(tp.raw_data().size() / 4);
|
||||
} else {
|
||||
shape_data = tp.int32_data().data();
|
||||
size = tp.int32_data_size();
|
||||
}
|
||||
for (int j = 0; j < size; j++) {
|
||||
v.push_back(shape_data[j]);
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp) {
|
||||
int size = 0;
|
||||
|
||||
std::vector<float> v;
|
||||
|
||||
// float
|
||||
if (tp.data_type() == 1) {
|
||||
const float* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const float*)tp.raw_data().data();
|
||||
size = (int)(tp.raw_data().size() / 4);
|
||||
} else {
|
||||
shape_data = tp.float_data().data();
|
||||
size = tp.float_data_size();
|
||||
}
|
||||
for (int j = 0; j < size; j++) {
|
||||
v.push_back(shape_data[j]);
|
||||
}
|
||||
}
|
||||
// double
|
||||
else if (tp.data_type() == 11) {
|
||||
const double* shape_data = 0;
|
||||
if (tp.has_raw_data()) {
|
||||
shape_data = (const double*)tp.raw_data().data();
|
||||
size = (int)(tp.raw_data().size() / 8);
|
||||
} else {
|
||||
shape_data = tp.double_data().data();
|
||||
size = tp.double_data_size();
|
||||
}
|
||||
for (int j = 0; j < size; j++) {
|
||||
v.push_back((float)shape_data[j]);
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
static int get_tensor_proto_data_size(const onnx::TensorProto& tp) {
|
||||
if (tp.has_raw_data()) {
|
||||
if (tp.data_type() == 1 || tp.data_type() == 6) {
|
||||
const std::string& raw_data = tp.raw_data();
|
||||
int size = (int)raw_data.size() / 4;
|
||||
return size;
|
||||
} else if (tp.data_type() == 7 || tp.data_type() == 11) {
|
||||
const std::string& raw_data = tp.raw_data();
|
||||
int size = (int)raw_data.size() / 8;
|
||||
return size;
|
||||
} else if (tp.data_type() == 9) {
|
||||
const std::string& raw_data = tp.raw_data();
|
||||
return 0;
|
||||
}
|
||||
} else if (tp.data_type() == 1) {
|
||||
return tp.float_data_size();
|
||||
} else if (tp.data_type() == 7) {
|
||||
return tp.int64_data_size();
|
||||
} else if (tp.data_type() == 6) {
|
||||
return tp.int32_data_size();
|
||||
} else if (tp.data_type() == 11) {
|
||||
return tp.double_data_size();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) {
|
||||
int size = get_tensor_proto_data_size(tp);
|
||||
|
||||
if (tp.has_raw_data()) {
|
||||
const std::string& raw_data = tp.raw_data();
|
||||
fwrite(raw_data.data(), sizeof(float), size, bp);
|
||||
} else if (tp.data_type() == 1) {
|
||||
fwrite(tp.float_data().data(), sizeof(float), size, bp);
|
||||
}
|
||||
}
|
||||
|
||||
static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) {
|
||||
int size = get_tensor_proto_data_size(tp);
|
||||
size_t written_size;
|
||||
if (tp.has_raw_data()) {
|
||||
const std::string& raw_data = tp.raw_data();
|
||||
if (tp.data_type() == 6) {
|
||||
int* intdataptr = (int*)raw_data.data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)intdataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
} else if (tp.data_type() == 7) {
|
||||
int64_t* intdataptr = (int64_t*)raw_data.data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)intdataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
} else if (tp.data_type() == 9) {
|
||||
bool* intdataptr = (bool*)raw_data.data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)intdataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
} else if (tp.data_type() == 11) {
|
||||
double* doubledataptr = (double*)raw_data.data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)doubledataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
}
|
||||
} else if (tp.data_type() == 6) {
|
||||
int* intdataptr = (int*)tp.int32_data().data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)intdataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
} else if (tp.data_type() == 7) {
|
||||
int64_t* intdataptr = (int64_t*)tp.int64_data().data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)intdataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
} else if (tp.data_type() == 9) {
|
||||
int* intdataptr = (int*)tp.int64_data().data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)intdataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
} else if (tp.data_type() == 11) {
|
||||
double* doubledataptr = (double*)tp.double_data().data();
|
||||
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
floatdataptr[i] = (float)doubledataptr[i];
|
||||
}
|
||||
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||
std::free(floatdataptr);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue