mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
refactor(onnx2ncnn): add test case and simplify code (#436)
* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils * refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils * refactor(onnx2ncnn.cpp): split code * refactor(net_module.cpp): fix build error * ci(test_onnx2ncnn.py): add generate model adn run * ci(onnx2ncnn): add ncnn backend * ci(test_onnx2ncnn): add converted onnx model` * ci(onnx2ncnn): fix ncnn tar * ci(backed-ncnn): simplify dependency install * ci(onnx2ncnn): fix apt install * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * Update backend-ncnn.yml * fix(ci): add include algorithm * Update build.yml * parent aa857605319f63bc624a11956e1cd66b5389e4bf author q.yao <streetyao@live.com> 1651287879 +0800 committer tpoisonooo <khj.application@aliyun.com> 1652169959 +0800 [Fix] Fix ci (#426) * fix ci * add nvidia key * remote torch * recover pytorch refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils * fix(onnx2ncnn): review * fix(onnx2ncnn): build error Co-authored-by: q.yao <streetyao@live.com>
This commit is contained in:
parent
6eb83a9daa
commit
d04c8dc9c0
97
.github/scripts/test_onnx2ncnn.py
vendored
Normal file
97
.github/scripts/test_onnx2ncnn.py
vendored
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
# list of tuple: config, pretrained model, onnx filename
|
||||||
|
CONFIGS = [
|
||||||
|
(
|
||||||
|
'mmclassification/configs/vision_transformer/vit-base-p32_ft-64xb64_in1k-384.py', # noqa: E501
|
||||||
|
'https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth', # noqa: E501
|
||||||
|
'vit.onnx'),
|
||||||
|
(
|
||||||
|
'mmclassification/configs/resnet/resnet50_8xb32_in1k.py',
|
||||||
|
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', # noqa: E501
|
||||||
|
'resnet50.onnx',
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'mmclassification/configs/resnet/resnet18_8xb32_in1k.py',
|
||||||
|
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', # noqa: E501
|
||||||
|
'resnet18.onnx',
|
||||||
|
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/resnet18.onnx', # noqa: E501
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'mmclassification/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py',
|
||||||
|
'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', # noqa: E501
|
||||||
|
'mobilenet-v2.onnx',
|
||||||
|
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/mobilenet-v2.onnx', # noqa: E501
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='MMDeploy onnx2ncnn test tool.')
|
||||||
|
parser.add_argument('--run', type=bool, help='Execute onnx2ncnn bin.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--repo-dir', type=str, default='~/', help='mmcls directory.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--out',
|
||||||
|
type=str,
|
||||||
|
default='onnx_output',
|
||||||
|
help='onnx model output directory.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--generate-onnx', type=bool, help='Generate onnx model.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def generate_onnx(args):
|
||||||
|
import mmcv
|
||||||
|
mmcv.mkdir_or_exist(args.out)
|
||||||
|
for conf in CONFIGS:
|
||||||
|
config = os.path.join(args.repo_dir, conf[0])
|
||||||
|
model = conf[1]
|
||||||
|
convert_cmd = [
|
||||||
|
'python3', 'tools/deploy.py',
|
||||||
|
'configs/mmcls/classification_ncnn_static.py', config, model,
|
||||||
|
'cat-dog.png', '--work-dir', 'work_dir', '--device', 'cpu'
|
||||||
|
]
|
||||||
|
print(subprocess.call(convert_cmd))
|
||||||
|
|
||||||
|
move_cmd = [
|
||||||
|
'mv', 'work_dir/end2end.onnx',
|
||||||
|
os.path.join(args.out, conf[2])
|
||||||
|
]
|
||||||
|
print(subprocess.call(move_cmd))
|
||||||
|
|
||||||
|
|
||||||
|
def run(args):
|
||||||
|
for conf in CONFIGS:
|
||||||
|
if len(conf) < 4:
|
||||||
|
continue
|
||||||
|
download_url = conf[3]
|
||||||
|
filename = conf[2]
|
||||||
|
download_cmd = ['wget', download_url]
|
||||||
|
# show processbar
|
||||||
|
os.system(' '.join(download_cmd))
|
||||||
|
|
||||||
|
convert_cmd = ['./onnx2ncnn', filename, 'onnx.param', 'onnx.bin']
|
||||||
|
subprocess.run(convert_cmd, capture_output=True, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""test `onnx2ncnn.cpp`
|
||||||
|
|
||||||
|
First generate onnx model then convert it with `onnx2ncnn`.
|
||||||
|
"""
|
||||||
|
args = parse_args()
|
||||||
|
if args.generate_onnx:
|
||||||
|
generate_onnx(args)
|
||||||
|
if args.run:
|
||||||
|
run(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
68
.github/workflows/backend-ncnn.yml
vendored
Normal file
68
.github/workflows/backend-ncnn.yml
vendored
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
name: backend
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths-ignore:
|
||||||
|
- "demo/**"
|
||||||
|
- "tools/**"
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
paths-ignore:
|
||||||
|
- "demo/**"
|
||||||
|
- "tools/**"
|
||||||
|
- "docs/**"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test_onnx2ncnn:
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.7]
|
||||||
|
torch: [1.9.0]
|
||||||
|
mmcv: [1.4.2]
|
||||||
|
include:
|
||||||
|
- torch: 1.9.0
|
||||||
|
torch_version: torch1.9
|
||||||
|
torchvision: 0.10.0
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
submodules: 'recursive'
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install unittest dependencies
|
||||||
|
run: |
|
||||||
|
pip install cmake onnx
|
||||||
|
- name: update
|
||||||
|
run: sudo apt update
|
||||||
|
- name: gcc-multilib
|
||||||
|
run: sudo apt install gcc-multilib g++-multilib wget libprotobuf-dev protobuf-compiler
|
||||||
|
- name: Install ncnn
|
||||||
|
run: |
|
||||||
|
wget https://github.com/Tencent/ncnn/archive/refs/tags/20220420.tar.gz
|
||||||
|
tar xf 20220420.tar.gz
|
||||||
|
pushd ncnn-20220420
|
||||||
|
mkdir build && pushd build
|
||||||
|
cmake -DCMAKE_INSTALL_PREFIX=$(pwd)/../install -DNCNN_BUILD_TESTS=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF ..
|
||||||
|
cmake --build . -j2
|
||||||
|
make install
|
||||||
|
popd && popd
|
||||||
|
- name: Install mmdeploy with ncnn backend
|
||||||
|
run: |
|
||||||
|
mkdir -p build && pushd build
|
||||||
|
export LD_LIBRARY_PATH=/home/runner/work/mmdeploy/mmdeploy/ncnn-20220420/install/lib/:$LD_LIBRARY_PATH
|
||||||
|
cmake -DMMDEPLOY_TARGET_BACKENDS=ncnn -Dncnn_DIR=/home/runner/work/mmdeploy/mmdeploy/ncnn-20220420/install/lib/cmake/ncnn/ ..
|
||||||
|
make onnx2ncnn -j2
|
||||||
|
popd
|
||||||
|
- name: Test onnx2ncnn
|
||||||
|
run: |
|
||||||
|
echo $(pwd)
|
||||||
|
ln -s build/bin/onnx2ncnn ./
|
||||||
|
python3 .github/scripts/test_onnx2ncnn.py --run 1
|
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -152,4 +152,4 @@ jobs:
|
|||||||
env_vars: OS,PYTHON
|
env_vars: OS,PYTHON
|
||||||
name: codecov-umbrella
|
name: codecov-umbrella
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
gcov_ignore : [".github/scripts/doc_link_checker.py"]
|
gcov_ignore : [".github/scripts/*"]
|
||||||
|
@ -5,7 +5,7 @@ endif ()
|
|||||||
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.14)
|
cmake_minimum_required(VERSION 3.14)
|
||||||
project(MMDeploy VERSION 0.1.0)
|
project(MMDeploy VERSION 0.5.0)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ find_package(Protobuf)
|
|||||||
if (PROTOBUF_FOUND)
|
if (PROTOBUF_FOUND)
|
||||||
protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS
|
protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto)
|
${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto)
|
||||||
add_executable(onnx2ncnn onnx2ncnn.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
|
add_executable(onnx2ncnn onnx2ncnn.cpp fuse_pass.cpp shape_inference.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS})
|
||||||
target_include_directories(onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR}
|
target_include_directories(onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR}
|
||||||
${CMAKE_CURRENT_BINARY_DIR})
|
${CMAKE_CURRENT_BINARY_DIR})
|
||||||
target_link_libraries(onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES})
|
target_link_libraries(onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES})
|
||||||
|
2333
csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp
Normal file
2333
csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp
Normal file
File diff suppressed because it is too large
Load Diff
120
csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h
Normal file
120
csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "shape_inference.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_shufflechannel(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief fuse subgraph
|
||||||
|
*
|
||||||
|
* conv - - - - - - - - - - - -> reshape
|
||||||
|
* \ /
|
||||||
|
* shape - slice - concat
|
||||||
|
*
|
||||||
|
* to
|
||||||
|
*
|
||||||
|
* conv --> reshape
|
||||||
|
*
|
||||||
|
* @param mutable_graph
|
||||||
|
* @param weights
|
||||||
|
* @param node_reference
|
||||||
|
* @param blob_names
|
||||||
|
* @param reduced_node_count
|
||||||
|
*/
|
||||||
|
void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_hardswish(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_normalize(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_groupnorm(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_layernorm(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_flatten(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_pixelshuffle(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_expand_broadcast(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_weight_transpose(onnx::GraphProto* mutable_graph,
|
||||||
|
std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference,
|
||||||
|
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||||
|
|
||||||
|
void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, int>& node_reference, std::set<std::string>& blob_names,
|
||||||
|
int& reduced_node_count);
|
File diff suppressed because it is too large
Load Diff
168
csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp
Normal file
168
csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
#include "shape_inference.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief query output shape of target node
|
||||||
|
*
|
||||||
|
* @param mutable_graph
|
||||||
|
* @param target
|
||||||
|
* @param weights
|
||||||
|
* @param context <tensor name, shape>
|
||||||
|
* @return std::tuple<bool, std::vector<int>>
|
||||||
|
*/
|
||||||
|
std::tuple<bool, std::vector<int>> query_shape(
|
||||||
|
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
|
||||||
|
const std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, std::vector<int>>& context) {
|
||||||
|
// emplace all input nodes
|
||||||
|
const int input_count = mutable_graph->input_size();
|
||||||
|
for (int i = 0; i < input_count; i++) {
|
||||||
|
auto inp = mutable_graph->input(i);
|
||||||
|
onnx::TypeProto inp_type = inp.type();
|
||||||
|
onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape();
|
||||||
|
|
||||||
|
auto dim_size = shape_proto.dim_size();
|
||||||
|
std::vector<int> shape(dim_size);
|
||||||
|
for (int index = 0; index < dim_size; ++index) {
|
||||||
|
shape[index] = shape_proto.dim(index).dim_value();
|
||||||
|
}
|
||||||
|
|
||||||
|
context.emplace(inp.name(), shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes
|
||||||
|
std::vector<onnx::NodeProto*> serial = {target};
|
||||||
|
{
|
||||||
|
std::set<std::string> mark_as_appended = {};
|
||||||
|
while (true) {
|
||||||
|
int start = 0, end = serial.size();
|
||||||
|
for (int i = start; i < end; ++i) {
|
||||||
|
auto node_ptr = serial[i];
|
||||||
|
auto len = node_ptr->input_size();
|
||||||
|
|
||||||
|
for (int j = 0; j < len; ++j) {
|
||||||
|
std::string name = node_ptr->input(j);
|
||||||
|
if (context.find(name) != context.end()) {
|
||||||
|
// if input founded, skip
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weights.find(name) != weights.end()) {
|
||||||
|
// if founded in weights, extract shape to context
|
||||||
|
auto weight = weights.at(name);
|
||||||
|
std::vector<int> shape;
|
||||||
|
for (auto index = 0; index < weight.dims_size(); ++index) {
|
||||||
|
shape.emplace_back(weight.dims(index));
|
||||||
|
}
|
||||||
|
context.emplace(name, shape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mark_as_appended.find(name) != mark_as_appended.end()) {
|
||||||
|
// if mark as appended, skip
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// else append it to serialization list
|
||||||
|
auto depend_ptr = find_node_by_output_name(mutable_graph, name);
|
||||||
|
if (depend_ptr == nullptr) {
|
||||||
|
fprintf(stderr, "cannot find %s from graph !\n", name.c_str());
|
||||||
|
return std::make_tuple(false, std::vector<int>{});
|
||||||
|
}
|
||||||
|
mark_as_appended.insert(name);
|
||||||
|
serial.emplace_back(depend_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (serial.size() <= end) {
|
||||||
|
// if not new node added, quit
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// update start and end position, continue BFS the tree
|
||||||
|
start = end;
|
||||||
|
end = serial.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for each node in serialization list, calculate the output shape
|
||||||
|
{
|
||||||
|
std::reverse(serial.begin(), serial.end());
|
||||||
|
for (auto node : serial) {
|
||||||
|
if (node->op_type() == "Conv") {
|
||||||
|
auto inp = context[node->input(0)];
|
||||||
|
auto weight = context[node->input(1)];
|
||||||
|
assert(inp.size() == 4 and weight.size() == 4);
|
||||||
|
|
||||||
|
int group = get_node_attr_i(*node, "group", 1);
|
||||||
|
assert(group == 1);
|
||||||
|
|
||||||
|
// treat multiple spatial attr as single one
|
||||||
|
#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \
|
||||||
|
int ATTR = DEFAULT; \
|
||||||
|
{ \
|
||||||
|
std::vector<int> _vec = get_node_attr_ai(*node, NAME); \
|
||||||
|
if (not _vec.empty()) { \
|
||||||
|
ATTR = _vec[0]; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
EXTRACT_REPEATED_PARAM("dilations", dilation, 1);
|
||||||
|
EXTRACT_REPEATED_PARAM("pads", pad, 0);
|
||||||
|
EXTRACT_REPEATED_PARAM("strides", stride, 1);
|
||||||
|
|
||||||
|
#undef EXTRACT_REPEATED_PARAM
|
||||||
|
|
||||||
|
int on = inp[0];
|
||||||
|
int oc = weight[0];
|
||||||
|
int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1;
|
||||||
|
int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1;
|
||||||
|
context.emplace(node->output(0), std::vector<int>{on, oc, oh, ow});
|
||||||
|
|
||||||
|
} else if (node->op_type() == "Shape") {
|
||||||
|
auto inp = context[node->input(0)];
|
||||||
|
context.emplace(node->output(0), std::vector<int>{1, inp[1], inp[2], inp[3]});
|
||||||
|
|
||||||
|
} else if (node->op_type() == "Slice") {
|
||||||
|
assert(node->input_size() >= 4);
|
||||||
|
|
||||||
|
auto inp = context[node->input(0)];
|
||||||
|
int start = get_node_attr_from_input<int>(weights.at(node->input(1)));
|
||||||
|
int end = get_node_attr_from_input<int>(weights.at(node->input(2)));
|
||||||
|
int axes = get_node_attr_from_input<int>(weights.at(node->input(3)));
|
||||||
|
|
||||||
|
if (axes != 0) {
|
||||||
|
fprintf(stderr, "Not support axes=%d !\n", axes);
|
||||||
|
return std::make_tuple(false, std::vector<int>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(inp.size() >= end - start);
|
||||||
|
context.emplace(node->output(0), std::vector<int>{inp.begin() + start, inp.begin() + end});
|
||||||
|
|
||||||
|
} else if (node->op_type() == "Concat") {
|
||||||
|
assert(node->input_size() >= 2);
|
||||||
|
|
||||||
|
auto axis = get_node_attr_i(*node, "axis", 0);
|
||||||
|
if (axis != 0) {
|
||||||
|
fprintf(stderr, "Not support axes=%d !\n", axis);
|
||||||
|
return std::make_tuple(false, std::vector<int>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> inp = context[node->input(0)];
|
||||||
|
std::vector<int> w_data = get_node_attr_from_input_ai(weights.at(node->input(1)));
|
||||||
|
|
||||||
|
// concat data on axis 0
|
||||||
|
inp.insert(inp.end(), w_data.begin(), w_data.end());
|
||||||
|
context.emplace(node->output(0), inp);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str());
|
||||||
|
return std::make_tuple(false, std::vector<int>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(context.find(target->output(0)) != context.end());
|
||||||
|
auto target_shape = context[target->output(0)];
|
||||||
|
return std::make_tuple(true, target_shape);
|
||||||
|
}
|
20
csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h
Normal file
20
csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief query output shape of target node
|
||||||
|
*
|
||||||
|
* @param mutable_graph
|
||||||
|
* @param target
|
||||||
|
* @param weights
|
||||||
|
* @param context <tensor name, shape>
|
||||||
|
* @return std::tuple<bool, std::vector<int>>
|
||||||
|
*/
|
||||||
|
std::tuple<bool, std::vector<int>> query_shape(
|
||||||
|
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
|
||||||
|
const std::map<std::string, onnx::TensorProto>& weights,
|
||||||
|
std::map<std::string, std::vector<int>>& context);
|
401
csrc/backend_ops/ncnn/onnx2ncnn/utils.h
Normal file
401
csrc/backend_ops/ncnn/onnx2ncnn/utils.h
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
#include <google/protobuf/io/coded_stream.h>
|
||||||
|
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||||
|
#include <google/protobuf/message.h>
|
||||||
|
#include <google/protobuf/text_format.h>
|
||||||
|
#include <limits.h>
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "onnx.pb.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief find graph node by output name
|
||||||
|
*
|
||||||
|
* @param graph
|
||||||
|
* @param name
|
||||||
|
* @return onnx::NodeProto*
|
||||||
|
*/
|
||||||
|
static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph,
|
||||||
|
const std::string& name) {
|
||||||
|
const int input_count = mutable_graph->node_size();
|
||||||
|
for (int i = 0; i < input_count; ++i) {
|
||||||
|
onnx::NodeProto* node = mutable_graph->mutable_node(i);
|
||||||
|
|
||||||
|
for (int j = 0; j < node->output_size(); ++j) {
|
||||||
|
auto output = node->output(j);
|
||||||
|
if (output == name) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) {
|
||||||
|
std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
|
||||||
|
if (!fs.is_open()) {
|
||||||
|
fprintf(stderr, "open failed %s\n", filepath);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
google::protobuf::io::IstreamInputStream input(&fs);
|
||||||
|
google::protobuf::io::CodedInputStream codedstr(&input);
|
||||||
|
|
||||||
|
#if GOOGLE_PROTOBUF_VERSION >= 3011000
|
||||||
|
codedstr.SetTotalBytesLimit(INT_MAX);
|
||||||
|
#else
|
||||||
|
codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool success = message->ParseFromCodedStream(&codedstr);
|
||||||
|
|
||||||
|
fs.close();
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key) {
|
||||||
|
std::vector<int> v;
|
||||||
|
|
||||||
|
for (int i = 0; i < node.attribute_size(); i++) {
|
||||||
|
const onnx::AttributeProto& attr = node.attribute(i);
|
||||||
|
if (attr.name() == key) {
|
||||||
|
v.resize(attr.ints_size());
|
||||||
|
for (int j = 0; j < attr.ints_size(); j++) {
|
||||||
|
v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX),
|
||||||
|
(::google::protobuf::int64)INT_MIN);
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void set_node_attr_ai(onnx::NodeProto& node, const char* key,
|
||||||
|
const std::vector<int>& value) {
|
||||||
|
onnx::AttributeProto* attr_group = node.add_attribute();
|
||||||
|
attr_group->set_name(key);
|
||||||
|
for (auto v : value) {
|
||||||
|
attr_group->add_ints(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key) {
|
||||||
|
std::vector<float> v;
|
||||||
|
|
||||||
|
for (int i = 0; i < node.attribute_size(); i++) {
|
||||||
|
const onnx::AttributeProto& attr = node.attribute(i);
|
||||||
|
if (attr.name() == key) {
|
||||||
|
v.resize(attr.floats_size());
|
||||||
|
for (int j = 0; j < attr.floats_size(); j++) {
|
||||||
|
v[j] = attr.floats(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) {
|
||||||
|
for (int i = 0; i < node.attribute_size(); i++) {
|
||||||
|
const onnx::AttributeProto& attr = node.attribute(i);
|
||||||
|
if (attr.name() == key) {
|
||||||
|
return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX),
|
||||||
|
(::google::protobuf::int64)INT_MIN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) {
|
||||||
|
for (int i = 0; i < node.attribute_size(); i++) {
|
||||||
|
const onnx::AttributeProto& attr = node.attribute(i);
|
||||||
|
if (attr.name() == key) {
|
||||||
|
return attr.f();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key,
|
||||||
|
const std::string& def = std::string()) {
|
||||||
|
for (int i = 0; i < node.attribute_size(); i++) {
|
||||||
|
const onnx::AttributeProto& attr = node.attribute(i);
|
||||||
|
if (attr.name() == key) {
|
||||||
|
return attr.s();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
|
||||||
|
static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) {
|
||||||
|
for (int i = 0; i < node.attribute_size(); i++) {
|
||||||
|
const onnx::AttributeProto& attr = node.attribute(i);
|
||||||
|
if (attr.name() == key) {
|
||||||
|
return attr.t();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return onnx::TensorProto();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static T get_node_attr_from_input(const onnx::TensorProto& tp) {
|
||||||
|
T v = 0.f;
|
||||||
|
|
||||||
|
// float
|
||||||
|
if (tp.data_type() == 1) {
|
||||||
|
const float* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const float*)tp.raw_data().data();
|
||||||
|
} else {
|
||||||
|
shape_data = tp.float_data().data();
|
||||||
|
}
|
||||||
|
v = shape_data[0];
|
||||||
|
}
|
||||||
|
// double
|
||||||
|
else if (tp.data_type() == 11) {
|
||||||
|
const double* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const double*)tp.raw_data().data();
|
||||||
|
} else {
|
||||||
|
shape_data = tp.double_data().data();
|
||||||
|
}
|
||||||
|
v = shape_data[0];
|
||||||
|
}
|
||||||
|
// int64
|
||||||
|
else if (tp.data_type() == 7) {
|
||||||
|
const int64_t* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const int64_t*)tp.raw_data().data();
|
||||||
|
} else {
|
||||||
|
shape_data = tp.int64_data().data();
|
||||||
|
}
|
||||||
|
v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX),
|
||||||
|
(::google::protobuf::int64)INT_MIN);
|
||||||
|
}
|
||||||
|
// int32
|
||||||
|
else if (tp.data_type() == 6) {
|
||||||
|
const int32_t* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const int32_t*)tp.raw_data().data();
|
||||||
|
} else {
|
||||||
|
shape_data = tp.int32_data().data();
|
||||||
|
}
|
||||||
|
v = shape_data[0];
|
||||||
|
} else {
|
||||||
|
// fprintf(stderr, "tp.name: %s\n", tp.name().c_str());
|
||||||
|
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||||
|
fprintf(stderr, "get_node_attr_from_input\n");
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp) {
|
||||||
|
int size = 0;
|
||||||
|
|
||||||
|
std::vector<int> v;
|
||||||
|
|
||||||
|
// int64
|
||||||
|
if (tp.data_type() == 7) {
|
||||||
|
const int64_t* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const int64_t*)tp.raw_data().data();
|
||||||
|
size = (int)(tp.raw_data().size() / 8);
|
||||||
|
} else {
|
||||||
|
shape_data = tp.int64_data().data();
|
||||||
|
size = tp.int64_data_size();
|
||||||
|
}
|
||||||
|
for (int j = 0; j < size; j++) {
|
||||||
|
int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX),
|
||||||
|
(::google::protobuf::int64)INT_MIN);
|
||||||
|
v.push_back(vi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// int32
|
||||||
|
else if (tp.data_type() == 6) {
|
||||||
|
const int32_t* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const int32_t*)tp.raw_data().data();
|
||||||
|
size = (int)(tp.raw_data().size() / 4);
|
||||||
|
} else {
|
||||||
|
shape_data = tp.int32_data().data();
|
||||||
|
size = tp.int32_data_size();
|
||||||
|
}
|
||||||
|
for (int j = 0; j < size; j++) {
|
||||||
|
v.push_back(shape_data[j]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp) {
|
||||||
|
int size = 0;
|
||||||
|
|
||||||
|
std::vector<float> v;
|
||||||
|
|
||||||
|
// float
|
||||||
|
if (tp.data_type() == 1) {
|
||||||
|
const float* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const float*)tp.raw_data().data();
|
||||||
|
size = (int)(tp.raw_data().size() / 4);
|
||||||
|
} else {
|
||||||
|
shape_data = tp.float_data().data();
|
||||||
|
size = tp.float_data_size();
|
||||||
|
}
|
||||||
|
for (int j = 0; j < size; j++) {
|
||||||
|
v.push_back(shape_data[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// double
|
||||||
|
else if (tp.data_type() == 11) {
|
||||||
|
const double* shape_data = 0;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
shape_data = (const double*)tp.raw_data().data();
|
||||||
|
size = (int)(tp.raw_data().size() / 8);
|
||||||
|
} else {
|
||||||
|
shape_data = tp.double_data().data();
|
||||||
|
size = tp.double_data_size();
|
||||||
|
}
|
||||||
|
for (int j = 0; j < size; j++) {
|
||||||
|
v.push_back((float)shape_data[j]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_tensor_proto_data_size(const onnx::TensorProto& tp) {
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
if (tp.data_type() == 1 || tp.data_type() == 6) {
|
||||||
|
const std::string& raw_data = tp.raw_data();
|
||||||
|
int size = (int)raw_data.size() / 4;
|
||||||
|
return size;
|
||||||
|
} else if (tp.data_type() == 7 || tp.data_type() == 11) {
|
||||||
|
const std::string& raw_data = tp.raw_data();
|
||||||
|
int size = (int)raw_data.size() / 8;
|
||||||
|
return size;
|
||||||
|
} else if (tp.data_type() == 9) {
|
||||||
|
const std::string& raw_data = tp.raw_data();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} else if (tp.data_type() == 1) {
|
||||||
|
return tp.float_data_size();
|
||||||
|
} else if (tp.data_type() == 7) {
|
||||||
|
return tp.int64_data_size();
|
||||||
|
} else if (tp.data_type() == 6) {
|
||||||
|
return tp.int32_data_size();
|
||||||
|
} else if (tp.data_type() == 11) {
|
||||||
|
return tp.double_data_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) {
|
||||||
|
int size = get_tensor_proto_data_size(tp);
|
||||||
|
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
const std::string& raw_data = tp.raw_data();
|
||||||
|
fwrite(raw_data.data(), sizeof(float), size, bp);
|
||||||
|
} else if (tp.data_type() == 1) {
|
||||||
|
fwrite(tp.float_data().data(), sizeof(float), size, bp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) {
|
||||||
|
int size = get_tensor_proto_data_size(tp);
|
||||||
|
size_t written_size;
|
||||||
|
if (tp.has_raw_data()) {
|
||||||
|
const std::string& raw_data = tp.raw_data();
|
||||||
|
if (tp.data_type() == 6) {
|
||||||
|
int* intdataptr = (int*)raw_data.data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)intdataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
} else if (tp.data_type() == 7) {
|
||||||
|
int64_t* intdataptr = (int64_t*)raw_data.data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)intdataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
} else if (tp.data_type() == 9) {
|
||||||
|
bool* intdataptr = (bool*)raw_data.data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)intdataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
} else if (tp.data_type() == 11) {
|
||||||
|
double* doubledataptr = (double*)raw_data.data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)doubledataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
}
|
||||||
|
} else if (tp.data_type() == 6) {
|
||||||
|
int* intdataptr = (int*)tp.int32_data().data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)intdataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
} else if (tp.data_type() == 7) {
|
||||||
|
int64_t* intdataptr = (int64_t*)tp.int64_data().data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)intdataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
} else if (tp.data_type() == 9) {
|
||||||
|
int* intdataptr = (int*)tp.int64_data().data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)intdataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
} else if (tp.data_type() == 11) {
|
||||||
|
double* doubledataptr = (double*)tp.double_data().data();
|
||||||
|
float* floatdataptr = (float*)std::malloc(sizeof(float) * size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
floatdataptr[i] = (float)doubledataptr[i];
|
||||||
|
}
|
||||||
|
written_size = fwrite(floatdataptr, sizeof(float), size, bp);
|
||||||
|
std::free(floatdataptr);
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user