Add rv1126 yolov3 support to sdk (#1280)
* add yolov3 head to SDK * add yolov5 head to SDK * fix export-info and lint, add reverse check * fix lint * fix export info for yolo heads * add output_names to partition_config * fix typo * config * normalize config * fix * refactor config * fix lint and doc * c++ form * resolve comments * fix CI * fix CI * fix CI * float strides anchors * refine pipeline of rknn-int8 * config * rename func * refactor * rknn wrapper dict and fix typo * rknn wrapper output update, mmcls use end2end type * fix typopull/1420/head
parent
522fcc0635
commit
4dd4d4851b
|
@ -1,8 +1,6 @@
|
|||
backend_config = dict(
|
||||
type='rknn',
|
||||
common_config=dict(
|
||||
mean_values=None, # [[103.53, 116.28, 123.675]],
|
||||
std_values=None, # [[57.375, 57.12, 58.395]],
|
||||
target_platform='rv1126', # 'rk3588'
|
||||
optimization_level=1),
|
||||
quantization_config=dict(do_quantization=False, dataset=None))
|
||||
quantization_config=dict(do_quantization=True, dataset=None))
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']
|
||||
|
||||
onnx_config = dict(input_shape=[224, 224])
|
||||
codebase_config = dict(model_type='end2end')
|
||||
backend_config = dict(
|
||||
input_size_list=[[3, 224, 224]],
|
||||
quantization_config=dict(do_quantization=False))
|
|
@ -1,5 +1,5 @@
|
|||
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']
|
||||
|
||||
onnx_config = dict(input_shape=[224, 224])
|
||||
codebase_config = dict(model_type='rknn')
|
||||
codebase_config = dict(model_type='end2end')
|
||||
backend_config = dict(input_size_list=[[3, 224, 224]])
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
|
||||
|
||||
onnx_config = dict(input_shape=[320, 320])
|
||||
|
||||
codebase_config = dict(model_type='rknn')
|
||||
|
||||
backend_config = dict(
|
||||
input_size_list=[[3, 320, 320]],
|
||||
quantization_config=dict(do_quantization=False))
|
||||
|
||||
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
|
||||
# partition_config = dict(
|
||||
# type='rknn', # the partition policy name
|
||||
# apply_marks=True, # should always be set to True
|
||||
# partition_cfg=[
|
||||
# dict(
|
||||
# save_file='model.onnx', # name to save the partitioned onnx
|
||||
# start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
# end=['yolo_head:input'], # [mark_name:output, ...]
|
||||
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
|
||||
# ])
|
||||
|
||||
# # retinanet, ssd, fsaf for rknn-toolkit2
|
||||
# partition_config = dict(
|
||||
# type='rknn', # the partition policy name
|
||||
# apply_marks=True,
|
||||
# partition_cfg=[
|
||||
# dict(
|
||||
# save_file='model.onnx',
|
||||
# start='detector_forward:input',
|
||||
# end=['BaseDenseHead:output'],
|
||||
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
|
||||
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
|
||||
# ])
|
|
@ -6,7 +6,7 @@ codebase_config = dict(model_type='rknn')
|
|||
|
||||
backend_config = dict(input_size_list=[[3, 320, 320]])
|
||||
|
||||
# # yolov3, yolox
|
||||
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
|
||||
# partition_config = dict(
|
||||
# type='rknn', # the partition policy name
|
||||
# apply_marks=True, # should always be set to True
|
||||
|
@ -14,10 +14,11 @@ backend_config = dict(input_size_list=[[3, 320, 320]])
|
|||
# dict(
|
||||
# save_file='model.onnx', # name to save the partitioned onnx
|
||||
# start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
# end=['yolo_head:input']) # [mark_name:output, ...]
|
||||
# end=['yolo_head:input'], # [mark_name:output, ...]
|
||||
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
|
||||
# ])
|
||||
|
||||
# # retinanet, ssd, fsaf
|
||||
# # retinanet, ssd, fsaf for rknn-toolkit2
|
||||
# partition_config = dict(
|
||||
# type='rknn', # the partition policy name
|
||||
# apply_marks=True,
|
||||
|
@ -25,5 +26,7 @@ backend_config = dict(input_size_list=[[3, 320, 320]])
|
|||
# dict(
|
||||
# save_file='model.onnx',
|
||||
# start='detector_forward:input',
|
||||
# end=['BaseDenseHead:output'])
|
||||
# end=['BaseDenseHead:output'],
|
||||
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
|
||||
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
|
||||
# ])
|
|
@ -8,5 +8,6 @@ partition_config = dict(
|
|||
dict(
|
||||
save_file='yolov3.onnx',
|
||||
start=['detector_forward:input'],
|
||||
end=['yolo_head:input'])
|
||||
end=['yolo_head:input'],
|
||||
output_names=[f'pred_maps.{i}' for i in range(3)])
|
||||
])
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
|
||||
|
||||
onnx_config = dict(input_shape=[320, 320])
|
||||
|
||||
codebase_config = dict(model_type='rknn')
|
||||
|
||||
backend_config = dict(
|
||||
input_size_list=[[3, 320, 320]],
|
||||
quantization_config=dict(do_quantization=False))
|
|
@ -0,0 +1,228 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "yolo_head.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "mmdeploy/core/model.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
YOLOHead::YOLOHead(const Value& cfg) : MMDetection(cfg) {
|
||||
auto init = [&]() -> Result<void> {
|
||||
auto model = cfg["context"]["model"].get<Model>();
|
||||
if (cfg.contains("params")) {
|
||||
nms_pre_ = cfg["params"].value("nms_pre", -1);
|
||||
score_thr_ = cfg["params"].value("score_thr", 0.02f);
|
||||
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
|
||||
iou_threshold_ = cfg["params"].contains("nms")
|
||||
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
|
||||
: 0.45f;
|
||||
if (cfg["params"].contains("anchor_generator")) {
|
||||
from_value(cfg["params"]["anchor_generator"]["base_sizes"], anchors_);
|
||||
from_value(cfg["params"]["anchor_generator"]["strides"], strides_);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
};
|
||||
init().value();
|
||||
}
|
||||
|
||||
Result<Value> YOLOHead::operator()(const Value& prep_res, const Value& infer_res) {
|
||||
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
|
||||
try {
|
||||
const Device kHost{0, 0};
|
||||
std::vector<Tensor> pred_maps;
|
||||
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) {
|
||||
auto pred_map = iter->get<Tensor>();
|
||||
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream()));
|
||||
pred_maps.push_back(_pred_map);
|
||||
}
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
// reorder pred_maps according to strides and anchors, mainly for rknpu yolov3
|
||||
if ((pred_maps.size() > 1) &&
|
||||
!((strides_[0] < strides_[1]) ^ (pred_maps[0].shape(3) < pred_maps[1].shape(3)))) {
|
||||
std::reverse(pred_maps.begin(), pred_maps.end());
|
||||
}
|
||||
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], pred_maps));
|
||||
return to_value(result);
|
||||
} catch (...) {
|
||||
return Status(eFail);
|
||||
}
|
||||
}
|
||||
|
||||
inline static int clamp(float val, int min, int max) {
|
||||
return val > min ? (val < max ? val : max) : min;
|
||||
}
|
||||
|
||||
static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }
|
||||
|
||||
static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); }
|
||||
|
||||
int YOLOHead::YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor,
|
||||
int grid_h, int grid_w, int height, int width, int stride,
|
||||
std::vector<float>& boxes, std::vector<float>& obj_probs,
|
||||
std::vector<int>& class_id, float threshold) const {
|
||||
auto input = const_cast<float*>(feat_map.data<float>());
|
||||
auto prop_box_size = feat_map.shape(1) / anchor.size();
|
||||
const int kClasses = prop_box_size - 5;
|
||||
int valid_count = 0;
|
||||
int grid_len = grid_h * grid_w;
|
||||
float thres = unsigmoid(threshold);
|
||||
for (int a = 0; a < anchor.size(); a++) {
|
||||
for (int i = 0; i < grid_h; i++) {
|
||||
for (int j = 0; j < grid_w; j++) {
|
||||
float box_confidence = input[(prop_box_size * a + 4) * grid_len + i * grid_w + j];
|
||||
if (box_confidence >= thres) {
|
||||
int offset = (prop_box_size * a) * grid_len + i * grid_w + j;
|
||||
float* in_ptr = input + offset;
|
||||
|
||||
float box_x = sigmoid(*in_ptr);
|
||||
float box_y = sigmoid(in_ptr[grid_len]);
|
||||
float box_w = in_ptr[2 * grid_len];
|
||||
float box_h = in_ptr[3 * grid_len];
|
||||
auto box = yolo_decode(box_x, box_y, box_w, box_h, stride, anchor, j, i, a);
|
||||
|
||||
box_x = box[0];
|
||||
box_y = box[1];
|
||||
box_w = box[2];
|
||||
box_h = box[3];
|
||||
|
||||
box_x -= (box_w / 2.0);
|
||||
box_y -= (box_h / 2.0);
|
||||
boxes.push_back(box_x);
|
||||
boxes.push_back(box_y);
|
||||
boxes.push_back(box_x + box_w);
|
||||
boxes.push_back(box_y + box_h);
|
||||
|
||||
float max_class_probs = in_ptr[5 * grid_len];
|
||||
int max_class_id = 0;
|
||||
for (int k = 1; k < kClasses; ++k) {
|
||||
float prob = in_ptr[(5 + k) * grid_len];
|
||||
if (prob > max_class_probs) {
|
||||
max_class_id = k;
|
||||
max_class_probs = prob;
|
||||
}
|
||||
}
|
||||
obj_probs.push_back(sigmoid(max_class_probs) * sigmoid(box_confidence));
|
||||
class_id.push_back(max_class_id);
|
||||
valid_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return valid_count;
|
||||
}
|
||||
|
||||
Result<Detections> YOLOHead::GetBBoxes(const Value& prep_res,
|
||||
const std::vector<Tensor>& pred_maps) const {
|
||||
std::vector<float> filter_boxes;
|
||||
std::vector<float> obj_probs;
|
||||
std::vector<int> class_id;
|
||||
|
||||
int model_in_h = prep_res["img_shape"][1].get<int>();
|
||||
int model_in_w = prep_res["img_shape"][2].get<int>();
|
||||
|
||||
for (int i = 0; i < pred_maps.size(); i++) {
|
||||
int stride = strides_[i];
|
||||
int grid_h = model_in_h / stride;
|
||||
int grid_w = model_in_w / stride;
|
||||
YOLOFeatDecode(pred_maps[i], anchors_[i], grid_h, grid_w, model_in_h, model_in_w, stride,
|
||||
filter_boxes, obj_probs, class_id, score_thr_);
|
||||
}
|
||||
|
||||
std::vector<int> indexArray;
|
||||
for (int i = 0; i < obj_probs.size(); ++i) {
|
||||
indexArray.push_back(i);
|
||||
}
|
||||
Sort(obj_probs, class_id, indexArray);
|
||||
|
||||
Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT,
|
||||
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"});
|
||||
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>());
|
||||
NMS(dets, iou_threshold_, indexArray);
|
||||
|
||||
Detections objs;
|
||||
std::vector<float> scale_factor;
|
||||
if (prep_res.contains("scale_factor")) {
|
||||
from_value(prep_res["scale_factor"], scale_factor);
|
||||
} else {
|
||||
scale_factor = {1.f, 1.f, 1.f, 1.f};
|
||||
}
|
||||
int ori_width = prep_res["ori_shape"][2].get<int>();
|
||||
int ori_height = prep_res["ori_shape"][1].get<int>();
|
||||
auto det_ptr = dets.data<float>();
|
||||
for (int i = 0; i < indexArray.size(); ++i) {
|
||||
if (indexArray[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
int j = indexArray[i];
|
||||
auto x1 = clamp(det_ptr[j * 4 + 0], 0, model_in_w);
|
||||
auto y1 = clamp(det_ptr[j * 4 + 1], 0, model_in_h);
|
||||
auto x2 = clamp(det_ptr[j * 4 + 2], 0, model_in_w);
|
||||
auto y2 = clamp(det_ptr[j * 4 + 3], 0, model_in_h);
|
||||
int label_id = class_id[i];
|
||||
float score = obj_probs[i];
|
||||
|
||||
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);
|
||||
|
||||
auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height);
|
||||
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
|
||||
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
|
||||
rect[3] - rect[1]);
|
||||
continue;
|
||||
}
|
||||
Detection det{};
|
||||
det.index = i;
|
||||
det.label_id = label_id;
|
||||
det.score = score;
|
||||
det.bbox = rect;
|
||||
objs.push_back(std::move(det));
|
||||
}
|
||||
|
||||
return objs;
|
||||
}
|
||||
|
||||
Result<Value> YOLOV3Head::operator()(const Value& prep_res, const Value& infer_res) {
|
||||
return YOLOHead::operator()(prep_res, infer_res);
|
||||
}
|
||||
|
||||
std::array<float, 4> YOLOV3Head::yolo_decode(float box_x, float box_y, float box_w, float box_h,
|
||||
float stride,
|
||||
const std::vector<std::vector<float>>& anchor, int j,
|
||||
int i, int a) const {
|
||||
box_x = (box_x + j) * stride;
|
||||
box_y = (box_y + i) * stride;
|
||||
box_w = expf(box_w) * anchor[a][0];
|
||||
box_h = expf(box_h) * anchor[a][1];
|
||||
return std::array<float, 4>{box_x, box_y, box_w, box_h};
|
||||
}
|
||||
|
||||
Result<Value> YOLOV5Head::operator()(const Value& prep_res, const Value& infer_res) {
|
||||
return YOLOHead::operator()(prep_res, infer_res);
|
||||
}
|
||||
|
||||
std::array<float, 4> YOLOV5Head::yolo_decode(float box_x, float box_y, float box_w, float box_h,
|
||||
float stride,
|
||||
const std::vector<std::vector<float>>& anchor, int j,
|
||||
int i, int a) const {
|
||||
box_x = box_x * 2 - 0.5;
|
||||
box_y = box_y * 2 - 0.5;
|
||||
box_w = box_w * 2 - 0.5;
|
||||
box_h = box_h * 2 - 0.5;
|
||||
box_x = (box_x + j) * stride;
|
||||
box_y = (box_y + i) * stride;
|
||||
box_w = box_w * box_w * anchor[a][0];
|
||||
box_h = box_h * box_h * anchor[a][1];
|
||||
return std::array<float, 4>{box_x, box_y, box_w, box_h};
|
||||
}
|
||||
|
||||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV3Head);
|
||||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV5Head);
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
|
||||
#define MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
|
||||
|
||||
#include "mmdeploy/codebase/mmdet/mmdet.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
class YOLOHead : public MMDetection {
|
||||
public:
|
||||
explicit YOLOHead(const Value& cfg);
|
||||
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
|
||||
int YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor,
|
||||
int grid_h, int grid_w, int height, int width, int stride,
|
||||
std::vector<float>& boxes, std::vector<float>& obj_probs,
|
||||
std::vector<int>& class_id, float threshold) const;
|
||||
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& pred_maps) const;
|
||||
virtual std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h,
|
||||
float stride,
|
||||
const std::vector<std::vector<float>>& anchor, int j,
|
||||
int i, int a) const = 0;
|
||||
|
||||
private:
|
||||
float score_thr_{0.4f};
|
||||
int nms_pre_{1000};
|
||||
float iou_threshold_{0.45f};
|
||||
int min_bbox_size_{0};
|
||||
std::vector<std::vector<std::vector<float>>> anchors_;
|
||||
std::vector<float> strides_;
|
||||
};
|
||||
|
||||
class YOLOV3Head : public YOLOHead {
|
||||
public:
|
||||
using YOLOHead::YOLOHead;
|
||||
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
|
||||
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride,
|
||||
const std::vector<std::vector<float>>& anchor, int j, int i,
|
||||
int a) const override;
|
||||
};
|
||||
|
||||
class YOLOV5Head : public YOLOHead {
|
||||
public:
|
||||
using YOLOHead::YOLOHead;
|
||||
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
|
||||
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride,
|
||||
const std::vector<std::vector<float>>& anchor, int j, int i,
|
||||
int a) const override;
|
||||
};
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
||||
|
||||
#endif // MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
|
|
@ -138,30 +138,12 @@ label: 65, score: 0.95
|
|||
|
||||
## Troubleshooting
|
||||
|
||||
- Quantization fails.
|
||||
|
||||
Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `True`. Please modify the settings of `Normalize` in the `model_cfg` from
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
```
|
||||
|
||||
to
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
|
||||
```
|
||||
|
||||
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[[103.53, 116.28, 123.675]]` and `std_values=[[57.375, 57.12, 58.395]]`.
|
||||
|
||||
- MMDet models.
|
||||
|
||||
YOLOV3 & YOLOX: you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py):
|
||||
|
||||
```python
|
||||
# yolov3, yolox
|
||||
# yolov3, yolox for rknn-toolkit and rknn-toolkit2
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True, # should always be set to True
|
||||
|
@ -169,7 +151,8 @@ label: 65, score: 0.95
|
|||
dict(
|
||||
save_file='model.onnx', # name to save the partitioned onnx
|
||||
start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
end=['yolo_head:input']) # [mark_name:output, ...]
|
||||
end=['yolo_head:input'], # [mark_name:output, ...]
|
||||
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
|
||||
])
|
||||
```
|
||||
|
||||
|
@ -184,7 +167,9 @@ label: 65, score: 0.95
|
|||
dict(
|
||||
save_file='model.onnx',
|
||||
start='detector_forward:input',
|
||||
end=['BaseDenseHead:output'])
|
||||
end=['BaseDenseHead:output'],
|
||||
output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
|
||||
[f'BaseDenseHead.loc.{i}' for i in range(5)])
|
||||
])
|
||||
```
|
||||
|
||||
|
|
|
@ -66,7 +66,8 @@ partition_config = dict(
|
|||
dict(
|
||||
save_file='yolov3.onnx', # filename to save the partitioned onnx model
|
||||
start=['detector_forward:input'], # [mark_name:input/output, ...]
|
||||
end=['yolo_head:input']) # [mark_name:input/output, ...]
|
||||
end=['yolo_head:input'], # [mark_name:input/output, ...]
|
||||
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
|
||||
])
|
||||
|
||||
```
|
||||
|
|
|
@ -105,7 +105,7 @@ python tools/deploy.py \
|
|||
将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)
|
||||
|
||||
```python
|
||||
# yolov3, yolox
|
||||
# yolov3, yolox for rknn-toolkit and rknn-toolkit2
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True, # should always be set to True
|
||||
|
@ -113,7 +113,8 @@ partition_config = dict(
|
|||
dict(
|
||||
save_file='model.onnx', # name to save the partitioned onnx
|
||||
start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
end=['yolo_head:input']) # [mark_name:output, ...]
|
||||
end=['yolo_head:input'], # [mark_name:output, ...]
|
||||
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
|
||||
])
|
||||
```
|
||||
|
||||
|
@ -143,7 +144,9 @@ partition_config = dict(
|
|||
dict(
|
||||
save_file='model.onnx',
|
||||
start='detector_forward:input',
|
||||
end=['BaseDenseHead:output'])
|
||||
end=['BaseDenseHead:output'],
|
||||
output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
|
||||
[f'BaseDenseHead.loc.{i}' for i in range(5)])
|
||||
])
|
||||
```
|
||||
|
||||
|
@ -168,24 +171,6 @@ backend_config = dict(
|
|||
|
||||
### 问题说明
|
||||
|
||||
- 量化失败.
|
||||
|
||||
经验来说, 如果 `do_quantization` 被设置为 `True`,RKNN 需要的输入没有被归一化过。请修改 `Normalize` 在 `model_cfg` 的设置,如将
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
```
|
||||
|
||||
改为
|
||||
|
||||
```python
|
||||
img_norm_cfg = dict(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
|
||||
```
|
||||
|
||||
此外, deploy_cfg 的 `mean_values` 和 `std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[[103.53, 116.28, 123.675]]`, `std_values=[[57.375, 57.12, 58.395]]`。
|
||||
|
||||
- SDK 只支持 int8 的 rknn 模型,这需要在转换模型时设置 `do_quantization=True`。
|
||||
|
||||
## 模型推理
|
||||
|
|
|
@ -64,7 +64,8 @@ partition_config = dict(
|
|||
dict(
|
||||
save_file='yolov3.onnx', # filename to save the partitioned onnx model
|
||||
start=['detector_forward:input'], # [mark_name:input/output, ...]
|
||||
end=['yolo_head:input']) # [mark_name:input/output, ...]
|
||||
end=['yolo_head:input'], # [mark_name:input/output, ...]
|
||||
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
|
||||
])
|
||||
|
||||
```
|
||||
|
|
|
@ -64,10 +64,12 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
|
|||
|
||||
if backend == Backend.PYTORCH:
|
||||
model = task_processor.init_pytorch_model(model[0])
|
||||
model_inputs, _ = task_processor.create_input(img, input_shape)
|
||||
else:
|
||||
model = task_processor.init_backend_model(model, **kwargs)
|
||||
model_inputs, _ = task_processor.create_input(
|
||||
img, input_shape, task_processor.update_test_pipeline)
|
||||
|
||||
model_inputs, _ = task_processor.create_input(img, input_shape)
|
||||
with torch.no_grad():
|
||||
result = task_processor.run_inference(model, model_inputs)[0]
|
||||
|
||||
|
|
|
@ -4,15 +4,17 @@ from typing import Optional, Union
|
|||
import mmcv
|
||||
from rknn.api import RKNN
|
||||
|
||||
from mmdeploy.utils import (get_common_config, get_onnx_config,
|
||||
from mmdeploy.utils import (get_backend_config, get_common_config,
|
||||
get_normalization, get_onnx_config,
|
||||
get_partition_config, get_quantization_config,
|
||||
get_root_logger, load_config)
|
||||
from mmdeploy.utils.config_utils import get_backend_config
|
||||
get_rknn_quantization, get_root_logger,
|
||||
load_config)
|
||||
|
||||
|
||||
def onnx2rknn(onnx_model: str,
|
||||
output_path: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
dataset_file: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""Convert ONNX to RKNN.
|
||||
|
@ -40,6 +42,14 @@ def onnx2rknn(onnx_model: str,
|
|||
output_names = onnx_params.get('output_names', None)
|
||||
input_size_list = get_backend_config(deploy_cfg).get(
|
||||
'input_size_list', None)
|
||||
# update norm value
|
||||
if get_rknn_quantization(deploy_cfg) is True:
|
||||
transform = get_normalization(model_cfg)
|
||||
common_params.update(
|
||||
dict(
|
||||
mean_values=[transform['mean']],
|
||||
std_values=[transform['std']]))
|
||||
|
||||
# update output_names for partition models
|
||||
if get_partition_config(deploy_cfg) is not None:
|
||||
import onnx
|
||||
|
@ -62,7 +72,7 @@ def onnx2rknn(onnx_model: str,
|
|||
if dataset_cfg is None and dataset_file is None:
|
||||
do_quantization = False
|
||||
logger.warning('no dataset passed in, quantization is skipped')
|
||||
if dataset_file is None:
|
||||
if dataset_cfg is not None:
|
||||
dataset_file = dataset_cfg
|
||||
ret = rknn.build(do_quantization=do_quantization, dataset=dataset_file)
|
||||
if ret != 0:
|
||||
|
|
|
@ -48,7 +48,7 @@ class RKNNWrapper(BaseWrapper):
|
|||
super().__init__(output_names)
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
torch.Tensor]) -> Sequence[torch.Tensor]:
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Run forward inference. Note that the shape of the input tensor is
|
||||
NxCxHxW while RKNN only accepts the numpy inputs of NxHxWxC. There is a
|
||||
permute operation outside RKNN inference.
|
||||
|
@ -57,11 +57,14 @@ class RKNNWrapper(BaseWrapper):
|
|||
inputs (Dict[str, torch.Tensor]): Input name and tensor pairs.
|
||||
|
||||
Return:
|
||||
Sequence[torch.Tensor]: The output tensors.
|
||||
Dict[str, torch.Tensor]: The output tensors.
|
||||
"""
|
||||
rknn_out = self.__rknnnn_execute(
|
||||
[i.permute(0, 2, 3, 1).cpu().numpy() for i in inputs.values()])
|
||||
return [torch.from_numpy(out) for out in rknn_out]
|
||||
rknn_out = [torch.from_numpy(out) for out in rknn_out]
|
||||
if self.output_names is not None:
|
||||
return dict(zip(self.output_names, rknn_out))
|
||||
return {'#' + str(i): x for i, x in enumerate(rknn_out)}
|
||||
|
||||
@TimeCounter.count_time(Backend.RKNN.value)
|
||||
def __rknnnn_execute(self, inputs: Sequence[np.array]):
|
||||
|
|
|
@ -7,7 +7,8 @@ import mmcv
|
|||
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.utils import (Backend, Task, get_backend, get_codebase,
|
||||
get_common_config, get_ir_config, get_root_logger,
|
||||
get_common_config, get_ir_config,
|
||||
get_partition_config, get_root_logger,
|
||||
get_task_type, is_dynamic_batch, load_config)
|
||||
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
|
||||
from .tracer import add_transform_tag, get_transform_static
|
||||
|
@ -94,6 +95,9 @@ def get_models(deploy_cfg: Union[str, mmcv.Config],
|
|||
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
|
||||
precision = 'FP32'
|
||||
ir_name = get_ir_config(deploy_cfg)['save_file']
|
||||
if get_partition_config(deploy_cfg) is not None:
|
||||
ir_name = get_partition_config(
|
||||
deploy_cfg)['partition_cfg'][0]['save_file']
|
||||
net = ir_name
|
||||
weights = ''
|
||||
backend = get_backend(deploy_cfg=deploy_cfg)
|
||||
|
@ -185,6 +189,9 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||
backend = get_backend(deploy_cfg=deploy_cfg)
|
||||
if backend in (Backend.TORCHSCRIPT, Backend.RKNN):
|
||||
output_names = ir_config.get('output_names', None)
|
||||
if get_partition_config(deploy_cfg) is not None:
|
||||
output_names = get_partition_config(
|
||||
deploy_cfg)['partition_cfg'][0]['output_names']
|
||||
input_map = dict(img='#0')
|
||||
output_map = {name: f'#{i}' for i, name in enumerate(output_names)}
|
||||
else:
|
||||
|
@ -258,6 +265,8 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||
for transform in transforms:
|
||||
if transform['type'] == 'Normalize':
|
||||
transform['to_float'] = False
|
||||
transform['mean'] = [0, 0, 0]
|
||||
transform['std'] = [1, 1, 1]
|
||||
|
||||
if transforms[0]['type'] != 'Lift':
|
||||
assert transforms[0]['type'] == 'LoadImageFromFile', \
|
||||
|
@ -299,14 +308,15 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||
task = Task.INSTANCE_SEGMENTATION
|
||||
|
||||
component = task_map[task]['component']
|
||||
if get_backend(deploy_cfg) == Backend.RKNN:
|
||||
if 'YOLO' in task_processor.model_cfg.model.type:
|
||||
bbox_head = task_processor.model_cfg.model.bbox_head
|
||||
component = bbox_head.type
|
||||
params['anchor_generator'] = bbox_head.get('anchor_generator',
|
||||
None)
|
||||
else: # default using base_dense_head
|
||||
component = 'BaseDenseHead'
|
||||
if task == Task.OBJECT_DETECTION:
|
||||
if get_backend(deploy_cfg) == Backend.RKNN:
|
||||
if 'YOLO' in task_processor.model_cfg.model.type:
|
||||
bbox_head = task_processor.model_cfg.model.bbox_head
|
||||
component = bbox_head.type
|
||||
params['anchor_generator'] = bbox_head.get(
|
||||
'anchor_generator', None)
|
||||
else: # default using base_dense_head
|
||||
component = 'BaseDenseHead'
|
||||
|
||||
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
|
||||
if 'type' in params:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -8,7 +8,8 @@ import torch
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmdeploy.utils import (get_backend_config, get_codebase,
|
||||
get_codebase_config, get_root_logger)
|
||||
get_codebase_config, get_rknn_quantization,
|
||||
get_root_logger)
|
||||
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
||||
|
||||
|
||||
|
@ -139,18 +140,49 @@ class BaseTask(metaclass=ABCMeta):
|
|||
return self.codebase_class.single_gpu_test(model, data_loader, show,
|
||||
out_dir, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def update_test_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
|
||||
"""Update preprocess pipeline.
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmcv.Config): Model config file.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file.
|
||||
|
||||
Returns:
|
||||
cfg (mmcv.Config): Updated model_cfg.
|
||||
"""
|
||||
cfg = model_cfg.deepcopy()
|
||||
if get_rknn_quantization(deploy_cfg):
|
||||
pipelines = cfg.data.test.pipeline
|
||||
for i, pipeline in enumerate(pipelines):
|
||||
if pipeline['type'] == 'MultiScaleFlipAug':
|
||||
assert 'transforms' in pipeline
|
||||
for trans in pipeline['transforms']:
|
||||
if trans['type'] == 'Normalize':
|
||||
trans['mean'] = [0, 0, 0]
|
||||
trans['std'] = [1, 1, 1]
|
||||
else:
|
||||
if pipeline['type'] == 'Normalize':
|
||||
pipeline['mean'] = [0, 0, 0]
|
||||
pipeline['std'] = [1, 1, 1]
|
||||
cfg.data.test.pipeline = pipelines
|
||||
return cfg
|
||||
|
||||
@abstractmethod
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None,
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None,
|
||||
**kwargs) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for model.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray | Sequence): Input image(s),
|
||||
accepted data types are `str`, `np.ndarray`.
|
||||
input_shape (list[int]): Input shape of image in (width, height)
|
||||
format, defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -147,15 +147,18 @@ class VideoRecognition(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None,
|
||||
**kwargs) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for recognizer.
|
||||
|
||||
Args:
|
||||
imgs (Any): Input image(s), accepted data type are `str`,
|
||||
`np.ndarray`, `torch.Tensor`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -8,8 +8,7 @@ import torch
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.codebase.base import BaseTask
|
||||
from mmdeploy.utils import Task, get_root_logger
|
||||
from mmdeploy.utils.config_utils import get_input_shape
|
||||
from mmdeploy.utils import Task, get_input_shape, get_root_logger
|
||||
from .mmclassification import MMCLS_TASK
|
||||
|
||||
|
||||
|
@ -35,6 +34,7 @@ def process_model_config(model_cfg: mmcv.Config,
|
|||
else:
|
||||
if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
|
||||
cfg.data.test.pipeline.pop(0)
|
||||
|
||||
# check whether input_shape is valid
|
||||
if input_shape is not None:
|
||||
if 'crop_size' in cfg.data.test.pipeline[2]:
|
||||
|
@ -111,15 +111,18 @@ class Classification(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Optional[Sequence[int]] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for classifier.
|
||||
|
||||
Args:
|
||||
imgs (Union[str, np.ndarray, Sequence]): Input image(s),
|
||||
accepted data type are `str`, `np.ndarray`, Sequence.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Default: None.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
@ -128,7 +131,10 @@ class Classification(BaseTask):
|
|||
from mmcv.parallel import collate, scatter
|
||||
if isinstance(imgs, (str, np.ndarray)):
|
||||
imgs = [imgs]
|
||||
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
model_cfg = self.model_cfg
|
||||
if pipeline_updater is not None:
|
||||
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
|
||||
cfg = process_model_config(model_cfg, imgs, input_shape)
|
||||
data_list = []
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
for img in imgs:
|
||||
|
@ -276,7 +282,8 @@ class Classification(BaseTask):
|
|||
dict: Composed of the preprocess information.
|
||||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg)
|
||||
cfg = process_model_config(cfg, [''], input_shape)
|
||||
preprocess = cfg.data.test.pipeline
|
||||
return preprocess
|
||||
|
||||
|
|
|
@ -144,25 +144,6 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
return pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('rknn')
|
||||
class RKNNEnd2EndModel(End2EndModel):
|
||||
"""RKNN inference class, converts RKNN output to mmcls format."""
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
|
||||
List[np.ndarray]:
|
||||
"""The interface for forward test.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: A list of classification prediction.
|
||||
"""
|
||||
outputs = self.wrapper({self.input_name: imgs})
|
||||
outputs = [out.numpy() for out in outputs]
|
||||
return outputs
|
||||
|
||||
|
||||
def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
|
||||
"""Get class name from config.
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -7,8 +7,7 @@ import torch
|
|||
from mmcv.parallel import DataContainer
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.utils import Task
|
||||
from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape
|
||||
from mmdeploy.utils import Task, get_input_shape, is_dynamic_shape
|
||||
from ...base import BaseTask
|
||||
from .mmdetection import MMDET_TASK
|
||||
|
||||
|
@ -103,15 +102,18 @@ class ObjectDetection(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for detector.
|
||||
|
||||
Args:
|
||||
imgs (str|np.ndarray): Input image(s), accpeted data type are
|
||||
`str`, `np.ndarray`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
@ -121,7 +123,10 @@ class ObjectDetection(BaseTask):
|
|||
if isinstance(imgs, (str, np.ndarray)):
|
||||
imgs = [imgs]
|
||||
dynamic_flag = is_dynamic_shape(self.deploy_cfg)
|
||||
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
model_cfg = self.model_cfg
|
||||
if pipeline_updater is not None:
|
||||
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
|
||||
cfg = process_model_config(model_cfg, imgs, input_shape)
|
||||
# Drop pad_to_square when static shape. Because static shape should
|
||||
# ensure the shape before input image.
|
||||
if not dynamic_flag:
|
||||
|
@ -291,7 +296,8 @@ class ObjectDetection(BaseTask):
|
|||
dict: Composed of the preprocess information.
|
||||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg)
|
||||
model_cfg = process_model_config(cfg, [''], input_shape)
|
||||
preprocess = model_cfg.data.test.pipeline
|
||||
return preprocess
|
||||
|
||||
|
|
|
@ -722,6 +722,7 @@ class RKNNModel(End2EndModel):
|
|||
class labels of shape [N, num_det].
|
||||
"""
|
||||
outputs = self.wrapper({self.input_name: imgs})
|
||||
outputs = [i for i in outputs.values()]
|
||||
ret = self._get_bboxes(outputs, img_metas)
|
||||
return ret
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from os import path as osp
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -65,12 +65,19 @@ class MonocularDetection(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Sequence[int] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for detector.
|
||||
|
||||
Args:
|
||||
pcd (str): Input pcd file path.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -117,13 +117,16 @@ class SuperResolution(BaseTask):
|
|||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None,
|
||||
**kwargs) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for editing processor.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray): Input image(s).
|
||||
input_shape (Sequence[int] | None): A list of two integer in
|
||||
(width, height) format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -108,15 +108,18 @@ class TextDetection(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for segmentor.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray): Input image(s), accepted data type are
|
||||
`str`, `np.ndarray`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -114,15 +114,18 @@ class TextRecognition(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Sequence[int] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for segmentor.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray): Input image(s), accepted data type are
|
||||
`str`, `np.ndarray`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -130,15 +130,18 @@ class PoseDetection(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None,
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None,
|
||||
**kwargs) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for pose detection.
|
||||
|
||||
Args:
|
||||
imgs (Any): Input image(s), accepted data type are ``str``,
|
||||
``np.ndarray``.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to ``None``.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -150,15 +150,18 @@ class RotatedDetection(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Sequence[int] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for rotated object detection.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray): Input image(s), accepted data type are
|
||||
`str`, `np.ndarray`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -33,6 +33,7 @@ def process_model_config(model_cfg: mmcv.Config,
|
|||
if isinstance(imgs[0], np.ndarray):
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0] = LoadImage()
|
||||
|
||||
# for static exporting
|
||||
if input_shape is not None:
|
||||
for pipeline in cfg.data.test.pipeline[1:]:
|
||||
|
@ -107,15 +108,18 @@ class Segmentation(BaseTask):
|
|||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None) \
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None, **kwargs) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for segmentor.
|
||||
|
||||
Args:
|
||||
imgs (Any): Input image(s), accepted data type are `str`,
|
||||
`np.ndarray`, `torch.Tensor`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
|
@ -125,7 +129,10 @@ class Segmentation(BaseTask):
|
|||
if isinstance(imgs, (str, np.ndarray)):
|
||||
imgs = [imgs]
|
||||
imgs = [mmcv.imread(_) for _ in imgs]
|
||||
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
model_cfg = self.model_cfg
|
||||
if pipeline_updater is not None:
|
||||
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
|
||||
cfg = process_model_config(model_cfg, imgs, input_shape)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
data_list = []
|
||||
for img in imgs:
|
||||
|
@ -260,7 +267,8 @@ class Segmentation(BaseTask):
|
|||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
load_from_file = self.model_cfg.data.test.pipeline[0]
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg)
|
||||
model_cfg = process_model_config(cfg, [''], input_shape)
|
||||
preprocess = model_cfg.data.test.pipeline
|
||||
preprocess[0] = load_from_file
|
||||
return preprocess
|
||||
|
|
|
@ -162,7 +162,9 @@ class RKNNModel(End2EndModel):
|
|||
List[np.ndarray]: A list of segmentation map.
|
||||
"""
|
||||
outputs = self.wrapper({self.input_name: imgs})
|
||||
outputs = [output.argmax(dim=1, keepdim=True) for output in outputs]
|
||||
outputs = [
|
||||
output.argmax(dim=1, keepdim=True) for output in outputs.values()
|
||||
]
|
||||
outputs = [out.detach().cpu().numpy() for out in outputs]
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -21,8 +21,9 @@ if importlib.util.find_spec('mmcv') is not None:
|
|||
get_codebase_config, get_common_config,
|
||||
get_dynamic_axes, get_input_shape,
|
||||
get_ir_config, get_model_inputs,
|
||||
get_onnx_config, get_partition_config,
|
||||
get_quantization_config, get_task_type,
|
||||
get_normalization, get_onnx_config,
|
||||
get_partition_config, get_quantization_config,
|
||||
get_rknn_quantization, get_task_type,
|
||||
is_dynamic_batch, is_dynamic_shape, load_config)
|
||||
|
||||
# yapf: enable
|
||||
|
@ -33,5 +34,6 @@ if importlib.util.find_spec('mmcv') is not None:
|
|||
'get_codebase_config', 'get_common_config', 'get_dynamic_axes',
|
||||
'get_input_shape', 'get_ir_config', 'get_model_inputs',
|
||||
'get_onnx_config', 'get_partition_config', 'get_quantization_config',
|
||||
'get_task_type', 'is_dynamic_batch', 'is_dynamic_shape', 'load_config'
|
||||
'get_task_type', 'is_dynamic_batch', 'is_dynamic_shape', 'load_config',
|
||||
'get_rknn_quantization', 'get_normalization'
|
||||
]
|
||||
|
|
|
@ -393,3 +393,40 @@ def get_dynamic_axes(
|
|||
raise KeyError('No names were found to define dynamic axes.')
|
||||
dynamic_axes = dict(zip(axes_names, dynamic_axes))
|
||||
return dynamic_axes
|
||||
|
||||
|
||||
def get_normalization(model_cfg: Union[str, mmcv.Config]):
|
||||
"""Get the Normalize transform from model config.
|
||||
|
||||
Args:
|
||||
model_cfg (mmcv.Config): The content of config.
|
||||
|
||||
Returns:
|
||||
dict: The Normalize transform.
|
||||
"""
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
pipelines = model_cfg.data.test.pipeline
|
||||
for i, pipeline in enumerate(pipelines):
|
||||
if pipeline['type'] == 'MultiScaleFlipAug':
|
||||
assert 'transforms' in pipeline
|
||||
for trans in pipeline['transforms']:
|
||||
if trans['type'] == 'Normalize':
|
||||
return trans
|
||||
else:
|
||||
if pipeline['type'] == 'Normalize':
|
||||
return pipeline
|
||||
|
||||
|
||||
def get_rknn_quantization(deploy_cfg: mmcv.Config):
|
||||
"""Get the flag of `do_quantization` for rknn backend.
|
||||
|
||||
Args:
|
||||
deploy_cfg (mmcv.Config): The content of config.
|
||||
|
||||
Returns:
|
||||
bool: Do quantization or not.
|
||||
"""
|
||||
if get_backend(deploy_cfg) == Backend.RKNN:
|
||||
return get_backend_config(
|
||||
deploy_cfg)['quantization_config']['do_quantization']
|
||||
return False
|
||||
|
|
|
@ -372,6 +372,7 @@ def main():
|
|||
onnx_path,
|
||||
output_path,
|
||||
deploy_cfg_path,
|
||||
model_cfg_path,
|
||||
dataset_file=dataset_file)
|
||||
|
||||
backend_files.append(output_path)
|
||||
|
|
Loading…
Reference in New Issue