[Feature] Add RKNN support to SDK (#1145)
* add rknn_net [WIP] * add cmake * enable mmcls * remove toTensor in SDK pipeline * update doc * translate to Chinese * update doc and add tool-chain cmake * use ::framework * fix lint * doc and print log * data map * refine install doc * add rknpu2 workflow * update gcc yaml * better cmake file * update doc link * use vector instead of array * better env variable * use soft link * release ctx * name rulepull/1166/head
parent
8e634059a1
commit
3eb60ea584
|
@ -0,0 +1,54 @@
|
|||
name: build_rknpu2_gcc
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "csrc/**"
|
||||
- "demo/csrc/**"
|
||||
- "CMakeLists.txt"
|
||||
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- "csrc/**"
|
||||
- "demo/csrc/**"
|
||||
- "CMakeLists.txt"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build_rknpu2_gcc:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
- name: rknpu2-gnu-toolchain
|
||||
run: |
|
||||
mkdir $GITHUB_WORKSPACE/rknpu2-gnu-toolchain
|
||||
cd $GITHUB_WORKSPACE/rknpu2-gnu-toolchain
|
||||
git clone https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu.git
|
||||
- name: rknpu2
|
||||
run: |
|
||||
mkdir $GITHUB_WORKSPACE/rknpu2
|
||||
cd $GITHUB_WORKSPACE/rknpu2
|
||||
git clone https://github.com/rockchip-linux/rknpu2.git
|
||||
- name: build
|
||||
run: |
|
||||
export RKNN_TOOL_CHAIN=$GITHUB_WORKSPACE/rknpu2-gnu-toolchain/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu/usr
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
export RKNPU2_DEVICE_DIR=$GITHUB_WORKSPACE/rknpu2/rknpu2/runtime/RK3588
|
||||
mkdir build && cd build
|
||||
cmake .. \
|
||||
-DCMAKE_TOOLCHAIN_FILE=$(pwd)/../cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DMMDEPLOY_SHARED_LIBS=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
-DMMDEPLOY_TARGET_DEVICES="cpu" \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DMMDEPLOY_CODEBASES=all \
|
||||
-DOpenCV_DIR=$RKNPU2_DEVICE_DIR/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV
|
||||
make -j$(nproc)
|
||||
make install
|
|
@ -128,6 +128,7 @@ if (MMDEPLOY_BUILD_SDK)
|
|||
mmdeploy_add_deps(pplnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS pplnn)
|
||||
endif ()
|
||||
mmdeploy_add_deps(snpe BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS snpe)
|
||||
mmdeploy_add_deps(rknn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS rknn)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
# generate the config file that is includes the exports
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
set(CMAKE_SYSTEM_NAME Linux)
|
||||
set(CMAKE_SYSTEM_PROCESSOR rockchip)
|
||||
|
||||
if(DEFINED ENV{RKNN_TOOL_CHAIN})
|
||||
file(TO_CMAKE_PATH $ENV{RKNN_TOOL_CHAIN} RKNN_TOOL_CHAIN)
|
||||
else()
|
||||
message(FATAL_ERROR "RKNN_TOOL_CHAIN env must be defined")
|
||||
endif()
|
||||
|
||||
set(CMAKE_C_COMPILER ${RKNN_TOOL_CHAIN}/bin/aarch64-rockchip-linux-gnu-gcc)
|
||||
set(CMAKE_CXX_COMPILER ${RKNN_TOOL_CHAIN}/bin/aarch64-rockchip-linux-gnu-g++)
|
||||
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
|
||||
set(CMAKE_C_FLAGS "-Wl,--allow-shlib-undefined")
|
||||
set(CMAKE_CXX_FLAGS "-Wl,--allow-shlib-undefined")
|
||||
|
||||
# cache flags
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags")
|
|
@ -38,5 +38,9 @@ if ("coreml" IN_LIST MMDEPLOY_TARGET_BACKENDS)
|
|||
add_subdirectory(coreml)
|
||||
endif ()
|
||||
|
||||
if ("rknn" IN_LIST MMDEPLOY_TARGET_BACKENDS)
|
||||
add_subdirectory(rknn)
|
||||
endif ()
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} net_module.cpp)
|
||||
add_library(mmdeploy::net_module ALIAS ${PROJECT_NAME})
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
project(mmdeploy_rknn_net)
|
||||
|
||||
add_library(rknn SHARED IMPORTED)
|
||||
|
||||
if(DEFINED ENV{RKNPU2_DEVICE_DIR})
|
||||
file(TO_CMAKE_PATH $ENV{RKNPU2_DEVICE_DIR} RKNPU2_DEVICE_DIR)
|
||||
else()
|
||||
message(FATAL_ERROR "RKNPU2_DEVICE_DIR env must be defined")
|
||||
endif()
|
||||
|
||||
set_target_properties(rknn PROPERTIES
|
||||
IMPORTED_LOCATION "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/aarch64/librknn_api.so"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/include"
|
||||
)
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} rknn_net.cpp)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE rknn)
|
||||
add_library(mmdeploy::rknn_net ALIAS ${PROJECT_NAME})
|
|
@ -0,0 +1,216 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "rknn_net.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "mmdeploy/core/logger.h"
|
||||
#include "mmdeploy/core/model.h"
|
||||
#include "mmdeploy/core/utils/filesystem.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
|
||||
namespace mmdeploy::framework {
|
||||
|
||||
Result<rknn_tensor_type> GetRKNNDataType(DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DataType::kFLOAT:
|
||||
return RKNN_TENSOR_FLOAT32;
|
||||
case DataType::kHALF:
|
||||
return RKNN_TENSOR_FLOAT16;
|
||||
case DataType::kINT8:
|
||||
return RKNN_TENSOR_INT8;
|
||||
case DataType::kINT32:
|
||||
return RKNN_TENSOR_INT32;
|
||||
case DataType::kINT64:
|
||||
return RKNN_TENSOR_INT64;
|
||||
default:
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
}
|
||||
|
||||
Result<DataType> GetMMDeployDataType(rknn_tensor_type data_type) {
|
||||
switch (data_type) {
|
||||
case RKNN_TENSOR_FLOAT32:
|
||||
return DataType::kFLOAT;
|
||||
case RKNN_TENSOR_FLOAT16:
|
||||
return DataType::kHALF;
|
||||
case RKNN_TENSOR_INT8:
|
||||
return DataType::kINT8;
|
||||
case RKNN_TENSOR_INT32:
|
||||
return DataType::kINT32;
|
||||
case RKNN_TENSOR_INT64:
|
||||
return DataType::kINT64;
|
||||
default:
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
}
|
||||
|
||||
RKNNNet::~RKNNNet() { rknn_destroy(ctx_); }
|
||||
|
||||
void RKNNNet::dump_tensor_attr(rknn_tensor_attr* attr) {
|
||||
MMDEPLOY_INFO(
|
||||
" index={}, name={}, n_dims={}, dims=[{}, {}, {}, {}], n_elems={}, size={}, fmt={}, "
|
||||
"type={}, qnt_type={}, "
|
||||
"zp={}, scale=%f\n",
|
||||
attr->index, attr->name, attr->n_dims, attr->dims[0], attr->dims[1], attr->dims[2],
|
||||
attr->dims[3], attr->n_elems, attr->size, get_format_string(attr->fmt),
|
||||
get_type_string(attr->type), get_qnt_type_string(attr->qnt_type), attr->zp, attr->scale);
|
||||
}
|
||||
|
||||
Result<void> RKNNNet::Init(const Value& args) {
|
||||
auto& context = args["context"];
|
||||
device_ = context["device"].get<Device>();
|
||||
stream_ = context["stream"].get<Stream>();
|
||||
if (!device_.is_host()) {
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
auto name = args["name"].get<std::string>();
|
||||
auto model = context["model"].get<Model>();
|
||||
OUTCOME_TRY(auto config, model.GetModelConfig(name));
|
||||
|
||||
std::string content;
|
||||
OUTCOME_TRY(content, model.ReadFile(config.net));
|
||||
char* model_ptr = const_cast<char*>(content.data());
|
||||
int ret = rknn_init(&ctx_, model_ptr, content.size(), 0, NULL);
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_ERROR("Load .rknn failed! ret= {}", ret);
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
|
||||
// Get Model Input Output Info
|
||||
rknn_input_output_num io_num;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_INFO("model input num: {}, output num: {}\n", io_num.n_input, io_num.n_output);
|
||||
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
|
||||
for (int i = 0; i < io_num.n_input; i++) {
|
||||
rknn_tensor_attr input_attr;
|
||||
input_attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &(input_attr), sizeof(rknn_tensor_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_INFO("input tensors:\n");
|
||||
dump_tensor_attr(&(input_attr));
|
||||
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
input_attrs_.push_back(input_attr);
|
||||
OUTCOME_TRY(auto data_type, GetMMDeployDataType(input_attr.type));
|
||||
input_tensors_.emplace_back(TensorDesc{device_, data_type, {}, input_attr.name});
|
||||
}
|
||||
|
||||
for (int i = 0; i < io_num.n_output; i++) {
|
||||
rknn_tensor_attr output_attr;
|
||||
output_attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &(output_attr), sizeof(rknn_tensor_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_INFO("output tensors:\n");
|
||||
dump_tensor_attr(&(output_attr));
|
||||
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
output_attrs_.push_back(output_attr);
|
||||
OUTCOME_TRY(auto data_type, GetMMDeployDataType(output_attr.type));
|
||||
output_tensors_.emplace_back(TensorDesc{device_, data_type, {}, output_attr.name});
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<void> RKNNNet::ForwardAsync(Event* event) { return Status(eNotSupported); }
|
||||
|
||||
Result<void> RKNNNet::Deinit() { return success(); }
|
||||
|
||||
Result<Span<Tensor>> RKNNNet::GetInputTensors() { return input_tensors_; }
|
||||
|
||||
Result<Span<Tensor>> RKNNNet::GetOutputTensors() { return output_tensors_; }
|
||||
|
||||
Result<void> RKNNNet::Reshape(Span<TensorShape> input_shapes) {
|
||||
for (size_t i = 0; i < input_shapes.size(); ++i) {
|
||||
input_tensors_[i].Reshape(input_shapes[i]);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<void> RKNNNet::Forward() {
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
|
||||
std::vector<rknn_input> inputs;
|
||||
for (int i = 0; i < input_tensors_.size(); i++) {
|
||||
rknn_input input;
|
||||
input.index = i;
|
||||
input.pass_through = 0;
|
||||
input.type = input_attrs_[i].type;
|
||||
input.fmt = input_attrs_[i].fmt;
|
||||
input.buf = input_tensors_[i].data<float>();
|
||||
input.size = input_attrs_[i].size;
|
||||
inputs.push_back(input);
|
||||
}
|
||||
|
||||
// Set input
|
||||
int ret = rknn_inputs_set(ctx_, input_tensors_.size(), inputs.data());
|
||||
if (ret < 0) {
|
||||
MMDEPLOY_ERROR("rknn_input_set fail! ret= {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
|
||||
// Get output
|
||||
std::vector<rknn_output> outputs;
|
||||
for (uint32_t i = 0; i < output_tensors_.size(); ++i) {
|
||||
rknn_output output;
|
||||
output.want_float = 1;
|
||||
output.index = i;
|
||||
output.is_prealloc = 0;
|
||||
outputs.push_back(output);
|
||||
}
|
||||
|
||||
ret = rknn_run(ctx_, NULL);
|
||||
if (ret < 0) {
|
||||
MMDEPLOY_ERROR("rknn_run fail! ret={}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
|
||||
ret = rknn_outputs_get(ctx_, output_tensors_.size(), outputs.data(), NULL);
|
||||
if (ret < 0) {
|
||||
MMDEPLOY_ERROR("rknn_outputs_get fail! ret= {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
for (int i = 0; i < output_tensors_.size(); i++) {
|
||||
TensorShape tensor_shape;
|
||||
for (int j = 0; j < output_attrs_[i].n_dims; ++j) {
|
||||
tensor_shape.push_back(output_attrs_[i].dims[j]);
|
||||
}
|
||||
output_tensors_[i].Reshape(tensor_shape);
|
||||
memcpy(output_tensors_[i].data<float>(), (float*)outputs[i].buf, output_attrs_[i].size);
|
||||
}
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
return success();
|
||||
}
|
||||
|
||||
class RKNNNetCreator : public Creator<Net> {
|
||||
public:
|
||||
const char* GetName() const override { return "rknn"; }
|
||||
int GetVersion() const override { return 0; }
|
||||
std::unique_ptr<Net> Create(const Value& args) override {
|
||||
try {
|
||||
auto p = std::make_unique<RKNNNet>();
|
||||
if (auto r = p->Init(args)) {
|
||||
return p;
|
||||
} else {
|
||||
MMDEPLOY_ERROR("error creating RKNNNet: {}", r.error().message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
MMDEPLOY_ERROR("unhandled exception when creating RKNNNet: {}", e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MODULE(Net, RKNNNetCreator);
|
||||
|
||||
} // namespace mmdeploy::framework
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_SRC_NET_RKNN_RKNN_NET_H_
|
||||
#define MMDEPLOY_SRC_NET_RKNN_RKNN_NET_H_
|
||||
|
||||
#include "mmdeploy/core/mpl/span.h"
|
||||
#include "mmdeploy/core/net.h"
|
||||
#include "rknn_api.h"
|
||||
|
||||
namespace mmdeploy::framework {
|
||||
|
||||
class RKNNNet : public Net {
|
||||
public:
|
||||
~RKNNNet() override;
|
||||
|
||||
Result<void> Init(const Value& args) override;
|
||||
|
||||
Result<void> Deinit() override;
|
||||
|
||||
Result<void> Reshape(Span<TensorShape> input_shapes) override;
|
||||
|
||||
Result<Span<Tensor> > GetInputTensors() override;
|
||||
|
||||
Result<Span<Tensor> > GetOutputTensors() override;
|
||||
|
||||
Result<void> Forward() override;
|
||||
|
||||
Result<void> ForwardAsync(Event* event) override;
|
||||
|
||||
private:
|
||||
void dump_tensor_attr(rknn_tensor_attr* attr);
|
||||
|
||||
Device device_;
|
||||
Stream stream_;
|
||||
rknn_context ctx_;
|
||||
std::vector<Tensor> input_tensors_;
|
||||
std::vector<Tensor> output_tensors_;
|
||||
std::vector<rknn_tensor_attr> input_attrs_;
|
||||
std::vector<rknn_tensor_attr> output_attrs_;
|
||||
static constexpr const auto kHost = Device(0);
|
||||
};
|
||||
|
||||
} // namespace mmdeploy::framework
|
||||
|
||||
#endif // MMDEPLOY_SRC_NET_RKNN_RKNN_NET_H_
|
|
@ -39,3 +39,4 @@ Please visit the following links to find out how to build MMDeploy according to
|
|||
- [NVIDIA Jetson](jetsons.md)
|
||||
- [SNPE](snpe.md)
|
||||
- [RISC-V](riscv.md)
|
||||
- [Rockchip](rockchip.md)
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# Build for RKNN
|
||||
|
||||
This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`.
|
||||
|
||||
## Installation
|
||||
|
||||
It is recommended to create a virtual environment for the project.
|
||||
|
||||
1. get RKNN-Toolkit2 through:
|
||||
|
||||
```
|
||||
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
|
||||
```
|
||||
|
||||
2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc). In our testing, we used the rknn-toolkit2 1.2.0 with commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`. When installing rknn-toolkit2, it is better to append `--no-deps` after the commands to avoid dependency conflicts. For example:
|
||||
|
||||
```
|
||||
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
|
||||
```
|
||||
|
||||
3. Install ONNX==1.8.0 before reinstall MMDeploy from source following the [instructions](../01-how-to-build/build_from_source.md). Note that there are conflicts between the pip dependencies of MMDeploy and RKNN. Here is the suggested packages versions for python 3.6:
|
||||
|
||||
```
|
||||
protobuf==3.19.4
|
||||
onnx==1.8.0
|
||||
onnxruntime==1.8.0
|
||||
torch==1.8.0
|
||||
torchvision==0.9.0
|
||||
```
|
||||
|
||||
4. Install torch and torchvision using conda. For example:
|
||||
|
||||
```
|
||||
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
|
||||
```
|
||||
|
||||
To work with models from [MMClassification](https://mmclassification.readthedocs.io/en/latest/getting_started.html), you may need to install it additionally.
|
||||
|
||||
## Usage
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python tools/deploy.py \
|
||||
configs/mmcls/classification_rknn_static.py \
|
||||
/mmclassification_dir/configs/resnet/resnet50_8xb32_in1k.py \
|
||||
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
|
||||
/mmclassification_dir/demo/demo.JPEG \
|
||||
--work-dir ../resnet50 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
## Deployment config
|
||||
|
||||
With the deployment config, you can modify the `backend_config` for your preference. An example `backend_config` of mmclassification is shown as below:
|
||||
|
||||
```python
|
||||
backend_config = dict(
|
||||
type='rknn',
|
||||
common_config=dict(
|
||||
mean_values=None,
|
||||
std_values=None,
|
||||
target_platform='rk3588',
|
||||
optimization_level=3),
|
||||
quantization_config=dict(do_quantization=False, dataset=None),
|
||||
input_size_list=[[3, 224, 224]])
|
||||
|
||||
```
|
||||
|
||||
The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`.
|
||||
|
||||
## Build SDK with Rockchip NPU
|
||||
|
||||
1. get rknpu2 through:
|
||||
|
||||
```
|
||||
git clone git@github.com:rockchip-linux/rknpu2.git
|
||||
```
|
||||
|
||||
2. for linux, download gcc cross compiler. The download link of the compiler from the official user guide of `rknpu2` was deprecated. You may use another verified [link](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). After download and unzip the compiler, you may open the terminal, set `RKNN_TOOL_CHAIN` and `RKNPU2_DEVICE_DIR` by `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`.
|
||||
|
||||
3. after the above preparition, run the following commands:
|
||||
|
||||
```shell
|
||||
cd /path/to/mmdeploy
|
||||
mkdir -p build && rm -rf build/CM* && cd build
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
cmake \
|
||||
-DCMAKE_TOOLCHAIN_FILE=/path/to/mmdeploy/cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DOpenCV_DIR=${RKNPU2_DEVICE_DIR}/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV \
|
||||
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
|
||||
-DMMDEPLOY_TARGET_DEVICES="cpu" \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DMMDEPLOY_CODEBASES=all \
|
||||
-DMMDEPLOY_BUILD_TEST=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
..
|
||||
make && make install
|
||||
```
|
||||
|
||||
## Run the demo with SDK
|
||||
|
||||
First make sure that`--dump-info`is used during convert model, so that the working directory has the files required by the SDK such as `pipeline.json`.
|
||||
|
||||
`adb push` the model directory, executable file and .so to the device.
|
||||
|
||||
```bash
|
||||
cd /path/to/mmdeploy
|
||||
adb push resnet50 /data/local/tmp/resnet50
|
||||
adb push /mmclassification_dir/demo/demo.JPEG /data/local/tmp/resnet50/demo.JPEG
|
||||
cd build
|
||||
adb push lib /data/local/tmp/lib
|
||||
adb push bin/image_classification /data/local/tmp/image_classification
|
||||
```
|
||||
|
||||
Set up environment variable and execute the sample.
|
||||
|
||||
```bash
|
||||
adb shell
|
||||
cd /data/local/tmp
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib
|
||||
./image_classification cpu ./resnet50 ./resnet50/demo.JPEG
|
||||
..
|
||||
label: 65, score: 0.95
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Quantization fails.
|
||||
|
||||
Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `True`. Please modify the settings of `Normalize` in the `model_cfg` from
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
```
|
||||
|
||||
to
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
|
||||
```
|
||||
|
||||
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`.
|
|
@ -1,80 +1,9 @@
|
|||
# RKNN support
|
||||
# Supported RKNN feature
|
||||
|
||||
This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`.
|
||||
Currently, MMDeploy only tests rk3588 with linux platform.
|
||||
|
||||
## Installation
|
||||
The following features cannot be automatically enabled by mmdeploy and you need to manually modify the configuration in MMDeploy like [here](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py).
|
||||
|
||||
It is recommended to create a virtual environment for the project.
|
||||
|
||||
1. get RKNN-Toolkit2 through:
|
||||
|
||||
```
|
||||
git clone https://github.com/rockchip-linux/rknn-toolkit2
|
||||
```
|
||||
|
||||
2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc). In our testing, we used the rknn-toolkit 1.2.0 with commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`.
|
||||
|
||||
3. reinstall MMDeploy from source following the [instructions](../01-how-to-build/build_from_source.md). Note that there are conflicts between the pip dependencies of MMDeploy and RKNN. Here is the suggested packages versions for python 3.6:
|
||||
|
||||
```
|
||||
protobuf==3.19.4
|
||||
onnx==1.8.0
|
||||
onnxruntime==1.8.0
|
||||
torch==1.8.0
|
||||
torchvision==0.9.0
|
||||
```
|
||||
|
||||
To work with models from [MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), you may need to install it additionally.
|
||||
|
||||
## Usage
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python tools/deploy.py \
|
||||
configs/mmdet/detection/detection_rknn_static.py \
|
||||
/mmdetection_dir/mmdetection/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py \
|
||||
/tmp/snapshots/yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8.pth \
|
||||
tests/data/tiger.jpeg \
|
||||
--work-dir ../deploy_result \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
## Deployment config
|
||||
|
||||
With the deployment config, you can modify the `backend_config` for your preference. An example `backend_config` of mmclassification is shown as below:
|
||||
|
||||
```python
|
||||
backend_config = dict(
|
||||
type='rknn',
|
||||
common_config=dict(
|
||||
mean_values=None,
|
||||
std_values=None,
|
||||
target_platform='rk3588',
|
||||
optimization_level=3),
|
||||
quantization_config=dict(do_quantization=False, dataset=None),
|
||||
input_size_list=[[3, 224, 224]])
|
||||
|
||||
```
|
||||
|
||||
The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Quantization fails.
|
||||
|
||||
Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `False`. Please modify the settings of `Normalize` in the `model_cfg` from
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
```
|
||||
|
||||
to
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
|
||||
```
|
||||
|
||||
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`.
|
||||
- target_platform other than `3588`
|
||||
- quantization settings
|
||||
- optimization level other than 3
|
||||
|
|
|
@ -61,6 +61,7 @@ You can switch between Chinese and English documents in the lower-left corner of
|
|||
05-supported-backends/snpe.md
|
||||
05-supported-backends/tensorrt.md
|
||||
05-supported-backends/torchscript.md
|
||||
05-supported-backends/rknn.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -42,3 +42,4 @@ git clone -b master git@github.com:open-mmlab/mmdeploy.git --recursive
|
|||
- [NVIDIA Jetson](jetsons.md)
|
||||
- [Qcom SNPE](snpe.md)
|
||||
- [RISC-V](riscv.md)
|
||||
- [Rockchip](rockchip.md)
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# 支持 RKNN
|
||||
|
||||
本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。
|
||||
|
||||
## 安装
|
||||
|
||||
建议为项目创建一个虚拟环境。
|
||||
|
||||
1. 获取 RKNN-Toolkit2:
|
||||
|
||||
```
|
||||
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
|
||||
```
|
||||
|
||||
2. 通过 [官方文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc),安装 RKNN python 安装包. 在我们的测试中, 使用的 rknn-toolkit 版本是 1.2.0,commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`。安装 rknn-toolkit2 时,最好在安装命令后添加`--no-deps`,以避免依赖包的冲突。比如:
|
||||
|
||||
```
|
||||
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
|
||||
```
|
||||
|
||||
3. 先安装onnx==1.8.0,跟着 [instructions](../01-how-to-build/build_from_source.md),源码安装 MMDeploy。 需要注意的是, MMDeploy 和 RKNN 依赖的安装包间有冲突的内容. 这里提供建议在 python 3.6 环境中使用的安装包版本:
|
||||
|
||||
```
|
||||
protobuf==3.19.4
|
||||
onnx==1.8.0
|
||||
onnxruntime==1.8.0
|
||||
torch==1.8.0
|
||||
torchvision==0.9.0
|
||||
```
|
||||
|
||||
4. 使用 conda 安装 torch and torchvision,比如:
|
||||
|
||||
```
|
||||
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
|
||||
```
|
||||
|
||||
如要使用 [MMClassification](https://mmclassification.readthedocs.io/en/latest/getting_started.html), 需要用户自己安装使用。
|
||||
|
||||
## 使用
|
||||
|
||||
例子:
|
||||
|
||||
```bash
|
||||
python tools/deploy.py \
|
||||
configs/mmcls/classification_rknn_static.py \
|
||||
/mmclassification_dir/configs/resnet/resnet50_8xb32_in1k.py \
|
||||
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
|
||||
/mmclassification_dir/demo/demo.JPEG \
|
||||
--work-dir ../resnet50 \
|
||||
--device cpu
|
||||
```
|
||||
|
||||
## 部署 config
|
||||
|
||||
部署 config,你可以根据需要修改 `backend_config` 字段. 一个 mmclassification 的 `backend_config`例子如下:
|
||||
|
||||
```python
|
||||
backend_config = dict(
|
||||
type='rknn',
|
||||
common_config=dict(
|
||||
mean_values=None,
|
||||
std_values=None,
|
||||
target_platform='rk3588',
|
||||
optimization_level=3),
|
||||
quantization_config=dict(do_quantization=False, dataset=None),
|
||||
input_size_list=[[3, 224, 224]])
|
||||
|
||||
```
|
||||
|
||||
`common_config` 的内容服务于 `rknn.config()`. `quantization_config` 的内容服务于 `rknn.build()`。
|
||||
|
||||
## 安装 SDK
|
||||
|
||||
1. 获取 rknpu2:
|
||||
|
||||
```
|
||||
git clone git@github.com:rockchip-linux/rknpu2.git
|
||||
```
|
||||
|
||||
2. 在 linux 系统, 下载 gcc 交叉编译器. `rknpu2` 的官方提供的下载链接无法使用了. 用户可以使用另一个 [链接](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). 下载并解压完编译器, 打开终端, 设置 `RKNN_TOOL_CHAIN` 和 `RKNPU2_DEVICE_DIR` 为 `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`。
|
||||
|
||||
3. 上述准备工作完成后, 运行如下指令安装:
|
||||
|
||||
```shell
|
||||
cd /path/to/mmdeploy
|
||||
mkdir -p build && rm -rf build/CM* && cd build
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
cmake \
|
||||
-DCMAKE_TOOLCHAIN_FILE=/path/to/mmdeploy/cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DOpenCV_DIR=${RKNPU2_DEVICE_DIR}/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV \
|
||||
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
|
||||
-DMMDEPLOY_TARGET_DEVICES="cpu" \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DMMDEPLOY_CODEBASES=all \
|
||||
-DMMDEPLOY_BUILD_TEST=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
..
|
||||
make && make install
|
||||
```
|
||||
|
||||
## 运行 SDK 的 demo
|
||||
|
||||
首先,确保`--dump-info`在转模型的时候调用了, 这样工作目录下包含 SDK 需要的配置文件 `pipeline.json`。
|
||||
|
||||
使用 `adb push` 将模型路径,执行文件和.so 文件传到板子上。
|
||||
|
||||
```bash
|
||||
cd /path/to/mmdeploy
|
||||
adb push resnet50 /data/local/tmp/resnet50
|
||||
adb push /mmclassification_dir/demo/demo.JPEG /data/local/tmp/resnet50/demo.JPEG
|
||||
cd build
|
||||
adb push lib /data/local/tmp/lib
|
||||
adb push bin/image_classification /data/local/tmp/image_classification
|
||||
```
|
||||
|
||||
设置环境变量,运行例子。
|
||||
|
||||
```bash
|
||||
adb shell
|
||||
cd /data/local/tmp
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib
|
||||
./image_classification cpu ./resnet50 ./resnet50/demo.JPEG
|
||||
..
|
||||
label: 65, score: 0.95
|
||||
```
|
||||
|
||||
## 问题点
|
||||
|
||||
- 量化失败.
|
||||
|
||||
经验来说, 如果 `do_quantization` 被设置为 `True`,RKNN 需要的输入没有被归一化过。请修改 `Normalize` 在 `model_cfg` 的设置,如将
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
```
|
||||
|
||||
改为
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
|
||||
```
|
||||
|
||||
此外, deploy_cfg 的 `mean_values` 和 `std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[123.675, 116.28, 103.53]`, `std_values=[58.395, 57.12, 57.375]`。
|
|
@ -0,0 +1,9 @@
|
|||
# 支持的 RKNN 特征
|
||||
|
||||
目前, MMDeploy 只在 rk3588 的 linux 平台上测试过.
|
||||
|
||||
以下特性需要手动在 MMDeploy 自行配置,如[这里](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py).
|
||||
|
||||
- target_platform != `3588`
|
||||
- quantization settings
|
||||
- optimization level != 3
|
|
@ -58,6 +58,7 @@
|
|||
05-supported-backends/onnxruntime.md
|
||||
05-supported-backends/openvino.md
|
||||
05-supported-backends/pplnn.md
|
||||
05-supported-backends/rknn.md
|
||||
05-supported-backends/snpe.md
|
||||
05-supported-backends/tensorrt.md
|
||||
05-supported-backends/torchscript.md
|
||||
|
|
|
@ -252,6 +252,9 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
transform['meta_keys'] = list(set(meta_keys))
|
||||
|
||||
if get_backend(deploy_cfg) == Backend.RKNN:
|
||||
del transforms[-2]
|
||||
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item type'\
|
||||
' of pipeline should be LoadImageFromFile'
|
||||
|
||||
|
|
Loading…
Reference in New Issue