[Feature] Support rv1126 in sdk (#1238)
* tmp * refine * update ssd-lite * tmp * tmp * 0.1 * 0.1.1 * rename to base_dense_head * remove debug code * wait stream * update cmakelists * add some comments * fix lint * fix ci error * fix according reviewer comments * update params * fix * support normalize with to_float being false * fix lint * support rv1126 build ci * support rv1126 build ci * change debug level * fix ci * update * update doc * fix circleci error * update normalize * update * check in build script * change namepull/1314/head
parent
940fffa075
commit
625593d6f3
|
@ -0,0 +1,44 @@
|
|||
name: backend-rknn
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "csrc/**"
|
||||
- "demo/csrc/**"
|
||||
- "CMakeLists.txt"
|
||||
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- "csrc/**"
|
||||
- "demo/csrc/**"
|
||||
- "CMakeLists.txt"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build_rknpu2:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
- name: update
|
||||
run: sudo apt update
|
||||
- name: cross compile
|
||||
run: |
|
||||
sh -x tools/scripts/ubuntu_cross_build_rknn.sh rk3588
|
||||
build_rknpu:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
- name: update
|
||||
run: sudo apt update
|
||||
- name: cross compile
|
||||
run: |
|
||||
sh -x tools/scripts/ubuntu_cross_build_rknn.sh rv1126
|
|
@ -1,54 +0,0 @@
|
|||
name: build_rknpu2_gcc
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "csrc/**"
|
||||
- "demo/csrc/**"
|
||||
- "CMakeLists.txt"
|
||||
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- "csrc/**"
|
||||
- "demo/csrc/**"
|
||||
- "CMakeLists.txt"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build_rknpu2_gcc:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
- name: rknpu2-gnu-toolchain
|
||||
run: |
|
||||
mkdir $GITHUB_WORKSPACE/rknpu2-gnu-toolchain
|
||||
cd $GITHUB_WORKSPACE/rknpu2-gnu-toolchain
|
||||
git clone https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu.git
|
||||
- name: rknpu2
|
||||
run: |
|
||||
mkdir $GITHUB_WORKSPACE/rknpu2
|
||||
cd $GITHUB_WORKSPACE/rknpu2
|
||||
git clone https://github.com/rockchip-linux/rknpu2.git
|
||||
- name: build
|
||||
run: |
|
||||
export RKNN_TOOL_CHAIN=$GITHUB_WORKSPACE/rknpu2-gnu-toolchain/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu/usr
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
export RKNPU2_DEVICE_DIR=$GITHUB_WORKSPACE/rknpu2/rknpu2/runtime/RK3588
|
||||
mkdir build && cd build
|
||||
cmake .. \
|
||||
-DCMAKE_TOOLCHAIN_FILE=$(pwd)/../cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DMMDEPLOY_SHARED_LIBS=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
-DMMDEPLOY_TARGET_DEVICES="cpu" \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DMMDEPLOY_CODEBASES=all \
|
||||
-DOpenCV_DIR=$RKNPU2_DEVICE_DIR/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV
|
||||
make -j$(nproc)
|
||||
make install
|
|
@ -0,0 +1,16 @@
|
|||
set(CMAKE_SYSTEM_NAME Linux)
|
||||
set(CMAKE_SYSTEM_PROCESSOR arm)
|
||||
|
||||
set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc")
|
||||
set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++")
|
||||
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
|
||||
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon")
|
||||
set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon")
|
||||
|
||||
# cache flags
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags")
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "base_dense_head.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mmdeploy/core/model.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
BaseDenseHead::BaseDenseHead(const Value& cfg) : MMDetection(cfg) {
|
||||
auto init = [&]() -> Result<void> {
|
||||
auto model = cfg["context"]["model"].get<Model>();
|
||||
if (cfg.contains("params")) {
|
||||
nms_pre_ = cfg["params"].value("nms_pre", -1);
|
||||
score_thr_ = cfg["params"].value("score_thr", 0.02f);
|
||||
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
|
||||
iou_threshold_ = cfg["params"].contains("nms")
|
||||
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
|
||||
: 0.45f;
|
||||
}
|
||||
return success();
|
||||
};
|
||||
init().value();
|
||||
}
|
||||
|
||||
Result<Value> BaseDenseHead::operator()(const Value& prep_res, const Value& infer_res) {
|
||||
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
|
||||
try {
|
||||
auto dets = infer_res["dets"].get<Tensor>();
|
||||
auto scores = infer_res["labels"].get<Tensor>();
|
||||
const Device kHost{0, 0};
|
||||
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
|
||||
OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(scores, kHost, stream()));
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], _dets, _scores));
|
||||
return to_value(result);
|
||||
} catch (...) {
|
||||
return Status(eFail);
|
||||
}
|
||||
}
|
||||
|
||||
Result<Detections> BaseDenseHead::GetBBoxes(const Value& prep_res, const Tensor& dets,
|
||||
const Tensor& scores) const {
|
||||
MMDEPLOY_DEBUG("dets: {}, {}", dets.shape(), dets.data_type());
|
||||
MMDEPLOY_DEBUG("scores: {}, {}", scores.shape(), scores.data_type());
|
||||
|
||||
std::vector<float> probs;
|
||||
std::vector<int> label_ids;
|
||||
std::vector<int> anchor_idxs;
|
||||
|
||||
FilterScoresAndTopk(scores, score_thr_, nms_pre_, probs, label_ids, anchor_idxs);
|
||||
|
||||
Sort(probs, label_ids, anchor_idxs);
|
||||
|
||||
NMS(dets, iou_threshold_, anchor_idxs);
|
||||
|
||||
Detections objs;
|
||||
std::vector<float> scale_factor;
|
||||
if (prep_res.contains("scale_factor")) {
|
||||
from_value(prep_res["scale_factor"], scale_factor);
|
||||
} else {
|
||||
scale_factor = {1.f, 1.f, 1.f, 1.f};
|
||||
}
|
||||
int ori_width = prep_res["ori_shape"][2].get<int>();
|
||||
int ori_height = prep_res["ori_shape"][1].get<int>();
|
||||
auto det_ptr = dets.data<float>();
|
||||
for (int i = 0; i < anchor_idxs.size(); ++i) {
|
||||
if (anchor_idxs[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
int j = anchor_idxs[i];
|
||||
auto x1 = det_ptr[j * 4 + 0];
|
||||
auto y1 = det_ptr[j * 4 + 1];
|
||||
auto x2 = det_ptr[j * 4 + 2];
|
||||
auto y2 = det_ptr[j * 4 + 3];
|
||||
int label_id = label_ids[i];
|
||||
float score = probs[i];
|
||||
|
||||
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);
|
||||
|
||||
auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height);
|
||||
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
|
||||
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
|
||||
rect[3] - rect[1]);
|
||||
continue;
|
||||
}
|
||||
Detection det{};
|
||||
det.index = i;
|
||||
det.label_id = label_id;
|
||||
det.score = score;
|
||||
det.bbox = rect;
|
||||
objs.push_back(std::move(det));
|
||||
}
|
||||
|
||||
return objs;
|
||||
}
|
||||
|
||||
REGISTER_CODEBASE_COMPONENT(MMDetection, BaseDenseHead);
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef MMDEPLOY_CODEBASE_MMDET_BASE_DENSE_HEAD_H_
|
||||
#define MMDEPLOY_CODEBASE_MMDET_BASE_DENSE_HEAD_H_
|
||||
|
||||
#include "mmdeploy/codebase/mmdet/mmdet.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
class BaseDenseHead : public MMDetection {
|
||||
public:
|
||||
explicit BaseDenseHead(const Value& cfg);
|
||||
|
||||
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
|
||||
Result<Detections> GetBBoxes(const Value& prep_res, const Tensor& dets,
|
||||
const Tensor& scores) const;
|
||||
|
||||
private:
|
||||
float score_thr_{0.4f};
|
||||
int nms_pre_{1000};
|
||||
float iou_threshold_{0.45f};
|
||||
int min_bbox_size_{0};
|
||||
};
|
||||
} // namespace mmdeploy::mmdet
|
||||
|
||||
#endif // MMDEPLOY_CODEBASE_MMDET_BASE_DENSE_HEAD_H_
|
|
@ -0,0 +1,93 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
using mmdeploy::framework::Tensor;
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
std::array<float, 4> MapToOriginImage(float left, float top, float right, float bottom,
|
||||
const float* scale_factor, float x_offset, float y_offset,
|
||||
int ori_width, int ori_height) {
|
||||
left = std::max(left / scale_factor[0] + x_offset, 0.f);
|
||||
top = std::max(top / scale_factor[1] + y_offset, 0.f);
|
||||
right = std::min(right / scale_factor[2] + x_offset, (float)ori_width - 1.f);
|
||||
bottom = std::min(bottom / scale_factor[3] + y_offset, (float)ori_height - 1.f);
|
||||
return {left, top, right, bottom};
|
||||
}
|
||||
|
||||
void FilterScoresAndTopk(const Tensor& scores, float score_thr, int topk, std::vector<float>& probs,
|
||||
std::vector<int>& label_ids, std::vector<int>& anchor_idxs) {
|
||||
auto kDets = scores.shape(1);
|
||||
auto kClasses = scores.shape(2);
|
||||
auto score_ptr = scores.data<float>();
|
||||
|
||||
for (auto i = 0; i < kDets; ++i, score_ptr += kClasses) {
|
||||
auto iter = std::max_element(score_ptr, score_ptr + kClasses);
|
||||
auto max_score = *iter;
|
||||
if (*iter < score_thr) {
|
||||
continue;
|
||||
}
|
||||
probs.push_back(*iter);
|
||||
label_ids.push_back(iter - score_ptr);
|
||||
anchor_idxs.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
float IOU(float xmin0, float ymin0, float xmax0, float ymax0, float xmin1, float ymin1, float xmax1,
|
||||
float ymax1) {
|
||||
auto w = std::max(0.f, std::min(xmax0, xmax1) - std::max(xmin0, xmin1));
|
||||
auto h = std::max(0.f, std::min(ymax0, ymax1) - std::max(ymin0, ymin1));
|
||||
auto area = w * h;
|
||||
auto sum = (xmax0 - xmin0) * (ymax0 - ymin0) + (xmax1 - xmin1) * (ymax1 - ymin1);
|
||||
auto iou = area / (sum - area);
|
||||
return iou <= 0.f ? 0.f : iou;
|
||||
}
|
||||
|
||||
void NMS(const Tensor& dets, float iou_threshold, std::vector<int>& keep_idxs) {
|
||||
auto det_ptr = dets.data<float>();
|
||||
for (auto i = 0; i < keep_idxs.size(); ++i) {
|
||||
auto n = keep_idxs[i];
|
||||
for (auto j = i + 1; j < keep_idxs.size(); ++j) {
|
||||
auto m = keep_idxs[j];
|
||||
|
||||
// `delta_xywh_bbox_coder` decode return tl_x, tl_y, br_x, br_y
|
||||
float xmin0 = det_ptr[n * 4 + 0];
|
||||
float ymin0 = det_ptr[n * 4 + 1];
|
||||
float xmax0 = det_ptr[n * 4 + 2];
|
||||
float ymax0 = det_ptr[n * 4 + 3];
|
||||
|
||||
float xmin1 = det_ptr[m * 4 + 0];
|
||||
float ymin1 = det_ptr[m * 4 + 1];
|
||||
float xmax1 = det_ptr[m * 4 + 2];
|
||||
float ymax1 = det_ptr[m * 4 + 3];
|
||||
|
||||
float iou = IOU(xmin0, ymin0, xmax0, ymax0, xmin1, ymin1, xmax1, ymax1);
|
||||
|
||||
if (iou > iou_threshold) {
|
||||
keep_idxs[j] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Sort(std::vector<float>& probs, std::vector<int>& label_ids, std::vector<int>& anchor_idxs) {
|
||||
std::vector<int> prob_idxs(probs.size());
|
||||
std::iota(prob_idxs.begin(), prob_idxs.end(), 0);
|
||||
std::sort(prob_idxs.begin(), prob_idxs.end(), [&](int i, int j) { return probs[i] > probs[j]; });
|
||||
std::vector<float> _probs;
|
||||
std::vector<int> _label_ids;
|
||||
std::vector<int> _keep_idxs;
|
||||
for (auto idx : prob_idxs) {
|
||||
_probs.push_back(probs[idx]);
|
||||
_label_ids.push_back(label_ids[idx]);
|
||||
_keep_idxs.push_back(anchor_idxs[idx]);
|
||||
}
|
||||
probs = std::move(_probs);
|
||||
label_ids = std::move(_label_ids);
|
||||
anchor_idxs = std::move(_keep_idxs);
|
||||
}
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_CODEBASE_MMDET_UTILS_H_
|
||||
#define MMDEPLOY_CODEBASE_MMDET_UTILS_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
std::array<float, 4> MapToOriginImage(float left, float top, float right, float bottom,
|
||||
const float* scale_factor, float x_offset, float y_offset,
|
||||
int ori_width, int ori_height);
|
||||
// @brief Filter results using score threshold and topk candidates.
|
||||
// scores (Tensor): The scores, shape (num_bboxes, K).
|
||||
// probs: The scores after being filtered
|
||||
// label_ids: The class labels
|
||||
// anchor_idxs: The anchor indexes
|
||||
void FilterScoresAndTopk(const mmdeploy::framework::Tensor& scores, float score_thr, int topk,
|
||||
std::vector<float>& probs, std::vector<int>& label_ids,
|
||||
std::vector<int>& anchor_idxs);
|
||||
float IOU(float xmin0, float ymin0, float xmax0, float ymax0, float xmin1, float ymin1, float xmax1,
|
||||
float ymax1);
|
||||
|
||||
void Sort(std::vector<float>& probs, std::vector<int>& label_ids, std::vector<int>& anchor_idxs);
|
||||
|
||||
void NMS(const mmdeploy::framework::Tensor& dets, float iou_threshold, std::vector<int>& keep_idxs);
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
||||
|
||||
#endif // MMDEPLOY_CODEBASE_MMDET_UTILS_H_
|
|
@ -28,6 +28,7 @@ class DirectoryModelImpl : public ModelImpl {
|
|||
auto _path = root_ / fs::path(file_path);
|
||||
std::ifstream ifs(_path, std::ios::binary | std::ios::in);
|
||||
if (!ifs.is_open()) {
|
||||
MMDEPLOY_ERROR("read file {} failed", _path.string());
|
||||
return Status(eFail);
|
||||
}
|
||||
ifs.seekg(0, std::ios::end);
|
||||
|
|
|
@ -2,19 +2,36 @@
|
|||
|
||||
project(mmdeploy_rknn_net)
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} rknn_net.cpp)
|
||||
|
||||
add_library(rknn SHARED IMPORTED)
|
||||
|
||||
if(DEFINED ENV{RKNPU2_DEVICE_DIR})
|
||||
if (DEFINED ENV{RKNPU2_DEVICE_DIR})
|
||||
file(TO_CMAKE_PATH $ENV{RKNPU2_DEVICE_DIR} RKNPU2_DEVICE_DIR)
|
||||
else()
|
||||
message(FATAL_ERROR "RKNPU2_DEVICE_DIR env must be defined")
|
||||
endif()
|
||||
endif ()
|
||||
if (DEFINED RKNPU2_DEVICE_DIR)
|
||||
set_target_properties(rknn PROPERTIES
|
||||
IMPORTED_LOCATION "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/aarch64/librknn_api.so"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/include"
|
||||
)
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE RK_MODELS)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE rknn)
|
||||
add_library(mmdeploy::rknn_net ALIAS ${PROJECT_NAME})
|
||||
endif ()
|
||||
|
||||
set_target_properties(rknn PROPERTIES
|
||||
IMPORTED_LOCATION "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/aarch64/librknn_api.so"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${RKNPU2_DEVICE_DIR}/Linux/librknn_api/include"
|
||||
)
|
||||
if (DEFINED ENV{RKNPU_DEVICE_DIR})
|
||||
file(TO_CMAKE_PATH $ENV{RKNPU_DEVICE_DIR} RKNPU_DEVICE_DIR)
|
||||
endif ()
|
||||
if (DEFINED RKNPU_DEVICE_DIR)
|
||||
set_target_properties(rknn PROPERTIES IMPORTED_CONFIGURATIONS RELEASE
|
||||
IMPORTED_LOCATION_RELEASE "${RKNPU_DEVICE_DIR}/lib/librknn_api.so"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${RKNPU_DEVICE_DIR}/include"
|
||||
)
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE RV_MODELS)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE rknn)
|
||||
add_library(mmdeploy::rknn_net ALIAS ${PROJECT_NAME})
|
||||
endif ()
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} rknn_net.cpp)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE rknn)
|
||||
add_library(mmdeploy::rknn_net ALIAS ${PROJECT_NAME})
|
||||
if (NOT (DEFINED RKNPU2_DEVICE_DIR OR RKNPU_DEVICE_DIR))
|
||||
message(FATAL_ERROR "RKNPU2_DEVICE_DIR or RKNPU_DEVICE_DIR must be defined")
|
||||
endif ()
|
||||
|
|
|
@ -12,7 +12,54 @@
|
|||
|
||||
namespace mmdeploy::framework {
|
||||
|
||||
Result<rknn_tensor_type> GetRKNNDataType(DataType data_type) {
|
||||
static inline const char* const rknn_type(rknn_tensor_type type) {
|
||||
switch (type) {
|
||||
case RKNN_TENSOR_FLOAT32:
|
||||
return "FP32";
|
||||
case RKNN_TENSOR_FLOAT16:
|
||||
return "FP16";
|
||||
case RKNN_TENSOR_INT8:
|
||||
return "INT8";
|
||||
case RKNN_TENSOR_UINT8:
|
||||
return "UINT8";
|
||||
case RKNN_TENSOR_INT16:
|
||||
return "INT16";
|
||||
#ifdef RK_MODELS
|
||||
case RKNN_TENSOR_INT32:
|
||||
return "INT32";
|
||||
case RKNN_TENSOR_INT64:
|
||||
return "INT64";
|
||||
#endif
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static inline const char* const rknn_format(rknn_tensor_format fmt) {
|
||||
switch (fmt) {
|
||||
case RKNN_TENSOR_NCHW:
|
||||
return "NCHW";
|
||||
case RKNN_TENSOR_NHWC:
|
||||
return "NHWC";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static inline const char* const rknn_qnt_type(rknn_tensor_qnt_type type) {
|
||||
switch (type) {
|
||||
case RKNN_TENSOR_QNT_NONE:
|
||||
return "NONE";
|
||||
case RKNN_TENSOR_QNT_DFP:
|
||||
return "DFP";
|
||||
case RKNN_TENSOR_QNT_AFFINE_ASYMMETRIC:
|
||||
return "AFFINE";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static Result<rknn_tensor_type> GetRKNNDataType(DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DataType::kFLOAT:
|
||||
return RKNN_TENSOR_FLOAT32;
|
||||
|
@ -20,42 +67,50 @@ Result<rknn_tensor_type> GetRKNNDataType(DataType data_type) {
|
|||
return RKNN_TENSOR_FLOAT16;
|
||||
case DataType::kINT8:
|
||||
return RKNN_TENSOR_INT8;
|
||||
#ifdef RK_MODELS
|
||||
case DataType::kINT32:
|
||||
return RKNN_TENSOR_INT32;
|
||||
case DataType::kINT64:
|
||||
return RKNN_TENSOR_INT64;
|
||||
#endif
|
||||
default:
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
}
|
||||
|
||||
Result<DataType> GetMMDeployDataType(rknn_tensor_type data_type) {
|
||||
switch (data_type) {
|
||||
static Result<DataType> GetMMDeployDataType(rknn_tensor_type type) {
|
||||
switch (type) {
|
||||
case RKNN_TENSOR_FLOAT32:
|
||||
return DataType::kFLOAT;
|
||||
case RKNN_TENSOR_FLOAT16:
|
||||
return DataType::kHALF;
|
||||
case RKNN_TENSOR_INT8:
|
||||
case RKNN_TENSOR_INT8: // fall through
|
||||
case RKNN_TENSOR_UINT8:
|
||||
return DataType::kINT8;
|
||||
#ifdef RK_MODELS
|
||||
case RKNN_TENSOR_INT32:
|
||||
return DataType::kINT32;
|
||||
case RKNN_TENSOR_INT64:
|
||||
return DataType::kINT64;
|
||||
#endif
|
||||
default:
|
||||
MMDEPLOY_ERROR("unsupported rknn_tensor_type: {}", rknn_type(type));
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
}
|
||||
|
||||
RKNNNet::~RKNNNet() { rknn_destroy(ctx_); }
|
||||
|
||||
void RKNNNet::dump_tensor_attr(rknn_tensor_attr* attr) {
|
||||
MMDEPLOY_INFO(
|
||||
" index={}, name={}, n_dims={}, dims=[{}, {}, {}, {}], n_elems={}, size={}, fmt={}, "
|
||||
"type={}, qnt_type={}, "
|
||||
"zp={}, scale=%f\n",
|
||||
attr->index, attr->name, attr->n_dims, attr->dims[0], attr->dims[1], attr->dims[2],
|
||||
attr->dims[3], attr->n_elems, attr->size, get_format_string(attr->fmt),
|
||||
get_type_string(attr->type), get_qnt_type_string(attr->qnt_type), attr->zp, attr->scale);
|
||||
void RKNNNet::PrintRKNNTensorAttr(const char* tag, const std::vector<rknn_tensor_attr>& attrs) {
|
||||
MMDEPLOY_INFO("{} tensors: ", tag);
|
||||
for (auto& attr : attrs) {
|
||||
MMDEPLOY_INFO(
|
||||
" - index={}, name={}, type={}, n_dims={}, dims=[{}, {}, {}, {}], n_elems={}, size={},"
|
||||
" fmt={}, qnt_type={}, zp={}, scale={}",
|
||||
attr.index, attr.name, rknn_type(attr.type), attr.n_dims, attr.dims[0], attr.dims[1],
|
||||
attr.dims[2], attr.dims[3], attr.n_elems, attr.size, rknn_format(attr.fmt),
|
||||
rknn_qnt_type(attr.qnt_type), attr.zp, attr.scale);
|
||||
}
|
||||
}
|
||||
|
||||
Result<void> RKNNNet::Init(const Value& args) {
|
||||
|
@ -73,50 +128,74 @@ Result<void> RKNNNet::Init(const Value& args) {
|
|||
std::string content;
|
||||
OUTCOME_TRY(content, model.ReadFile(config.net));
|
||||
char* model_ptr = const_cast<char*>(content.data());
|
||||
#ifdef RK_MODELS
|
||||
int ret = rknn_init(&ctx_, model_ptr, content.size(), 0, NULL);
|
||||
#endif
|
||||
#ifdef RV_MODELS
|
||||
int ret = rknn_init(&ctx_, model_ptr, content.size(), 0);
|
||||
#endif
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_ERROR("Load .rknn failed! ret= {}", ret);
|
||||
return Status(eInvalidArgument);
|
||||
MMDEPLOY_ERROR("init rknn model with {} failed! ret: {}", config.net, ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
|
||||
// Get Model Input Output Info
|
||||
rknn_input_output_num io_num;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_INFO("model input num: {}, output num: {}\n", io_num.n_input, io_num.n_output);
|
||||
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
|
||||
MMDEPLOY_ERROR("rknn query 'RKNN_QUERY_IN_OUT_NUM' fail! ret: {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
MMDEPLOY_DEBUG("model input num: {}, output num: {}", io_num.n_input, io_num.n_output);
|
||||
|
||||
auto get_tensor_shape = [](rknn_tensor_attr& attr) -> Result<TensorShape> {
|
||||
TensorShape shape;
|
||||
for (int i = 0; i < attr.n_dims; ++i) {
|
||||
shape.push_back(attr.dims[i]);
|
||||
}
|
||||
#ifdef RK_MODELS
|
||||
return shape;
|
||||
#endif
|
||||
#ifdef RV_MODELS
|
||||
std::reverse(shape.begin(), shape.end());
|
||||
return shape;
|
||||
#endif
|
||||
};
|
||||
|
||||
for (int i = 0; i < io_num.n_input; i++) {
|
||||
rknn_tensor_attr input_attr;
|
||||
input_attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &(input_attr), sizeof(rknn_tensor_attr));
|
||||
rknn_tensor_attr attr;
|
||||
attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &(attr), sizeof(rknn_tensor_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_INFO("input tensors:\n");
|
||||
dump_tensor_attr(&(input_attr));
|
||||
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
|
||||
MMDEPLOY_ERROR("rknn query 'RKNN_QUERY_INPUT_ATTR' fail! ret: {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
input_attrs_.push_back(input_attr);
|
||||
OUTCOME_TRY(auto data_type, GetMMDeployDataType(input_attr.type));
|
||||
input_tensors_.emplace_back(TensorDesc{device_, data_type, {}, input_attr.name});
|
||||
if (attr.type != RKNN_TENSOR_UINT8) {
|
||||
MMDEPLOY_ERROR("MMDeploy SDK only supports RKNN-INT8 model");
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
input_attrs_.push_back(attr);
|
||||
// Only support uint8 input data
|
||||
OUTCOME_TRY(auto data_type, GetMMDeployDataType(RKNN_TENSOR_UINT8));
|
||||
input_tensors_.emplace_back(
|
||||
TensorDesc{device_, data_type, get_tensor_shape(attr).value(), "#" + std::to_string(i)});
|
||||
}
|
||||
PrintRKNNTensorAttr("input", input_attrs_);
|
||||
|
||||
for (int i = 0; i < io_num.n_output; i++) {
|
||||
rknn_tensor_attr output_attr;
|
||||
output_attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &(output_attr), sizeof(rknn_tensor_attr));
|
||||
rknn_tensor_attr attr;
|
||||
attr.index = i;
|
||||
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &(attr), sizeof(rknn_tensor_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
MMDEPLOY_INFO("output tensors:\n");
|
||||
dump_tensor_attr(&(output_attr));
|
||||
MMDEPLOY_ERROR("rknn_query fail! ret= {}", ret);
|
||||
MMDEPLOY_ERROR("rknn query 'RKNN_QUERY_OUTPUT_ATTR' fail! ret: {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
output_attrs_.push_back(output_attr);
|
||||
OUTCOME_TRY(auto data_type, GetMMDeployDataType(output_attr.type));
|
||||
output_tensors_.emplace_back(TensorDesc{device_, data_type, {}, output_attr.name});
|
||||
output_attrs_.push_back(attr);
|
||||
// MMDeploy SDK always make the output data type as float
|
||||
output_tensors_.emplace_back(TensorDesc{
|
||||
device_, DataType::kFLOAT, get_tensor_shape(attr).value(), "#" + std::to_string(i)});
|
||||
}
|
||||
PrintRKNNTensorAttr("output", output_attrs_);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -143,10 +222,11 @@ Result<void> RKNNNet::Forward() {
|
|||
for (int i = 0; i < input_tensors_.size(); i++) {
|
||||
rknn_input input;
|
||||
input.index = i;
|
||||
// '0' let the buf data be converted into an input consistent with the model
|
||||
input.pass_through = 0;
|
||||
input.type = input_attrs_[i].type;
|
||||
input.fmt = input_attrs_[i].fmt;
|
||||
input.buf = input_tensors_[i].data<float>();
|
||||
input.type = RKNN_TENSOR_UINT8; // data type of input buf
|
||||
input.fmt = RKNN_TENSOR_NHWC; // data format of input buf
|
||||
input.buf = input_tensors_[i].data();
|
||||
input.size = input_attrs_[i].size;
|
||||
inputs.push_back(input);
|
||||
}
|
||||
|
@ -158,35 +238,28 @@ Result<void> RKNNNet::Forward() {
|
|||
return Status(eFail);
|
||||
}
|
||||
|
||||
// Get output
|
||||
std::vector<rknn_output> outputs;
|
||||
for (uint32_t i = 0; i < output_tensors_.size(); ++i) {
|
||||
rknn_output output;
|
||||
output.want_float = 1;
|
||||
output.index = i;
|
||||
output.is_prealloc = 0;
|
||||
outputs.push_back(output);
|
||||
}
|
||||
|
||||
// Forward
|
||||
ret = rknn_run(ctx_, NULL);
|
||||
if (ret < 0) {
|
||||
MMDEPLOY_ERROR("rknn_run fail! ret={}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
|
||||
ret = rknn_outputs_get(ctx_, output_tensors_.size(), outputs.data(), NULL);
|
||||
// Get output
|
||||
std::vector<rknn_output> outputs(output_tensors_.size());
|
||||
for (uint32_t i = 0; i < output_tensors_.size(); ++i) {
|
||||
outputs[i].want_float = 1;
|
||||
outputs[i].is_prealloc = 1; // use pre-allocated buffer in `output_tensors_`
|
||||
outputs[i].index = 1;
|
||||
outputs[i].buf = output_tensors_[i].data();
|
||||
outputs[i].size = output_tensors_[i].byte_size();
|
||||
}
|
||||
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), NULL);
|
||||
if (ret < 0) {
|
||||
MMDEPLOY_ERROR("rknn_outputs_get fail! ret= {}", ret);
|
||||
return Status(eFail);
|
||||
}
|
||||
for (int i = 0; i < output_tensors_.size(); i++) {
|
||||
TensorShape tensor_shape;
|
||||
for (int j = 0; j < output_attrs_[i].n_dims; ++j) {
|
||||
tensor_shape.push_back(output_attrs_[i].dims[j]);
|
||||
}
|
||||
output_tensors_[i].Reshape(tensor_shape);
|
||||
memcpy(output_tensors_[i].data<float>(), (float*)outputs[i].buf, output_attrs_[i].size);
|
||||
}
|
||||
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ class RKNNNet : public Net {
|
|||
Result<void> ForwardAsync(Event* event) override;
|
||||
|
||||
private:
|
||||
void dump_tensor_attr(rknn_tensor_attr* attr);
|
||||
void PrintRKNNTensorAttr(const char* tag, const std::vector<rknn_tensor_attr>& attrs);
|
||||
|
||||
Device device_;
|
||||
Stream stream_;
|
||||
|
|
|
@ -25,6 +25,14 @@ class NormalizeImpl : public ::mmdeploy::NormalizeImpl {
|
|||
auto dst_mat = Normalize(mat, arg_.mean, arg_.std, arg_.to_rgb, true);
|
||||
return CVMat2Tensor(dst_mat);
|
||||
}
|
||||
|
||||
Result<Tensor> ConvertToRGB(const Tensor& tensor) override {
|
||||
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
|
||||
SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);
|
||||
auto src_mat = Tensor2CVMat(tensor);
|
||||
auto dst_mat = ColorTransfer(src_mat, PixelFormat::kBGR, PixelFormat::kRGB);
|
||||
return CVMat2Tensor(dst_mat);
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizeImplCreator : public Creator<::mmdeploy::NormalizeImpl> {
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "mmdeploy/preprocess/transform/normalize.h"
|
||||
#include "ppl/cv/cuda/cvtcolor.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ppl::cv::cuda;
|
||||
|
||||
namespace mmdeploy::cuda {
|
||||
|
||||
|
@ -66,6 +68,25 @@ class NormalizeImpl : public ::mmdeploy::NormalizeImpl {
|
|||
}
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
Result<Tensor> ConvertToRGB(const Tensor& tensor) override {
|
||||
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
|
||||
|
||||
SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);
|
||||
|
||||
auto src_desc = src_tensor.desc();
|
||||
int h = (int)src_desc.shape[1];
|
||||
int w = (int)src_desc.shape[2];
|
||||
int c = (int)src_desc.shape[3];
|
||||
int stride = w * c;
|
||||
auto stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
|
||||
|
||||
TensorDesc dst_desc{device_, DataType::kINT8, src_desc.shape, src_desc.name};
|
||||
Tensor dst_tensor{dst_desc};
|
||||
RGB2BGR<uint8_t>(stream, h, w, stride, tensor.data<uint8_t>(), stride,
|
||||
dst_tensor.data<uint8_t>());
|
||||
return dst_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizeImplCreator : public Creator<::mmdeploy::NormalizeImpl> {
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "normalize.h"
|
||||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/core/registry.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
using namespace std;
|
||||
|
@ -25,6 +25,16 @@ NormalizeImpl::NormalizeImpl(const Value& args) : TransformImpl(args) {
|
|||
arg_.std.push_back(v.get<float>());
|
||||
}
|
||||
arg_.to_rgb = args.value("to_rgb", true);
|
||||
arg_.to_float = args.value("to_float", true);
|
||||
// assert `mean` is 0 and `std` is 1 when `to_float` is false
|
||||
if (!arg_.to_float) {
|
||||
for (int i = 0; i < arg_.mean.size(); ++i) {
|
||||
if ((int)arg_.mean[i] != 0 || (int)arg_.std[i] != 1) {
|
||||
MMDEPLOY_ERROR("mean {} and std {} are not supported in int8 case", arg_.mean, arg_.std);
|
||||
throw_exception(eInvalidArgument);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -66,8 +76,15 @@ Result<Value> NormalizeImpl::Process(const Value& input) {
|
|||
assert(desc.shape.size() == 4 /*n, h, w, c*/);
|
||||
assert(desc.shape[3] == arg_.mean.size());
|
||||
|
||||
OUTCOME_TRY(auto dst, NormalizeImage(tensor));
|
||||
SetTransformData(output, key, std::move(dst));
|
||||
if (arg_.to_float) {
|
||||
OUTCOME_TRY(auto dst, NormalizeImage(tensor));
|
||||
SetTransformData(output, key, std::move(dst));
|
||||
} else {
|
||||
if (arg_.to_rgb) {
|
||||
OUTCOME_TRY(auto dst, ConvertToRGB(tensor));
|
||||
SetTransformData(output, key, std::move(dst));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& v : arg_.mean) {
|
||||
output["img_norm_cfg"]["mean"].push_back(v);
|
||||
|
|
|
@ -17,12 +17,14 @@ class MMDEPLOY_API NormalizeImpl : public TransformImpl {
|
|||
|
||||
protected:
|
||||
virtual Result<Tensor> NormalizeImage(const Tensor& img) = 0;
|
||||
virtual Result<Tensor> ConvertToRGB(const Tensor& img) = 0;
|
||||
|
||||
protected:
|
||||
struct normalize_arg_t {
|
||||
std::vector<float> mean;
|
||||
std::vector<float> std;
|
||||
bool to_rgb;
|
||||
bool to_float;
|
||||
};
|
||||
using ArgType = struct normalize_arg_t;
|
||||
ArgType arg_;
|
||||
|
|
|
@ -1,32 +1,69 @@
|
|||
# 支持 RKNN
|
||||
# 瑞芯微 NPU 部署
|
||||
|
||||
本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。对于不同的 NPU 设备,您需要使用不同的 rknn 包.
|
||||
这是设备和安装包的关系表:
|
||||
- [模型转换](#模型转换)
|
||||
- [安装环境](#安装环境)
|
||||
- [分裂模型转换](#分类模型转换)
|
||||
- [检测模型转换](#检测模型转换)
|
||||
- [模型推理](#模型推理)
|
||||
- [Host 交叉编译](#Host-交叉编译)
|
||||
- [Device 执行推理](#Device-执行推理)
|
||||
|
||||
| Device | Python Package | c/c++ SDK |
|
||||
| -------------------- | ---------------------------------------------------------------- | -------------------------------------------------- |
|
||||
| RK1808/RK1806 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) |
|
||||
| RV1109/RV1126 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) |
|
||||
| RK3566/RK3568/RK3588 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
|
||||
| RV1103/RV1106 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
|
||||
______________________________________________________________________
|
||||
|
||||
## 安装
|
||||
MMDeploy 支持把模型部署到瑞芯微设备上。已支持的芯片:RV1126、RK3588。
|
||||
|
||||
建议为项目创建一个虚拟环境。
|
||||
完整的部署过程包含两个步骤:
|
||||
|
||||
1. 使用 git 获取 RKNN-Toolkit2 或者 RKNN-Toolkit。以 RKNN-Toolkit2 为例:
|
||||
1. 模型转换
|
||||
|
||||
```
|
||||
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
|
||||
```
|
||||
- 在主机上,将 PyTorch 模型转换为 RKNN 模型
|
||||
|
||||
2. 通过 [rknn-toolkit2 文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc) 或者 [rknn-toolkit 文档](https://github.com/rockchip-linux/rknn-toolkit/tree/master/doc)安装 RKNN python 安装包。安装 rknn python 包时,最好在安装命令后添加`--no-deps`,以避免依赖包的冲突。以rknn-toolkit2为例:
|
||||
2. 模型推理
|
||||
|
||||
- 在主机上, 使用交叉编译工具得到设备所需的 SDK 和 bin
|
||||
- 把转好的模型和编好的 SDK、bin,传到设备,进行推理
|
||||
|
||||
## 模型转换
|
||||
|
||||
### 安装环境
|
||||
|
||||
1. 请参考[快速入门](../get_started.md),创建 conda 虚拟环境,并安装 PyTorch、mmcv-full
|
||||
|
||||
2. 安装 RKNN Toolkit
|
||||
|
||||
如下表所示,瑞芯微提供了 2 套 RKNN Toolkit,对应于不同的芯片型号
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Device</th>
|
||||
<th>RKNN-Toolkit</th>
|
||||
<th>Installation Guide</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>RK1808 / RK1806 / RV1109 / RV1126</td>
|
||||
<td><code>git clone https://github.com/rockchip-linux/rknn-toolkit</code></td>
|
||||
<td><a href="https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc">安装指南</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>RK3566 / RK3568 / RK3588 / RV1103 / RV1106</td>
|
||||
<td><code>git clone https://github.com/rockchip-linux/rknn-toolkit2</code></td>
|
||||
<td><a href="https://github.com/rockchip-linux/rknn-toolkit/tree/master/doc">安装指南</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
2.1 通过 `git clone` 下载和设备匹配的 RKNN Toolkit
|
||||
|
||||
2.2 参考表中的安装指南,安装 RKNN python 安装包。建议在安装时,使用选项 `--no-deps`,以避免依赖包的冲突。以 rknn-toolkit2 为例:
|
||||
|
||||
```
|
||||
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
|
||||
```
|
||||
|
||||
3. 先安装onnx==1.8.0,跟着 [instructions](../01-how-to-build/build_from_source.md),源码安装 MMDeploy。 需要注意的是, MMDeploy 和 RKNN 依赖的安装包间有冲突的内容. 这里提供建议在 python 3.6 环境中使用的安装包版本:
|
||||
2.3 先安装onnx==1.8.0,跟着 [instructions](../01-how-to-build/build_from_source.md),源码安装 MMDeploy。 需要注意的是, MMDeploy 和 RKNN 依赖的安装包间有冲突的内容. 这里提供建议在 python 3.6 环境中使用的安装包版本:
|
||||
|
||||
```
|
||||
protobuf==3.19.4
|
||||
|
@ -36,29 +73,81 @@
|
|||
torchvision==0.9.0
|
||||
```
|
||||
|
||||
4. 使用 conda 安装 torch and torchvision,比如:
|
||||
### 分类模型转换
|
||||
|
||||
```
|
||||
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
|
||||
```
|
||||
以 mmclassification 中的 resnet50 为例,模型转换命令如下:
|
||||
|
||||
如要使用 [MMClassification](https://mmclassification.readthedocs.io/en/latest/getting_started.html), 需要用户自己安装使用。
|
||||
```shell
|
||||
# 安装 mmclassification
|
||||
pip install mmcls
|
||||
git clone https://github.com/open-mmlab/mmclassification
|
||||
|
||||
## 使用
|
||||
|
||||
例子:
|
||||
|
||||
```bash
|
||||
# 执行转换命令
|
||||
cd /the/path/of/mmdeploy
|
||||
python tools/deploy.py \
|
||||
configs/mmcls/classification_rknn_static.py \
|
||||
/mmclassification_dir/configs/resnet/resnet50_8xb32_in1k.py \
|
||||
/the/path/of/mmclassification/configs/resnet/resnet50_8xb32_in1k.py \
|
||||
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
|
||||
/mmclassification_dir/demo/demo.JPEG \
|
||||
--work-dir ../resnet50 \
|
||||
--device cpu
|
||||
/the/path/of/mmclassification/demo/demo.JPEG \
|
||||
--work-dir mmdeploy_models/mmcls/resnet50 \
|
||||
--device cpu \
|
||||
--dump-info
|
||||
```
|
||||
|
||||
## 部署 config
|
||||
```{note}
|
||||
若转换过程中,遇到 NoModuleFoundError 的问题,使用 pip install 对应的包
|
||||
```
|
||||
|
||||
### 检测模型转换
|
||||
|
||||
- YOLOV3 & YOLOX
|
||||
|
||||
将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)
|
||||
|
||||
```python
|
||||
# yolov3, yolox
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True, # should always be set to True
|
||||
partition_cfg=[
|
||||
dict(
|
||||
save_file='model.onnx', # name to save the partitioned onnx
|
||||
start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
end=['yolo_head:input']) # [mark_name:output, ...]
|
||||
])
|
||||
```
|
||||
|
||||
执行命令:
|
||||
|
||||
```shell
|
||||
# 安装 mmdet
|
||||
pip install mmdet
|
||||
git clone https://github.com/open-mmlab/mmdetection
|
||||
|
||||
# 执行转换命令
|
||||
python tools/deploy.py \
|
||||
configs/mmcls/detection_rknn_static.py \
|
||||
|
||||
```
|
||||
|
||||
- RetinaNet & SSD & FSAF with rknn-toolkit2
|
||||
|
||||
将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)。使用 rknn-toolkit 的用户则不用。
|
||||
|
||||
```python
|
||||
# retinanet, ssd
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True,
|
||||
partition_cfg=[
|
||||
dict(
|
||||
save_file='model.onnx',
|
||||
start='detector_forward:input',
|
||||
end=['BaseDenseHead:output'])
|
||||
])
|
||||
```
|
||||
|
||||
### 部署 config 说明
|
||||
|
||||
部署 config,你可以根据需要修改 `backend_config` 字段. 一个 mmclassification 的 `backend_config`例子如下:
|
||||
|
||||
|
@ -77,66 +166,7 @@ backend_config = dict(
|
|||
|
||||
`common_config` 的内容服务于 `rknn.config()`. `quantization_config` 的内容服务于 `rknn.build()`。
|
||||
|
||||
## 安装 SDK
|
||||
|
||||
### RKNPU2 编译 MMDeploy SDK
|
||||
|
||||
1. 获取 rknpu2:
|
||||
|
||||
```
|
||||
git clone git@github.com:rockchip-linux/rknpu2.git
|
||||
```
|
||||
|
||||
2. 在 linux 系统, 下载 gcc 交叉编译器. `rknpu2` 的官方提供的下载链接无法使用了. 用户可以使用另一个 [链接](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). 下载并解压完编译器, 打开终端, 设置 `RKNN_TOOL_CHAIN` 和 `RKNPU2_DEVICE_DIR` 为 `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`。
|
||||
|
||||
3. 上述准备工作完成后, 运行如下指令安装:
|
||||
|
||||
```shell
|
||||
cd /path/to/mmdeploy
|
||||
mkdir -p build && rm -rf build/CM* && cd build
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
cmake \
|
||||
-DCMAKE_TOOLCHAIN_FILE=/path/to/mmdeploy/cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DOpenCV_DIR=${RKNPU2_DEVICE_DIR}/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV \
|
||||
-DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
|
||||
-DMMDEPLOY_TARGET_DEVICES="cpu" \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DMMDEPLOY_CODEBASES=all \
|
||||
-DMMDEPLOY_BUILD_TEST=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
..
|
||||
make && make install
|
||||
```
|
||||
|
||||
## 运行 SDK 的 demo
|
||||
|
||||
首先,确保`--dump-info`在转模型的时候调用了, 这样工作目录下包含 SDK 需要的配置文件 `pipeline.json`。
|
||||
|
||||
使用 `adb push` 将模型路径,执行文件和.so 文件传到板子上。
|
||||
|
||||
```bash
|
||||
cd /path/to/mmdeploy
|
||||
adb push resnet50 /data/local/tmp/resnet50
|
||||
adb push /mmclassification_dir/demo/demo.JPEG /data/local/tmp/resnet50/demo.JPEG
|
||||
cd build
|
||||
adb push lib /data/local/tmp/lib
|
||||
adb push bin/image_classification /data/local/tmp/image_classification
|
||||
```
|
||||
|
||||
设置环境变量,运行例子。
|
||||
|
||||
```bash
|
||||
adb shell
|
||||
cd /data/local/tmp
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib
|
||||
./image_classification cpu ./resnet50 ./resnet50/demo.JPEG
|
||||
..
|
||||
label: 65, score: 0.95
|
||||
```
|
||||
|
||||
## 问题点
|
||||
### 问题说明
|
||||
|
||||
- 量化失败.
|
||||
|
||||
|
@ -156,36 +186,117 @@ label: 65, score: 0.95
|
|||
|
||||
此外, deploy_cfg 的 `mean_values` 和 `std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[[103.53, 116.28, 123.675]]`, `std_values=[[57.375, 57.12, 58.395]]`。
|
||||
|
||||
- MMDet 模型.
|
||||
|
||||
YOLOV3 & YOLOX: 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py):
|
||||
|
||||
```python
|
||||
# yolov3, yolox
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True, # should always be set to True
|
||||
partition_cfg=[
|
||||
dict(
|
||||
save_file='model.onnx', # name to save the partitioned onnx
|
||||
start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
end=['yolo_head:input']) # [mark_name:output, ...]
|
||||
])
|
||||
```
|
||||
|
||||
RetinaNet & SSD & FSAF with rknn-toolkit2, 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)。使用 rknn-toolkit 的用户则不用。
|
||||
|
||||
```python
|
||||
# retinanet, ssd
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True,
|
||||
partition_cfg=[
|
||||
dict(
|
||||
save_file='model.onnx',
|
||||
start='detector_forward:input',
|
||||
end=['BaseDenseHead:output'])
|
||||
])
|
||||
```
|
||||
|
||||
- SDK 只支持 int8 的 rknn 模型,这需要在转换模型时设置 `do_quantization=True`。
|
||||
|
||||
## 模型推理
|
||||
|
||||
### Host 交叉编译
|
||||
|
||||
mmdeploy 提供 2 种交叉编译方式:
|
||||
|
||||
1. 执行编译脚本(推荐)
|
||||
|
||||
在 Ubuntu 主机上,执行如下命令:
|
||||
|
||||
```shell
|
||||
bash tools/scripts/ubuntu_cross_build_rknn.sh <model>
|
||||
```
|
||||
|
||||
命令中的参数 model 表示瑞芯微芯片的型号,目前支持 rv1126,rk3588。
|
||||
|
||||
2. 手动配置环境并编译
|
||||
|
||||
如下表所示,瑞芯微提供了 2 套 RKNN API 工具包,对应于不同的芯片型号。而每套 RKNN API 工具包又分别对应不同的 gcc 交叉编译工具。
|
||||
|
||||
| Device | RKNN API |
|
||||
| ------------------------------------------ | -------------------------------------------------- |
|
||||
| RK1808 / RK1806 / RV1109 / RV1126 | [rknpu](https://github.com/rockchip-linux/rknpu) |
|
||||
| RK3566 / RK3568 / RK3588 / RV1103 / RV1106 | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
|
||||
|
||||
以支持的 rv1126 和 rk3588 为例,mmdeploy 在 ubuntu18.04 上的交叉编译过程如下:
|
||||
|
||||
- rv11126
|
||||
|
||||
````shell
|
||||
# 1. 下载 RKNN API 包
|
||||
git clone https://github.com/rockchip-linux/rknpu
|
||||
export RKNPU_DIR=$(pwd)/rknpu
|
||||
|
||||
# 2. 准备 gcc 交叉编译工具
|
||||
sudo apt-get update
|
||||
sudo apt-get install gcc-7-arm-linux-gnueabihf
|
||||
sudo apt-get install g++-7-arm-linux-gnueabihf
|
||||
|
||||
# 3. 下载 OpenCV
|
||||
## 在rknpu2中,有opencv-linux-armhf库。路径是 rknpu2/examples/3rdparty/opencv/opencv-linux-armhf
|
||||
git clone https://github.com/rockchip-linux/rknpu2
|
||||
export RKNPU2_DIR=$(pwd)/rknpu2
|
||||
|
||||
# 3. 编译 mmdeploy SDK
|
||||
```shell
|
||||
cd /path/to/mmdeploy
|
||||
mkdir -p build && rm -rf build/CM* && cd build
|
||||
cmake .. \
|
||||
-DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/arm-linux-gnueabihf.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DMMDEPLOY_BUILD_SDK_CXX_API=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DRKNPU_DEVICE_DIR=${RKNPU_DIR}/rknn/rknn_api/librknn_api \
|
||||
-DOpenCV_DIR=${RKNPU2_DIR}/examples/3rdparty/opencv/opencv-linux-armhf/share/OpenCV
|
||||
make -j$(nproc) && make install
|
||||
````
|
||||
|
||||
- rk3588
|
||||
|
||||
```shell
|
||||
# 1. 下载 RKNN API 包
|
||||
git clone https://github.com/rockchip-linux/rknpu2
|
||||
export RKNPU2_DEVICE_DIR=$(pwd)/rknpu2/runtime/RK3588
|
||||
|
||||
# 2. 准备 gcc 交叉编译工具
|
||||
git clone https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu
|
||||
export RKNN_TOOL_CHAIN=$(pwd)/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# 3. 编译 mmdeploy SDK
|
||||
cd /path/to/mmdeploy
|
||||
mkdir -p build && rm -rf build/CM* && cd build
|
||||
export LD_LIBRARY_PATH=$RKNN_TOOL_CHAIN/lib64:$LD_LIBRARY_PATH
|
||||
cmake \
|
||||
-DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DMMDEPLOY_BUILD_SDK_CXX_API=ON \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
-DOpenCV_DIR=${RKNPU2_DEVICE_DIR}/../../examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV
|
||||
make -j $(nproc) && make install
|
||||
```
|
||||
|
||||
### Device 执行推理
|
||||
|
||||
首先,确保`--dump-info`在转模型的时候调用了, 这样工作目录下包含 SDK 需要的配置文件 `pipeline.json`。
|
||||
|
||||
使用 `adb push` 将转好的模型、编好的 SDK 和 bin 文件推到设备上。
|
||||
|
||||
```bash
|
||||
cd {/the/path/to/mmdeploy}
|
||||
adb push mmdeploy_models/mmcls/resnet50 /root/resnet50
|
||||
adb push {/the/path/of/mmclassification}/demo/demo.JPEG /root/demo.JPEG
|
||||
adb push build/install /root/mmdeploy_sdk
|
||||
```
|
||||
|
||||
通过 adb shell,打开设备终端,设置环境变量,运行例子。
|
||||
|
||||
```bash
|
||||
adb shell
|
||||
cd /root/mmdeploy_sdk
|
||||
export LD_LIBRARY_PATH=$(pwd)/lib:${LD_LIBRARY_PATH}
|
||||
./bin/image_classification cpu ../resnet50 ../demo.JPEG
|
||||
```
|
||||
|
||||
结果显示:
|
||||
|
||||
```shell
|
||||
label: 65, score: 0.95
|
||||
```
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
#!/bin/bash
|
||||
# set -ex
|
||||
# get appropriate proc number: max(1, nproc-3)
|
||||
good_nproc() {
|
||||
num=`nproc`
|
||||
num=`expr $num - 3`
|
||||
if [ $num -lt 1 ];then
|
||||
return 1
|
||||
fi
|
||||
return ${num}
|
||||
}
|
||||
|
||||
install_rknpu_toolchain() {
|
||||
# install gcc cross compiler
|
||||
ubuntu_version=`cat /etc/issue`
|
||||
ubuntu_major_version=`echo "$ubuntu_version" | grep -oP '\d{2}' | head -n 1`
|
||||
|
||||
if [ "$ubuntu_major_version" -lt 18 ]; then
|
||||
echo "ubuntu 18.04 is minimum requirement, but got $ubuntu_version"
|
||||
wget wget https://developer.arm.com/-/media/Files/downloads/gnu-a/8.3-2019.03/binrel/gcc-arm-8.3-2019.03-x86_64-arm-linux-gnueabihf.tar.xz
|
||||
tar -xvf gcc-arm-8.3-2019.03-x86_64-arm-linux-gnueabihf.tar.xz
|
||||
ln -sf $(pwd)/gcc-arm-8.3-2019.03-x86_64-arm-linux-gnueabihf/bin/arm-linux-gnueabihf-gcc /usr/bin/arm-linux-gnueabihf-gcc
|
||||
ln -sf $(pwd)/gcc-arm-8.3-2019.03-x86_64-arm-linux-gnueabihf/bin/arm-linux-gnueabihf-g++ /usr/bin/arm-linux-gnueabihf-g++
|
||||
else
|
||||
sudo apt install -y gcc-7-arm-linux-gnueabihf g++-7-arm-linux-gnueabihf
|
||||
fi
|
||||
arm-linux-gnueabihf-gcc --version
|
||||
arm-linux-gnueabihf-g++ --version
|
||||
|
||||
# install rknpu
|
||||
git clone https://github.com/rockchip-linux/rknpu
|
||||
export RKNPU_DIR=$(pwd)/rknpu
|
||||
|
||||
sudo apt install wget git git-lfs
|
||||
|
||||
python3 -m pip install cmake==3.22.0
|
||||
|
||||
echo 'export PATH=~/.local/bin:${PATH}' >> ~/mmdeploy.env
|
||||
export PATH=~/.local/bin:${PATH}
|
||||
}
|
||||
|
||||
install_rknpu2_toolchain() {
|
||||
git clone https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu.git
|
||||
git clone https://github.com/rockchip-linux/rknpu2.git
|
||||
export RKNN_TOOL_CHAIN=$(pwd)/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu
|
||||
export RKNPU2_DIR=$(pwd)/rknpu2
|
||||
}
|
||||
|
||||
build_ocv() {
|
||||
if [ ! -e "opencv" ];then
|
||||
git clone https://github.com/opencv/opencv --depth=1 --branch=4.6.0 --recursive
|
||||
fi
|
||||
if [ ! -e "opencv/build_rknpu" ];then
|
||||
mkdir -p opencv/build_rknpu
|
||||
fi
|
||||
cd opencv/build_rknpu
|
||||
rm -rf CMakeCache.txt
|
||||
cmake .. -DCMAKE_INSTALL_PREFIX=$(pwd)/install -DCMAKE_TOOLCHAIN_FILE=../platforms/linux/arm-gnueabi.toolchain.cmake \
|
||||
-DBUILD_PERF_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DBUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release
|
||||
good_nproc
|
||||
jobs=$?
|
||||
make -j${jobs} && make install
|
||||
export OPENCV_PACKAGE_DIR=$(pwd)/install/lib/cmake/opencv4
|
||||
cd -
|
||||
}
|
||||
|
||||
build_mmdeploy_with_rknpu() {
|
||||
git submodule init
|
||||
git submodule update
|
||||
|
||||
if [ ! -e "build_rknpu" ];then
|
||||
mkdir build_rknpu
|
||||
fi
|
||||
cd build_rknpu
|
||||
|
||||
rm -rf CMakeCache.txt
|
||||
cmake .. \
|
||||
-DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/arm-linux-gnueabihf.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DMMDEPLOY_BUILD_SDK_CXX_API=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DRKNPU_DEVICE_DIR="${RKNPU_DIR}"/rknn/rknn_api/librknn_api \
|
||||
-DOpenCV_DIR="${OPENCV_PACKAGE_DIR}"
|
||||
make -j$(nproc) && make install
|
||||
|
||||
good_nproc
|
||||
jobs=$?
|
||||
make -j${jobs}
|
||||
make install
|
||||
|
||||
ls -lah install/bin/*
|
||||
}
|
||||
|
||||
build_mmdeploy_with_rknpu2() {
|
||||
git submodule init
|
||||
git submodule update
|
||||
device_model=$1
|
||||
if [ ! -e "build_rknpu2" ];then
|
||||
mkdir build_rknpu2
|
||||
fi
|
||||
cd build_rknpu2
|
||||
|
||||
rm -rf CMakeCache.txt
|
||||
cmake .. \
|
||||
-DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/rknpu2-linux-gnu.cmake \
|
||||
-DMMDEPLOY_BUILD_SDK=ON \
|
||||
-DMMDEPLOY_BUILD_SDK_CXX_API=ON \
|
||||
-DMMDEPLOY_BUILD_EXAMPLES=ON \
|
||||
-DMMDEPLOY_TARGET_BACKENDS="rknn" \
|
||||
-DRKNPU2_DEVICE_DIR="${RKNPU2_DIR}/runtime/${device_model}" \
|
||||
-DOpenCV_DIR="${RKNPU2_DIR}"/examples/3rdparty/opencv/opencv-linux-aarch64/share/OpenCV
|
||||
make -j$(nproc) && make install
|
||||
|
||||
good_nproc
|
||||
jobs=$?
|
||||
make -j${jobs}
|
||||
make install
|
||||
|
||||
ls -lah install/bin/*
|
||||
}
|
||||
|
||||
print_success() {
|
||||
echo "----------------------------------------------------------------------"
|
||||
echo "Cross build finished, PLS copy bin/model/test_data to the device.. QVQ"
|
||||
echo "----------------------------------------------------------------------"
|
||||
}
|
||||
|
||||
if [ ! -e "../mmdeploy-dep" ];then
|
||||
mkdir ../mmdeploy-dep
|
||||
fi
|
||||
cd ../mmdeploy-dep
|
||||
|
||||
device_model=$(echo "$1" | tr [:lower:] [:upper:])
|
||||
case "$device_model" in
|
||||
RK1808|RK1806|RV1109|RV1126)
|
||||
install_rknpu_toolchain
|
||||
build_ocv
|
||||
cd ../mmdeploy
|
||||
build_mmdeploy_with_rknpu
|
||||
;;
|
||||
RK3566|RK3568)
|
||||
install_rknpu2_toolchain
|
||||
cd ../mmdeploy
|
||||
build_mmdeploy_with_rknpu2 "RK356X"
|
||||
;;
|
||||
RK3588|RV1106)
|
||||
install_rknpu2_toolchain
|
||||
cd ../mmdeploy
|
||||
build_mmdeploy_with_rknpu2 "$device_model"
|
||||
;;
|
||||
*)
|
||||
echo "mmdeploy doesn't support rockchip '$1' yet"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
print_success
|
Loading…
Reference in New Issue