[Feature] Add RKNN support to SDK (#1145)

* add rknn_net [WIP]

* add cmake

* enable mmcls

* remove toTensor in SDK pipeline

* update doc

* translate to Chinese

* update doc and add tool-chain cmake

* use ::framework

* fix lint

* doc and print log

* data map

* refine install doc

* add rknpu2 workflow

* update gcc yaml

* better cmake file

* update doc link

* use vector instead of array

* better env variable

* use soft link

* release ctx

* name rule
pull/1166/head
AllentDan 2022-10-18 17:52:31 +08:00 committed by GitHub
parent 8e634059a1
commit 3eb60ea584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 679 additions and 77 deletions

View File

@ -0,0 +1,54 @@
name: build_rknpu2_gcc
on:
push:
paths:
- "csrc/**"
- "demo/csrc/**"
- "CMakeLists.txt"
pull_request:
paths-ignore:
- "csrc/**"
- "demo/csrc/**"
- "CMakeLists.txt"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build_rknpu2_gcc:
runs-on: ubuntu-18.04
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: rknpu2-gnu-toolchain
run: |
mkdir $GITHUB_WORKSPACE/rknpu2-gnu-toolchain
cd $GITHUB_WORKSPACE/rknpu2-gnu-toolchain
git clone https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu.git
- name: rknpu2
run: |
mkdir $GITHUB_WORKSPACE/rknpu2
cd $GITHUB_WORKSPACE/rknpu2
git clone https://github.com/rockchip-linux/rknpu2.git
- name: build
run: |
export RKNN_TOOL_CHAIN=$GITHUB_WORKSPACE/rknpu2-gnu-toolchain/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu/usr
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
export RKNPU2_DEVICE_DIR=$GITHUB_WORKSPACE/rknpu2/rknpu2/runtime/RK3588
mkdir build && cd build
cmake .. \
-DCMAKE_TOOLCHAIN_FILE=$(pwd)/../cmake/toolchains/rknpu2-linux-gnu.cmake \
-DMMDEPLOY_BUILD_SDK=ON \
-DMMDEPLOY_SHARED_LIBS=ON \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
-DMMDEPLOY_TARGET_DEVICES="cpu" \
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
-DMMDEPLOY_CODEBASES=all \
-DOpenCV_DIR=$RKNPU2_DEVICE_DIR/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV
make -j$(nproc)
make install

View File

@ -128,6 +128,7 @@ if (MMDEPLOY_BUILD_SDK)
mmdeploy_add_deps(pplnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS pplnn)
endif ()
mmdeploy_add_deps(snpe BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS snpe)
mmdeploy_add_deps(rknn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS rknn)
include(CMakePackageConfigHelpers)
# generate the config file that is includes the exports

View File

@ -0,0 +1,23 @@
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR rockchip)
if(DEFINED ENV{RKNN_TOOL_CHAIN})
file(TO_CMAKE_PATH $ENV{RKNN_TOOL_CHAIN} RKNN_TOOL_CHAIN)
else()
message(FATAL_ERROR "RKNN_TOOL_CHAIN env must be defined")
endif()
set(CMAKE_C_COMPILER ${RKNN_TOOL_CHAIN}/bin/aarch64-rockchip-linux-gnu-gcc)
set(CMAKE_CXX_COMPILER ${RKNN_TOOL_CHAIN}/bin/aarch64-rockchip-linux-gnu-g++)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
set(CMAKE_C_FLAGS "-Wl,--allow-shlib-undefined")
set(CMAKE_CXX_FLAGS "-Wl,--allow-shlib-undefined")
# cache flags
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags")

View File

@ -38,5 +38,9 @@ if ("coreml" IN_LIST MMDEPLOY_TARGET_BACKENDS)
add_subdirectory(coreml)
endif ()
if ("rknn" IN_LIST MMDEPLOY_TARGET_BACKENDS)
add_subdirectory(rknn)
endif ()
mmdeploy_add_module(${PROJECT_NAME} net_module.cpp)
add_library(mmdeploy::net_module ALIAS ${PROJECT_NAME})

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_rknn_net)
add_library(rknn SHARED IMPORTED)
if(DEFINED ENV{RKNPU2_DEVICE_DIR})
file(TO_CMAKE_PATH $ENV{RKNPU2_DEVICE_DIR} RKNPU2_DEVICE_DIR)
else()
message(FATAL_ERROR "RKNPU2_DEVICE_DIR env must be defined")
endif()
set_target_properties(rknn PROPERTIES
IMPORTED_LOCATION "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/aarch64/librknn_api.so"
INTERFACE_INCLUDE_DIRECTORIES "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/include"
)
mmdeploy_add_module(${PROJECT_NAME} rknn_net.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE rknn)
add_library(mmdeploy::rknn_net ALIAS ${PROJECT_NAME})

View File

@ -0,0 +1,216 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "rknn_net.h"
#include <stdio.h>
#include <fstream>
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/model.h"
#include "mmdeploy/core/utils/filesystem.h"
#include "mmdeploy/core/utils/formatter.h"
namespace mmdeploy::framework {
Result<rknn_tensor_type> GetRKNNDataType(DataType data_type) {
switch (data_type) {
case DataType::kFLOAT:
return RKNN_TENSOR_FLOAT32;
case DataType::kHALF:
return RKNN_TENSOR_FLOAT16;
case DataType::kINT8:
return RKNN_TENSOR_INT8;
case DataType::kINT32:
return RKNN_TENSOR_INT32;
case DataType::kINT64:
return RKNN_TENSOR_INT64;
default:
return Status(eNotSupported);
}
}
Result<DataType> GetMMDeployDataType(rknn_tensor_type data_type) {
switch (data_type) {
case RKNN_TENSOR_FLOAT32:
return DataType::kFLOAT;
case RKNN_TENSOR_FLOAT16:
return DataType::kHALF;
case RKNN_TENSOR_INT8:
return DataType::kINT8;
case RKNN_TENSOR_INT32:
return DataType::kINT32;
case RKNN_TENSOR_INT64:
return DataType::kINT64;
default:
return Status(eNotSupported);
}
}
RKNNNet::~RKNNNet() { rknn_destroy(ctx_); }
void RKNNNet::dump_tensor_attr(rknn_tensor_attr* attr) {
MMDEPLOY_INFO(
" index={}, name={}, n_dims={}, dims=[{}, {}, {}, {}], n_elems={}, size={}, fmt={}, "
"type={}, qnt_type={}, "
"zp={}, scale=%f\n",
attr->index, attr->name, attr->n_dims, attr->dims[0], attr->dims[1], attr->dims[2],
attr->dims[3], attr->n_elems, attr->size, get_format_string(attr->fmt),
get_type_string(attr->type), get_qnt_type_string(attr->qnt_type), attr->zp, attr->scale);
}
Result<void> RKNNNet::Init(const Value& args) {
auto& context = args["context"];
device_ = context["device"].get<Device>();
stream_ = context["stream"].get<Stream>();
if (!device_.is_host()) {
return Status(eNotSupported);
}
auto name = args["name"].get<std::string>();
auto model = context["model"].get<Model>();
OUTCOME_TRY(auto config, model.GetModelConfig(name));
std::string content;
OUTCOME_TRY(content, model.ReadFile(config.net));
char* model_ptr = const_cast<char*>(content.data());
int ret = rknn_init(&ctx_, model_ptr, content.size(), 0, NULL);
if (ret != RKNN_SUCC) {
MMDEPLOY_ERROR("Load .rknn failed! ret= {}", ret);
return Status(eInvalidArgument);
}
// Get Model Input Output Info
rknn_input_output_num io_num;
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
if (ret != RKNN_SUCC) {
MMDEPLOY_INFO("model input num: {}, output num: {}\n", io_num.n_input, io_num.n_output);
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
return Status(eFail);
}
for (int i = 0; i < io_num.n_input; i++) {
rknn_tensor_attr input_attr;
input_attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &(input_attr), sizeof(rknn_tensor_attr));
if (ret != RKNN_SUCC) {
MMDEPLOY_INFO("input tensors:\n");
dump_tensor_attr(&(input_attr));
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
return Status(eFail);
}
input_attrs_.push_back(input_attr);
OUTCOME_TRY(auto data_type, GetMMDeployDataType(input_attr.type));
input_tensors_.emplace_back(TensorDesc{device_, data_type, {}, input_attr.name});
}
for (int i = 0; i < io_num.n_output; i++) {
rknn_tensor_attr output_attr;
output_attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &(output_attr), sizeof(rknn_tensor_attr));
if (ret != RKNN_SUCC) {
MMDEPLOY_INFO("output tensors:\n");
dump_tensor_attr(&(output_attr));
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
return Status(eFail);
}
output_attrs_.push_back(output_attr);
OUTCOME_TRY(auto data_type, GetMMDeployDataType(output_attr.type));
output_tensors_.emplace_back(TensorDesc{device_, data_type, {}, output_attr.name});
}
return success();
}
Result<void> RKNNNet::ForwardAsync(Event* event) { return Status(eNotSupported); }
Result<void> RKNNNet::Deinit() { return success(); }
Result<Span<Tensor>> RKNNNet::GetInputTensors() { return input_tensors_; }
Result<Span<Tensor>> RKNNNet::GetOutputTensors() { return output_tensors_; }
Result<void> RKNNNet::Reshape(Span<TensorShape> input_shapes) {
for (size_t i = 0; i < input_shapes.size(); ++i) {
input_tensors_[i].Reshape(input_shapes[i]);
}
return success();
}
Result<void> RKNNNet::Forward() {
OUTCOME_TRY(stream_.Wait());
std::vector<rknn_input> inputs;
for (int i = 0; i < input_tensors_.size(); i++) {
rknn_input input;
input.index = i;
input.pass_through = 0;
input.type = input_attrs_[i].type;
input.fmt = input_attrs_[i].fmt;
input.buf = input_tensors_[i].data<float>();
input.size = input_attrs_[i].size;
inputs.push_back(input);
}
// Set input
int ret = rknn_inputs_set(ctx_, input_tensors_.size(), inputs.data());
if (ret < 0) {
MMDEPLOY_ERROR("rknn_input_set fail! ret= {}", ret);
return Status(eFail);
}
// Get output
std::vector<rknn_output> outputs;
for (uint32_t i = 0; i < output_tensors_.size(); ++i) {
rknn_output output;
output.want_float = 1;
output.index = i;
output.is_prealloc = 0;
outputs.push_back(output);
}
ret = rknn_run(ctx_, NULL);
if (ret < 0) {
MMDEPLOY_ERROR("rknn_run fail! ret={}", ret);
return Status(eFail);
}
ret = rknn_outputs_get(ctx_, output_tensors_.size(), outputs.data(), NULL);
if (ret < 0) {
MMDEPLOY_ERROR("rknn_outputs_get fail! ret= {}", ret);
return Status(eFail);
}
for (int i = 0; i < output_tensors_.size(); i++) {
TensorShape tensor_shape;
for (int j = 0; j < output_attrs_[i].n_dims; ++j) {
tensor_shape.push_back(output_attrs_[i].dims[j]);
}
output_tensors_[i].Reshape(tensor_shape);
memcpy(output_tensors_[i].data<float>(), (float*)outputs[i].buf, output_attrs_[i].size);
}
OUTCOME_TRY(stream_.Wait());
return success();
}
class RKNNNetCreator : public Creator<Net> {
public:
const char* GetName() const override { return "rknn"; }
int GetVersion() const override { return 0; }
std::unique_ptr<Net> Create(const Value& args) override {
try {
auto p = std::make_unique<RKNNNet>();
if (auto r = p->Init(args)) {
return p;
} else {
MMDEPLOY_ERROR("error creating RKNNNet: {}", r.error().message().c_str());
return nullptr;
}
} catch (const std::exception& e) {
MMDEPLOY_ERROR("unhandled exception when creating RKNNNet: {}", e.what());
return nullptr;
}
}
};
REGISTER_MODULE(Net, RKNNNetCreator);
} // namespace mmdeploy::framework

View File

@ -0,0 +1,45 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_NET_RKNN_RKNN_NET_H_
#define MMDEPLOY_SRC_NET_RKNN_RKNN_NET_H_
#include "mmdeploy/core/mpl/span.h"
#include "mmdeploy/core/net.h"
#include "rknn_api.h"
namespace mmdeploy::framework {
class RKNNNet : public Net {
public:
~RKNNNet() override;
Result<void> Init(const Value& args) override;
Result<void> Deinit() override;
Result<void> Reshape(Span<TensorShape> input_shapes) override;
Result<Span<Tensor> > GetInputTensors() override;
Result<Span<Tensor> > GetOutputTensors() override;
Result<void> Forward() override;
Result<void> ForwardAsync(Event* event) override;
private:
void dump_tensor_attr(rknn_tensor_attr* attr);
Device device_;
Stream stream_;
rknn_context ctx_;
std::vector<Tensor> input_tensors_;
std::vector<Tensor> output_tensors_;
std::vector<rknn_tensor_attr> input_attrs_;
std::vector<rknn_tensor_attr> output_attrs_;
static constexpr const auto kHost = Device(0);
};
} // namespace mmdeploy::framework
#endif // MMDEPLOY_SRC_NET_RKNN_RKNN_NET_H_

View File

@ -39,3 +39,4 @@ Please visit the following links to find out how to build MMDeploy according to
- [NVIDIA Jetson](jetsons.md)
- [SNPE](snpe.md)
- [RISC-V](riscv.md)
- [Rockchip](rockchip.md)

View File

@ -0,0 +1,147 @@
# Build for RKNN
This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`.
## Installation
It is recommended to create a virtual environment for the project.
1. get RKNN-Toolkit2 through:
```
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
```
2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc). In our testing, we used the rknn-toolkit2 1.2.0 with commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`. When installing rknn-toolkit2, it is better to append `--no-deps` after the commands to avoid dependency conflicts. For example:
```
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
```
3. Install ONNX==1.8.0 before reinstall MMDeploy from source following the [instructions](../01-how-to-build/build_from_source.md). Note that there are conflicts between the pip dependencies of MMDeploy and RKNN. Here is the suggested packages versions for python 3.6:
```
protobuf==3.19.4
onnx==1.8.0
onnxruntime==1.8.0
torch==1.8.0
torchvision==0.9.0
```
4. Install torch and torchvision using conda. For example:
```
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
```
To work with models from [MMClassification](https://mmclassification.readthedocs.io/en/latest/getting_started.html), you may need to install it additionally.
## Usage
Example:
```bash
python tools/deploy.py \
configs/mmcls/classification_rknn_static.py \
/mmclassification_dir/configs/resnet/resnet50_8xb32_in1k.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
/mmclassification_dir/demo/demo.JPEG \
--work-dir ../resnet50 \
--device cpu
```
## Deployment config
With the deployment config, you can modify the `backend_config` for your preference. An example `backend_config` of mmclassification is shown as below:
```python
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None,
std_values=None,
target_platform='rk3588',
optimization_level=3),
quantization_config=dict(do_quantization=False, dataset=None),
input_size_list=[[3, 224, 224]])
```
The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`.
## Build SDK with Rockchip NPU
1. get rknpu2 through:
```
git clone git@github.com:rockchip-linux/rknpu2.git
```
2. for linux, download gcc cross compiler. The download link of the compiler from the official user guide of `rknpu2` was deprecated. You may use another verified [link](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). After download and unzip the compiler, you may open the terminal, set `RKNN_TOOL_CHAIN` and `RKNPU2_DEVICE_DIR` by `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`.
3. after the above preparition, run the following commands:
```shell
cd /path/to/mmdeploy
mkdir -p build && rm -rf build/CM* && cd build
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
cmake \
-DCMAKE_TOOLCHAIN_FILE=/path/to/mmdeploy/cmake/toolchains/rknpu2-linux-gnu.cmake \
-DMMDEPLOY_BUILD_SDK=ON \
-DCMAKE_BUILD_TYPE=Debug \
-DOpenCV_DIR=${RKNPU2_DEVICE_DIR}/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV \
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
-DMMDEPLOY_TARGET_DEVICES="cpu" \
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
-DMMDEPLOY_CODEBASES=all \
-DMMDEPLOY_BUILD_TEST=ON \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
..
make && make install
```
## Run the demo with SDK
First make sure that`--dump-info`is used during convert model, so that the working directory has the files required by the SDK such as `pipeline.json`.
`adb push` the model directory, executable file and .so to the device.
```bash
cd /path/to/mmdeploy
adb push resnet50 /data/local/tmp/resnet50
adb push /mmclassification_dir/demo/demo.JPEG /data/local/tmp/resnet50/demo.JPEG
cd build
adb push lib /data/local/tmp/lib
adb push bin/image_classification /data/local/tmp/image_classification
```
Set up environment variable and execute the sample.
```bash
adb shell
cd /data/local/tmp
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib
./image_classification cpu ./resnet50 ./resnet50/demo.JPEG
..
label: 65, score: 0.95
```
## Troubleshooting
- Quantization fails.
Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `True`. Please modify the settings of `Normalize` in the `model_cfg` from
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
```
to
```python
img_norm_cfg = dict(
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`.

View File

@ -1,80 +1,9 @@
# RKNN support
# Supported RKNN feature
This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`.
Currently, MMDeploy only tests rk3588 with linux platform.
## Installation
The following features cannot be automatically enabled by mmdeploy and you need to manually modify the configuration in MMDeploy like [here](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py).
It is recommended to create a virtual environment for the project.
1. get RKNN-Toolkit2 through:
```
git clone https://github.com/rockchip-linux/rknn-toolkit2
```
2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc). In our testing, we used the rknn-toolkit 1.2.0 with commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`.
3. reinstall MMDeploy from source following the [instructions](../01-how-to-build/build_from_source.md). Note that there are conflicts between the pip dependencies of MMDeploy and RKNN. Here is the suggested packages versions for python 3.6:
```
protobuf==3.19.4
onnx==1.8.0
onnxruntime==1.8.0
torch==1.8.0
torchvision==0.9.0
```
To work with models from [MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), you may need to install it additionally.
## Usage
Example:
```bash
python tools/deploy.py \
configs/mmdet/detection/detection_rknn_static.py \
/mmdetection_dir/mmdetection/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py \
/tmp/snapshots/yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8.pth \
tests/data/tiger.jpeg \
--work-dir ../deploy_result \
--device cpu
```
## Deployment config
With the deployment config, you can modify the `backend_config` for your preference. An example `backend_config` of mmclassification is shown as below:
```python
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None,
std_values=None,
target_platform='rk3588',
optimization_level=3),
quantization_config=dict(do_quantization=False, dataset=None),
input_size_list=[[3, 224, 224]])
```
The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`.
## Troubleshooting
- Quantization fails.
Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `False`. Please modify the settings of `Normalize` in the `model_cfg` from
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
```
to
```python
img_norm_cfg = dict(
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`.
- target_platform other than `3588`
- quantization settings
- optimization level other than 3

View File

@ -61,6 +61,7 @@ You can switch between Chinese and English documents in the lower-left corner of
05-supported-backends/snpe.md
05-supported-backends/tensorrt.md
05-supported-backends/torchscript.md
05-supported-backends/rknn.md
.. toctree::
:maxdepth: 1

View File

@ -42,3 +42,4 @@ git clone -b master git@github.com:open-mmlab/mmdeploy.git --recursive
- [NVIDIA Jetson](jetsons.md)
- [Qcom SNPE](snpe.md)
- [RISC-V](riscv.md)
- [Rockchip](rockchip.md)

View File

@ -0,0 +1,147 @@
# 支持 RKNN
本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。
## 安装
建议为项目创建一个虚拟环境。
1. 获取 RKNN-Toolkit2:
```
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
```
2. 通过 [官方文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc),安装 RKNN python 安装包. 在我们的测试中, 使用的 rknn-toolkit 版本是 1.2.0commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`。安装 rknn-toolkit2 时,最好在安装命令后添加`--no-deps`,以避免依赖包的冲突。比如:
```
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
```
3. 先安装onnx==1.8.0,跟着 [instructions](../01-how-to-build/build_from_source.md),源码安装 MMDeploy。 需要注意的是, MMDeploy 和 RKNN 依赖的安装包间有冲突的内容. 这里提供建议在 python 3.6 环境中使用的安装包版本:
```
protobuf==3.19.4
onnx==1.8.0
onnxruntime==1.8.0
torch==1.8.0
torchvision==0.9.0
```
4. 使用 conda 安装 torch and torchvision比如:
```
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
```
如要使用 [MMClassification](https://mmclassification.readthedocs.io/en/latest/getting_started.html) 需要用户自己安装使用。
## 使用
例子:
```bash
python tools/deploy.py \
configs/mmcls/classification_rknn_static.py \
/mmclassification_dir/configs/resnet/resnet50_8xb32_in1k.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
/mmclassification_dir/demo/demo.JPEG \
--work-dir ../resnet50 \
--device cpu
```
## 部署 config
部署 config你可以根据需要修改 `backend_config` 字段. 一个 mmclassification 的 `backend_config`例子如下:
```python
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None,
std_values=None,
target_platform='rk3588',
optimization_level=3),
quantization_config=dict(do_quantization=False, dataset=None),
input_size_list=[[3, 224, 224]])
```
`common_config` 的内容服务于 `rknn.config()`. `quantization_config` 的内容服务于 `rknn.build()`
## 安装 SDK
1. 获取 rknpu2:
```
git clone git@github.com:rockchip-linux/rknpu2.git
```
2. 在 linux 系统, 下载 gcc 交叉编译器. `rknpu2` 的官方提供的下载链接无法使用了. 用户可以使用另一个 [链接](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). 下载并解压完编译器, 打开终端, 设置 `RKNN_TOOL_CHAIN``RKNPU2_DEVICE_DIR``export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`
3. 上述准备工作完成后, 运行如下指令安装:
```shell
cd /path/to/mmdeploy
mkdir -p build && rm -rf build/CM* && cd build
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
cmake \
-DCMAKE_TOOLCHAIN_FILE=/path/to/mmdeploy/cmake/toolchains/rknpu2-linux-gnu.cmake \
-DMMDEPLOY_BUILD_SDK=ON \
-DCMAKE_BUILD_TYPE=Debug \
-DOpenCV_DIR=${RKNPU2_DEVICE_DIR}/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV \
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
-DMMDEPLOY_TARGET_DEVICES="cpu" \
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
-DMMDEPLOY_CODEBASES=all \
-DMMDEPLOY_BUILD_TEST=ON \
-DMMDEPLOY_BUILD_EXAMPLES=ON \
..
make && make install
```
## 运行 SDK 的 demo
首先,确保`--dump-info`在转模型的时候调用了, 这样工作目录下包含 SDK 需要的配置文件 `pipeline.json`
使用 `adb push` 将模型路径,执行文件和.so 文件传到板子上。
```bash
cd /path/to/mmdeploy
adb push resnet50 /data/local/tmp/resnet50
adb push /mmclassification_dir/demo/demo.JPEG /data/local/tmp/resnet50/demo.JPEG
cd build
adb push lib /data/local/tmp/lib
adb push bin/image_classification /data/local/tmp/image_classification
```
设置环境变量,运行例子。
```bash
adb shell
cd /data/local/tmp
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib
./image_classification cpu ./resnet50 ./resnet50/demo.JPEG
..
label: 65, score: 0.95
```
## 问题点
- 量化失败.
经验来说, 如果 `do_quantization` 被设置为 `True`RKNN 需要的输入没有被归一化过。请修改 `Normalize``model_cfg` 的设置,如将
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
```
改为
```python
img_norm_cfg = dict(
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
此外, deploy_cfg 的 `mean_values``std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[123.675, 116.28, 103.53]` `std_values=[58.395, 57.12, 57.375]`

View File

@ -0,0 +1,9 @@
# 支持的 RKNN 特征
目前, MMDeploy 只在 rk3588 的 linux 平台上测试过.
以下特性需要手动在 MMDeploy 自行配置,如[这里](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py).
- target_platform = `3588`
- quantization settings
- optimization level = 3

View File

@ -58,6 +58,7 @@
05-supported-backends/onnxruntime.md
05-supported-backends/openvino.md
05-supported-backends/pplnn.md
05-supported-backends/rknn.md
05-supported-backends/snpe.md
05-supported-backends/tensorrt.md
05-supported-backends/torchscript.md

View File

@ -252,6 +252,9 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
if get_backend(deploy_cfg) == Backend.RKNN:
del transforms[-2]
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item type'\
' of pipeline should be LoadImageFromFile'