refactor(onnx2ncnn): add test case and simplify code (#436)

* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils

* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils

* refactor(onnx2ncnn.cpp): split code

* refactor(net_module.cpp): fix build error

* ci(test_onnx2ncnn.py): add generate model adn run

* ci(onnx2ncnn): add ncnn backend

* ci(test_onnx2ncnn): add converted onnx model`

* ci(onnx2ncnn): fix ncnn tar

* ci(backed-ncnn): simplify dependency install

* ci(onnx2ncnn): fix apt install

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* fix(ci): add include algorithm

* Update build.yml

* parent aa85760531
author q.yao <streetyao@live.com> 1651287879 +0800
committer tpoisonooo <khj.application@aliyun.com> 1652169959 +0800

[Fix] Fix ci (#426)

* fix ci

* add nvidia key

* remote torch

* recover pytorch

refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils

* fix(onnx2ncnn): review

* fix(onnx2ncnn): build error

Co-authored-by: q.yao <streetyao@live.com>
pull/442/head
tpoisonooo 2022-05-16 10:36:25 +08:00 committed by GitHub
parent 6eb83a9daa
commit d04c8dc9c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 3251 additions and 2925 deletions

View File

@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import subprocess
# list of tuple: config, pretrained model, onnx filename
CONFIGS = [
(
'mmclassification/configs/vision_transformer/vit-base-p32_ft-64xb64_in1k-384.py', # noqa: E501
'https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth', # noqa: E501
'vit.onnx'),
(
'mmclassification/configs/resnet/resnet50_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', # noqa: E501
'resnet50.onnx',
),
(
'mmclassification/configs/resnet/resnet18_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', # noqa: E501
'resnet18.onnx',
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/resnet18.onnx', # noqa: E501
),
(
'mmclassification/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', # noqa: E501
'mobilenet-v2.onnx',
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/mobilenet-v2.onnx', # noqa: E501
)
]
def parse_args():
parser = argparse.ArgumentParser(
description='MMDeploy onnx2ncnn test tool.')
parser.add_argument('--run', type=bool, help='Execute onnx2ncnn bin.')
parser.add_argument(
'--repo-dir', type=str, default='~/', help='mmcls directory.')
parser.add_argument(
'--out',
type=str,
default='onnx_output',
help='onnx model output directory.')
parser.add_argument(
'--generate-onnx', type=bool, help='Generate onnx model.')
args = parser.parse_args()
return args
def generate_onnx(args):
import mmcv
mmcv.mkdir_or_exist(args.out)
for conf in CONFIGS:
config = os.path.join(args.repo_dir, conf[0])
model = conf[1]
convert_cmd = [
'python3', 'tools/deploy.py',
'configs/mmcls/classification_ncnn_static.py', config, model,
'cat-dog.png', '--work-dir', 'work_dir', '--device', 'cpu'
]
print(subprocess.call(convert_cmd))
move_cmd = [
'mv', 'work_dir/end2end.onnx',
os.path.join(args.out, conf[2])
]
print(subprocess.call(move_cmd))
def run(args):
for conf in CONFIGS:
if len(conf) < 4:
continue
download_url = conf[3]
filename = conf[2]
download_cmd = ['wget', download_url]
# show processbar
os.system(' '.join(download_cmd))
convert_cmd = ['./onnx2ncnn', filename, 'onnx.param', 'onnx.bin']
subprocess.run(convert_cmd, capture_output=True, check=True)
def main():
"""test `onnx2ncnn.cpp`
First generate onnx model then convert it with `onnx2ncnn`.
"""
args = parse_args()
if args.generate_onnx:
generate_onnx(args)
if args.run:
run(args)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,68 @@
name: backend
on:
push:
paths-ignore:
- "demo/**"
- "tools/**"
pull_request:
paths-ignore:
- "demo/**"
- "tools/**"
- "docs/**"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test_onnx2ncnn:
runs-on: ubuntu-18.04
strategy:
matrix:
python-version: [3.7]
torch: [1.9.0]
mmcv: [1.4.2]
include:
- torch: 1.9.0
torch_version: torch1.9
torchvision: 0.10.0
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install unittest dependencies
run: |
pip install cmake onnx
- name: update
run: sudo apt update
- name: gcc-multilib
run: sudo apt install gcc-multilib g++-multilib wget libprotobuf-dev protobuf-compiler
- name: Install ncnn
run: |
wget https://github.com/Tencent/ncnn/archive/refs/tags/20220420.tar.gz
tar xf 20220420.tar.gz
pushd ncnn-20220420
mkdir build && pushd build
cmake -DCMAKE_INSTALL_PREFIX=$(pwd)/../install -DNCNN_BUILD_TESTS=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF ..
cmake --build . -j2
make install
popd && popd
- name: Install mmdeploy with ncnn backend
run: |
mkdir -p build && pushd build
export LD_LIBRARY_PATH=/home/runner/work/mmdeploy/mmdeploy/ncnn-20220420/install/lib/:$LD_LIBRARY_PATH
cmake -DMMDEPLOY_TARGET_BACKENDS=ncnn -Dncnn_DIR=/home/runner/work/mmdeploy/mmdeploy/ncnn-20220420/install/lib/cmake/ncnn/ ..
make onnx2ncnn -j2
popd
- name: Test onnx2ncnn
run: |
echo $(pwd)
ln -s build/bin/onnx2ncnn ./
python3 .github/scripts/test_onnx2ncnn.py --run 1

View File

@ -152,4 +152,4 @@ jobs:
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
gcov_ignore : [".github/scripts/doc_link_checker.py"]
gcov_ignore : [".github/scripts/*"]

View File

@ -5,7 +5,7 @@ endif ()
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
cmake_minimum_required(VERSION 3.14)
project(MMDeploy VERSION 0.1.0)
project(MMDeploy VERSION 0.5.0)
set(CMAKE_CXX_STANDARD 17)

View File

@ -7,7 +7,7 @@ find_package(Protobuf)
if (PROTOBUF_FOUND)
protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS
${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto)
add_executable(onnx2ncnn onnx2ncnn.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
add_executable(onnx2ncnn onnx2ncnn.cpp fuse_pass.cpp shape_inference.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
target_include_directories(onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR}
${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES})

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,120 @@
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "shape_inference.h"
#include "utils.h"
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_shufflechannel(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
/**
* @brief fuse subgraph
*
* conv - - - - - - - - - - - -> reshape
* \ /
* shape - slice - concat
*
* to
*
* conv --> reshape
*
* @param mutable_graph
* @param weights
* @param node_reference
* @param blob_names
* @param reduced_node_count
*/
void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_hardswish(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_normalize(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_groupnorm(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_layernorm(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_flatten(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_pixelshuffle(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);
void fuse_expand_broadcast(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_weight_transpose(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
int& reduced_node_count);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,168 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "shape_inference.h"
/**
* @brief query output shape of target node
*
* @param mutable_graph
* @param target
* @param weights
* @param context <tensor name, shape>
* @return std::tuple<bool, std::vector<int>>
*/
std::tuple<bool, std::vector<int>> query_shape(
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
const std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, std::vector<int>>& context) {
// emplace all input nodes
const int input_count = mutable_graph->input_size();
for (int i = 0; i < input_count; i++) {
auto inp = mutable_graph->input(i);
onnx::TypeProto inp_type = inp.type();
onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape();
auto dim_size = shape_proto.dim_size();
std::vector<int> shape(dim_size);
for (int index = 0; index < dim_size; ++index) {
shape[index] = shape_proto.dim(index).dim_value();
}
context.emplace(inp.name(), shape);
}
// BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes
std::vector<onnx::NodeProto*> serial = {target};
{
std::set<std::string> mark_as_appended = {};
while (true) {
int start = 0, end = serial.size();
for (int i = start; i < end; ++i) {
auto node_ptr = serial[i];
auto len = node_ptr->input_size();
for (int j = 0; j < len; ++j) {
std::string name = node_ptr->input(j);
if (context.find(name) != context.end()) {
// if input founded, skip
continue;
}
if (weights.find(name) != weights.end()) {
// if founded in weights, extract shape to context
auto weight = weights.at(name);
std::vector<int> shape;
for (auto index = 0; index < weight.dims_size(); ++index) {
shape.emplace_back(weight.dims(index));
}
context.emplace(name, shape);
continue;
}
if (mark_as_appended.find(name) != mark_as_appended.end()) {
// if mark as appended, skip
continue;
}
// else append it to serialization list
auto depend_ptr = find_node_by_output_name(mutable_graph, name);
if (depend_ptr == nullptr) {
fprintf(stderr, "cannot find %s from graph !\n", name.c_str());
return std::make_tuple(false, std::vector<int>{});
}
mark_as_appended.insert(name);
serial.emplace_back(depend_ptr);
}
}
if (serial.size() <= end) {
// if not new node added, quit
break;
}
// update start and end position, continue BFS the tree
start = end;
end = serial.size();
}
}
// for each node in serialization list, calculate the output shape
{
std::reverse(serial.begin(), serial.end());
for (auto node : serial) {
if (node->op_type() == "Conv") {
auto inp = context[node->input(0)];
auto weight = context[node->input(1)];
assert(inp.size() == 4 and weight.size() == 4);
int group = get_node_attr_i(*node, "group", 1);
assert(group == 1);
// treat multiple spatial attr as single one
#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \
int ATTR = DEFAULT; \
{ \
std::vector<int> _vec = get_node_attr_ai(*node, NAME); \
if (not _vec.empty()) { \
ATTR = _vec[0]; \
} \
}
EXTRACT_REPEATED_PARAM("dilations", dilation, 1);
EXTRACT_REPEATED_PARAM("pads", pad, 0);
EXTRACT_REPEATED_PARAM("strides", stride, 1);
#undef EXTRACT_REPEATED_PARAM
int on = inp[0];
int oc = weight[0];
int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1;
int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1;
context.emplace(node->output(0), std::vector<int>{on, oc, oh, ow});
} else if (node->op_type() == "Shape") {
auto inp = context[node->input(0)];
context.emplace(node->output(0), std::vector<int>{1, inp[1], inp[2], inp[3]});
} else if (node->op_type() == "Slice") {
assert(node->input_size() >= 4);
auto inp = context[node->input(0)];
int start = get_node_attr_from_input<int>(weights.at(node->input(1)));
int end = get_node_attr_from_input<int>(weights.at(node->input(2)));
int axes = get_node_attr_from_input<int>(weights.at(node->input(3)));
if (axes != 0) {
fprintf(stderr, "Not support axes=%d !\n", axes);
return std::make_tuple(false, std::vector<int>{});
}
assert(inp.size() >= end - start);
context.emplace(node->output(0), std::vector<int>{inp.begin() + start, inp.begin() + end});
} else if (node->op_type() == "Concat") {
assert(node->input_size() >= 2);
auto axis = get_node_attr_i(*node, "axis", 0);
if (axis != 0) {
fprintf(stderr, "Not support axes=%d !\n", axis);
return std::make_tuple(false, std::vector<int>{});
}
std::vector<int> inp = context[node->input(0)];
std::vector<int> w_data = get_node_attr_from_input_ai(weights.at(node->input(1)));
// concat data on axis 0
inp.insert(inp.end(), w_data.begin(), w_data.end());
context.emplace(node->output(0), inp);
} else {
fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str());
return std::make_tuple(false, std::vector<int>{});
}
}
}
assert(context.find(target->output(0)) != context.end());
auto target_shape = context[target->output(0)];
return std::make_tuple(true, target_shape);
}

View File

@ -0,0 +1,20 @@
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <algorithm>
#include "utils.h"
/**
* @brief query output shape of target node
*
* @param mutable_graph
* @param target
* @param weights
* @param context <tensor name, shape>
* @return std::tuple<bool, std::vector<int>>
*/
std::tuple<bool, std::vector<int>> query_shape(
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
const std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, std::vector<int>>& context);

View File

@ -0,0 +1,401 @@
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <float.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <limits.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include "onnx.pb.h"
/**
* @brief find graph node by output name
*
* @param graph
* @param name
* @return onnx::NodeProto*
*/
static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph,
const std::string& name) {
const int input_count = mutable_graph->node_size();
for (int i = 0; i < input_count; ++i) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
for (int j = 0; j < node->output_size(); ++j) {
auto output = node->output(j);
if (output == name) {
return node;
}
}
}
return nullptr;
}
static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) {
std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) {
fprintf(stderr, "open failed %s\n", filepath);
return false;
}
google::protobuf::io::IstreamInputStream input(&fs);
google::protobuf::io::CodedInputStream codedstr(&input);
#if GOOGLE_PROTOBUF_VERSION >= 3011000
codedstr.SetTotalBytesLimit(INT_MAX);
#else
codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
#endif
bool success = message->ParseFromCodedStream(&codedstr);
fs.close();
return success;
}
static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key) {
std::vector<int> v;
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
v.resize(attr.ints_size());
for (int j = 0; j < attr.ints_size(); j++) {
v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
}
break;
}
}
return v;
}
static void set_node_attr_ai(onnx::NodeProto& node, const char* key,
const std::vector<int>& value) {
onnx::AttributeProto* attr_group = node.add_attribute();
attr_group->set_name(key);
for (auto v : value) {
attr_group->add_ints(v);
}
return;
}
static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key) {
std::vector<float> v;
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
v.resize(attr.floats_size());
for (int j = 0; j < attr.floats_size(); j++) {
v[j] = attr.floats(j);
}
break;
}
}
return v;
}
static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
}
}
return def;
}
static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return attr.f();
}
}
return def;
}
static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key,
const std::string& def = std::string()) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return attr.s();
}
}
return def;
}
static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) {
for (int i = 0; i < node.attribute_size(); i++) {
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key) {
return attr.t();
}
}
return onnx::TensorProto();
}
template <typename T>
static T get_node_attr_from_input(const onnx::TensorProto& tp) {
T v = 0.f;
// float
if (tp.data_type() == 1) {
const float* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const float*)tp.raw_data().data();
} else {
shape_data = tp.float_data().data();
}
v = shape_data[0];
}
// double
else if (tp.data_type() == 11) {
const double* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const double*)tp.raw_data().data();
} else {
shape_data = tp.double_data().data();
}
v = shape_data[0];
}
// int64
else if (tp.data_type() == 7) {
const int64_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int64_t*)tp.raw_data().data();
} else {
shape_data = tp.int64_data().data();
}
v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
}
// int32
else if (tp.data_type() == 6) {
const int32_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int32_t*)tp.raw_data().data();
} else {
shape_data = tp.int32_data().data();
}
v = shape_data[0];
} else {
// fprintf(stderr, "tp.name: %s\n", tp.name().c_str());
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
fprintf(stderr, "get_node_attr_from_input\n");
abort();
}
return v;
}
static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp) {
int size = 0;
std::vector<int> v;
// int64
if (tp.data_type() == 7) {
const int64_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int64_t*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 8);
} else {
shape_data = tp.int64_data().data();
size = tp.int64_data_size();
}
for (int j = 0; j < size; j++) {
int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX),
(::google::protobuf::int64)INT_MIN);
v.push_back(vi);
}
}
// int32
else if (tp.data_type() == 6) {
const int32_t* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const int32_t*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 4);
} else {
shape_data = tp.int32_data().data();
size = tp.int32_data_size();
}
for (int j = 0; j < size; j++) {
v.push_back(shape_data[j]);
}
} else {
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
}
return v;
}
static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp) {
int size = 0;
std::vector<float> v;
// float
if (tp.data_type() == 1) {
const float* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const float*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 4);
} else {
shape_data = tp.float_data().data();
size = tp.float_data_size();
}
for (int j = 0; j < size; j++) {
v.push_back(shape_data[j]);
}
}
// double
else if (tp.data_type() == 11) {
const double* shape_data = 0;
if (tp.has_raw_data()) {
shape_data = (const double*)tp.raw_data().data();
size = (int)(tp.raw_data().size() / 8);
} else {
shape_data = tp.double_data().data();
size = tp.double_data_size();
}
for (int j = 0; j < size; j++) {
v.push_back((float)shape_data[j]);
}
} else {
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
}
return v;
}
static int get_tensor_proto_data_size(const onnx::TensorProto& tp) {
if (tp.has_raw_data()) {
if (tp.data_type() == 1 || tp.data_type() == 6) {
const std::string& raw_data = tp.raw_data();
int size = (int)raw_data.size() / 4;
return size;
} else if (tp.data_type() == 7 || tp.data_type() == 11) {
const std::string& raw_data = tp.raw_data();
int size = (int)raw_data.size() / 8;
return size;
} else if (tp.data_type() == 9) {
const std::string& raw_data = tp.raw_data();
return 0;
}
} else if (tp.data_type() == 1) {
return tp.float_data_size();
} else if (tp.data_type() == 7) {
return tp.int64_data_size();
} else if (tp.data_type() == 6) {
return tp.int32_data_size();
} else if (tp.data_type() == 11) {
return tp.double_data_size();
}
return 0;
}
static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) {
int size = get_tensor_proto_data_size(tp);
if (tp.has_raw_data()) {
const std::string& raw_data = tp.raw_data();
fwrite(raw_data.data(), sizeof(float), size, bp);
} else if (tp.data_type() == 1) {
fwrite(tp.float_data().data(), sizeof(float), size, bp);
}
}
static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) {
int size = get_tensor_proto_data_size(tp);
size_t written_size;
if (tp.has_raw_data()) {
const std::string& raw_data = tp.raw_data();
if (tp.data_type() == 6) {
int* intdataptr = (int*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 7) {
int64_t* intdataptr = (int64_t*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 9) {
bool* intdataptr = (bool*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 11) {
double* doubledataptr = (double*)raw_data.data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)doubledataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
}
} else if (tp.data_type() == 6) {
int* intdataptr = (int*)tp.int32_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 7) {
int64_t* intdataptr = (int64_t*)tp.int64_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 9) {
int* intdataptr = (int*)tp.int64_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)intdataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
} else if (tp.data_type() == 11) {
double* doubledataptr = (double*)tp.double_data().data();
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
for (int i = 0; i < size; i++) {
floatdataptr[i] = (float)doubledataptr[i];
}
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
std::free(floatdataptr);
}
}