Add rv1126 yolov3 support to sdk (#1280)

* add yolov3 head to SDK

* add yolov5 head to SDK

* fix export-info and lint, add reverse check

* fix lint

* fix export info for yolo heads

* add output_names to partition_config

* fix typo

* config

* normalize config

* fix

* refactor config

* fix lint and doc

* c++ form

* resolve comments

* fix CI

* fix CI

* fix CI

* float strides anchors

* refine pipeline of rknn-int8

* config

* rename func

* refactor

* rknn wrapper dict and fix typo

* rknn wrapper output update,  mmcls use end2end type

* fix typo
pull/1420/head
AllentDan 2022-11-22 20:16:22 +08:00 committed by GitHub
parent 522fcc0635
commit 4dd4d4851b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 579 additions and 147 deletions

View File

@ -1,8 +1,6 @@
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None, # [[103.53, 116.28, 123.675]],
std_values=None, # [[57.375, 57.12, 58.395]],
target_platform='rv1126', # 'rk3588'
optimization_level=1),
quantization_config=dict(do_quantization=False, dataset=None))
quantization_config=dict(do_quantization=True, dataset=None))

View File

@ -0,0 +1,7 @@
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[224, 224])
codebase_config = dict(model_type='end2end')
backend_config = dict(
input_size_list=[[3, 224, 224]],
quantization_config=dict(do_quantization=False))

View File

@ -1,5 +1,5 @@
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[224, 224])
codebase_config = dict(model_type='rknn')
codebase_config = dict(model_type='end2end')
backend_config = dict(input_size_list=[[3, 224, 224]])

View File

@ -0,0 +1,34 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
backend_config = dict(
input_size_list=[[3, 320, 320]],
quantization_config=dict(do_quantization=False))
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True, # should always be set to True
# partition_cfg=[
# dict(
# save_file='model.onnx', # name to save the partitioned onnx
# start=['detector_forward:input'], # [mark_name:input, ...]
# end=['yolo_head:input'], # [mark_name:output, ...]
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
# ])
# # retinanet, ssd, fsaf for rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True,
# partition_cfg=[
# dict(
# save_file='model.onnx',
# start='detector_forward:input',
# end=['BaseDenseHead:output'],
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
# ])

View File

@ -6,7 +6,7 @@ codebase_config = dict(model_type='rknn')
backend_config = dict(input_size_list=[[3, 320, 320]])
# # yolov3, yolox
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True, # should always be set to True
@ -14,10 +14,11 @@ backend_config = dict(input_size_list=[[3, 320, 320]])
# dict(
# save_file='model.onnx', # name to save the partitioned onnx
# start=['detector_forward:input'], # [mark_name:input, ...]
# end=['yolo_head:input']) # [mark_name:output, ...]
# end=['yolo_head:input'], # [mark_name:output, ...]
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
# ])
# # retinanet, ssd, fsaf
# # retinanet, ssd, fsaf for rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True,
@ -25,5 +26,7 @@ backend_config = dict(input_size_list=[[3, 320, 320]])
# dict(
# save_file='model.onnx',
# start='detector_forward:input',
# end=['BaseDenseHead:output'])
# end=['BaseDenseHead:output'],
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
# ])

View File

@ -8,5 +8,6 @@ partition_config = dict(
dict(
save_file='yolov3.onnx',
start=['detector_forward:input'],
end=['yolo_head:input'])
end=['yolo_head:input'],
output_names=[f'pred_maps.{i}' for i in range(3)])
])

View File

@ -0,0 +1,9 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
backend_config = dict(
input_size_list=[[3, 320, 320]],
quantization_config=dict(do_quantization=False))

View File

@ -0,0 +1,228 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "yolo_head.h"
#include <math.h>
#include <algorithm>
#include <numeric>
#include "mmdeploy/core/model.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "utils.h"
namespace mmdeploy::mmdet {
YOLOHead::YOLOHead(const Value& cfg) : MMDetection(cfg) {
auto init = [&]() -> Result<void> {
auto model = cfg["context"]["model"].get<Model>();
if (cfg.contains("params")) {
nms_pre_ = cfg["params"].value("nms_pre", -1);
score_thr_ = cfg["params"].value("score_thr", 0.02f);
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
iou_threshold_ = cfg["params"].contains("nms")
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
: 0.45f;
if (cfg["params"].contains("anchor_generator")) {
from_value(cfg["params"]["anchor_generator"]["base_sizes"], anchors_);
from_value(cfg["params"]["anchor_generator"]["strides"], strides_);
}
}
return success();
};
init().value();
}
Result<Value> YOLOHead::operator()(const Value& prep_res, const Value& infer_res) {
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
try {
const Device kHost{0, 0};
std::vector<Tensor> pred_maps;
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) {
auto pred_map = iter->get<Tensor>();
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream()));
pred_maps.push_back(_pred_map);
}
OUTCOME_TRY(stream().Wait());
// reorder pred_maps according to strides and anchors, mainly for rknpu yolov3
if ((pred_maps.size() > 1) &&
!((strides_[0] < strides_[1]) ^ (pred_maps[0].shape(3) < pred_maps[1].shape(3)))) {
std::reverse(pred_maps.begin(), pred_maps.end());
}
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], pred_maps));
return to_value(result);
} catch (...) {
return Status(eFail);
}
}
inline static int clamp(float val, int min, int max) {
return val > min ? (val < max ? val : max) : min;
}
static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }
static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); }
int YOLOHead::YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor,
int grid_h, int grid_w, int height, int width, int stride,
std::vector<float>& boxes, std::vector<float>& obj_probs,
std::vector<int>& class_id, float threshold) const {
auto input = const_cast<float*>(feat_map.data<float>());
auto prop_box_size = feat_map.shape(1) / anchor.size();
const int kClasses = prop_box_size - 5;
int valid_count = 0;
int grid_len = grid_h * grid_w;
float thres = unsigmoid(threshold);
for (int a = 0; a < anchor.size(); a++) {
for (int i = 0; i < grid_h; i++) {
for (int j = 0; j < grid_w; j++) {
float box_confidence = input[(prop_box_size * a + 4) * grid_len + i * grid_w + j];
if (box_confidence >= thres) {
int offset = (prop_box_size * a) * grid_len + i * grid_w + j;
float* in_ptr = input + offset;
float box_x = sigmoid(*in_ptr);
float box_y = sigmoid(in_ptr[grid_len]);
float box_w = in_ptr[2 * grid_len];
float box_h = in_ptr[3 * grid_len];
auto box = yolo_decode(box_x, box_y, box_w, box_h, stride, anchor, j, i, a);
box_x = box[0];
box_y = box[1];
box_w = box[2];
box_h = box[3];
box_x -= (box_w / 2.0);
box_y -= (box_h / 2.0);
boxes.push_back(box_x);
boxes.push_back(box_y);
boxes.push_back(box_x + box_w);
boxes.push_back(box_y + box_h);
float max_class_probs = in_ptr[5 * grid_len];
int max_class_id = 0;
for (int k = 1; k < kClasses; ++k) {
float prob = in_ptr[(5 + k) * grid_len];
if (prob > max_class_probs) {
max_class_id = k;
max_class_probs = prob;
}
}
obj_probs.push_back(sigmoid(max_class_probs) * sigmoid(box_confidence));
class_id.push_back(max_class_id);
valid_count++;
}
}
}
}
return valid_count;
}
Result<Detections> YOLOHead::GetBBoxes(const Value& prep_res,
const std::vector<Tensor>& pred_maps) const {
std::vector<float> filter_boxes;
std::vector<float> obj_probs;
std::vector<int> class_id;
int model_in_h = prep_res["img_shape"][1].get<int>();
int model_in_w = prep_res["img_shape"][2].get<int>();
for (int i = 0; i < pred_maps.size(); i++) {
int stride = strides_[i];
int grid_h = model_in_h / stride;
int grid_w = model_in_w / stride;
YOLOFeatDecode(pred_maps[i], anchors_[i], grid_h, grid_w, model_in_h, model_in_w, stride,
filter_boxes, obj_probs, class_id, score_thr_);
}
std::vector<int> indexArray;
for (int i = 0; i < obj_probs.size(); ++i) {
indexArray.push_back(i);
}
Sort(obj_probs, class_id, indexArray);
Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT,
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"});
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>());
NMS(dets, iou_threshold_, indexArray);
Detections objs;
std::vector<float> scale_factor;
if (prep_res.contains("scale_factor")) {
from_value(prep_res["scale_factor"], scale_factor);
} else {
scale_factor = {1.f, 1.f, 1.f, 1.f};
}
int ori_width = prep_res["ori_shape"][2].get<int>();
int ori_height = prep_res["ori_shape"][1].get<int>();
auto det_ptr = dets.data<float>();
for (int i = 0; i < indexArray.size(); ++i) {
if (indexArray[i] == -1) {
continue;
}
int j = indexArray[i];
auto x1 = clamp(det_ptr[j * 4 + 0], 0, model_in_w);
auto y1 = clamp(det_ptr[j * 4 + 1], 0, model_in_h);
auto x2 = clamp(det_ptr[j * 4 + 2], 0, model_in_w);
auto y2 = clamp(det_ptr[j * 4 + 3], 0, model_in_h);
int label_id = class_id[i];
float score = obj_probs[i];
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);
auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height);
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
rect[3] - rect[1]);
continue;
}
Detection det{};
det.index = i;
det.label_id = label_id;
det.score = score;
det.bbox = rect;
objs.push_back(std::move(det));
}
return objs;
}
Result<Value> YOLOV3Head::operator()(const Value& prep_res, const Value& infer_res) {
return YOLOHead::operator()(prep_res, infer_res);
}
std::array<float, 4> YOLOV3Head::yolo_decode(float box_x, float box_y, float box_w, float box_h,
float stride,
const std::vector<std::vector<float>>& anchor, int j,
int i, int a) const {
box_x = (box_x + j) * stride;
box_y = (box_y + i) * stride;
box_w = expf(box_w) * anchor[a][0];
box_h = expf(box_h) * anchor[a][1];
return std::array<float, 4>{box_x, box_y, box_w, box_h};
}
Result<Value> YOLOV5Head::operator()(const Value& prep_res, const Value& infer_res) {
return YOLOHead::operator()(prep_res, infer_res);
}
std::array<float, 4> YOLOV5Head::yolo_decode(float box_x, float box_y, float box_w, float box_h,
float stride,
const std::vector<std::vector<float>>& anchor, int j,
int i, int a) const {
box_x = box_x * 2 - 0.5;
box_y = box_y * 2 - 0.5;
box_w = box_w * 2 - 0.5;
box_h = box_h * 2 - 0.5;
box_x = (box_x + j) * stride;
box_y = (box_y + i) * stride;
box_w = box_w * box_w * anchor[a][0];
box_h = box_h * box_h * anchor[a][1];
return std::array<float, 4>{box_x, box_y, box_w, box_h};
}
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV3Head);
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV5Head);
} // namespace mmdeploy::mmdet

View File

@ -0,0 +1,53 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
#define MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
#include "mmdeploy/codebase/mmdet/mmdet.h"
#include "mmdeploy/core/tensor.h"
namespace mmdeploy::mmdet {
class YOLOHead : public MMDetection {
public:
explicit YOLOHead(const Value& cfg);
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
int YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor,
int grid_h, int grid_w, int height, int width, int stride,
std::vector<float>& boxes, std::vector<float>& obj_probs,
std::vector<int>& class_id, float threshold) const;
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& pred_maps) const;
virtual std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h,
float stride,
const std::vector<std::vector<float>>& anchor, int j,
int i, int a) const = 0;
private:
float score_thr_{0.4f};
int nms_pre_{1000};
float iou_threshold_{0.45f};
int min_bbox_size_{0};
std::vector<std::vector<std::vector<float>>> anchors_;
std::vector<float> strides_;
};
class YOLOV3Head : public YOLOHead {
public:
using YOLOHead::YOLOHead;
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride,
const std::vector<std::vector<float>>& anchor, int j, int i,
int a) const override;
};
class YOLOV5Head : public YOLOHead {
public:
using YOLOHead::YOLOHead;
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride,
const std::vector<std::vector<float>>& anchor, int j, int i,
int a) const override;
};
} // namespace mmdeploy::mmdet
#endif // MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_

View File

@ -138,30 +138,12 @@ label: 65, score: 0.95
## Troubleshooting
- Quantization fails.
Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `True`. Please modify the settings of `Normalize` in the `model_cfg` from
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
```
to
```python
img_norm_cfg = dict(
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[[103.53, 116.28, 123.675]]` and `std_values=[[57.375, 57.12, 58.395]]`.
- MMDet models.
YOLOV3 & YOLOX: you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py):
```python
# yolov3, yolox
# yolov3, yolox for rknn-toolkit and rknn-toolkit2
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
@ -169,7 +151,8 @@ label: 65, score: 0.95
dict(
save_file='model.onnx', # name to save the partitioned onnx
start=['detector_forward:input'], # [mark_name:input, ...]
end=['yolo_head:input']) # [mark_name:output, ...]
end=['yolo_head:input'], # [mark_name:output, ...]
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
])
```
@ -184,7 +167,9 @@ label: 65, score: 0.95
dict(
save_file='model.onnx',
start='detector_forward:input',
end=['BaseDenseHead:output'])
end=['BaseDenseHead:output'],
output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
[f'BaseDenseHead.loc.{i}' for i in range(5)])
])
```

View File

@ -66,7 +66,8 @@ partition_config = dict(
dict(
save_file='yolov3.onnx', # filename to save the partitioned onnx model
start=['detector_forward:input'], # [mark_name:input/output, ...]
end=['yolo_head:input']) # [mark_name:input/output, ...]
end=['yolo_head:input'], # [mark_name:input/output, ...]
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
])
```

View File

@ -105,7 +105,7 @@ python tools/deploy.py \
将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)
```python
# yolov3, yolox
# yolov3, yolox for rknn-toolkit and rknn-toolkit2
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
@ -113,7 +113,8 @@ partition_config = dict(
dict(
save_file='model.onnx', # name to save the partitioned onnx
start=['detector_forward:input'], # [mark_name:input, ...]
end=['yolo_head:input']) # [mark_name:output, ...]
end=['yolo_head:input'], # [mark_name:output, ...]
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
])
```
@ -143,7 +144,9 @@ partition_config = dict(
dict(
save_file='model.onnx',
start='detector_forward:input',
end=['BaseDenseHead:output'])
end=['BaseDenseHead:output'],
output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
[f'BaseDenseHead.loc.{i}' for i in range(5)])
])
```
@ -168,24 +171,6 @@ backend_config = dict(
### 问题说明
- 量化失败.
经验来说, 如果 `do_quantization` 被设置为 `True`RKNN 需要的输入没有被归一化过。请修改 `Normalize``model_cfg` 的设置,如将
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
```
改为
```python
img_norm_cfg = dict(
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
此外, deploy_cfg 的 `mean_values``std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[[103.53, 116.28, 123.675]]`, `std_values=[[57.375, 57.12, 58.395]]`
- SDK 只支持 int8 的 rknn 模型,这需要在转换模型时设置 `do_quantization=True`
## 模型推理

View File

@ -64,7 +64,8 @@ partition_config = dict(
dict(
save_file='yolov3.onnx', # filename to save the partitioned onnx model
start=['detector_forward:input'], # [mark_name:input/output, ...]
end=['yolo_head:input']) # [mark_name:input/output, ...]
end=['yolo_head:input'], # [mark_name:input/output, ...]
output_names=[f'pred_maps.{i}' for i in range(3)]) # output names
])
```

View File

@ -64,10 +64,12 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
if backend == Backend.PYTORCH:
model = task_processor.init_pytorch_model(model[0])
model_inputs, _ = task_processor.create_input(img, input_shape)
else:
model = task_processor.init_backend_model(model, **kwargs)
model_inputs, _ = task_processor.create_input(
img, input_shape, task_processor.update_test_pipeline)
model_inputs, _ = task_processor.create_input(img, input_shape)
with torch.no_grad():
result = task_processor.run_inference(model, model_inputs)[0]

View File

@ -4,15 +4,17 @@ from typing import Optional, Union
import mmcv
from rknn.api import RKNN
from mmdeploy.utils import (get_common_config, get_onnx_config,
from mmdeploy.utils import (get_backend_config, get_common_config,
get_normalization, get_onnx_config,
get_partition_config, get_quantization_config,
get_root_logger, load_config)
from mmdeploy.utils.config_utils import get_backend_config
get_rknn_quantization, get_root_logger,
load_config)
def onnx2rknn(onnx_model: str,
output_path: str,
deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config],
dataset_file: Optional[str] = None,
**kwargs):
"""Convert ONNX to RKNN.
@ -40,6 +42,14 @@ def onnx2rknn(onnx_model: str,
output_names = onnx_params.get('output_names', None)
input_size_list = get_backend_config(deploy_cfg).get(
'input_size_list', None)
# update norm value
if get_rknn_quantization(deploy_cfg) is True:
transform = get_normalization(model_cfg)
common_params.update(
dict(
mean_values=[transform['mean']],
std_values=[transform['std']]))
# update output_names for partition models
if get_partition_config(deploy_cfg) is not None:
import onnx
@ -62,7 +72,7 @@ def onnx2rknn(onnx_model: str,
if dataset_cfg is None and dataset_file is None:
do_quantization = False
logger.warning('no dataset passed in, quantization is skipped')
if dataset_file is None:
if dataset_cfg is not None:
dataset_file = dataset_cfg
ret = rknn.build(do_quantization=do_quantization, dataset=dataset_file)
if ret != 0:

View File

@ -48,7 +48,7 @@ class RKNNWrapper(BaseWrapper):
super().__init__(output_names)
def forward(self, inputs: Dict[str,
torch.Tensor]) -> Sequence[torch.Tensor]:
torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Run forward inference. Note that the shape of the input tensor is
NxCxHxW while RKNN only accepts the numpy inputs of NxHxWxC. There is a
permute operation outside RKNN inference.
@ -57,11 +57,14 @@ class RKNNWrapper(BaseWrapper):
inputs (Dict[str, torch.Tensor]): Input name and tensor pairs.
Return:
Sequence[torch.Tensor]: The output tensors.
Dict[str, torch.Tensor]: The output tensors.
"""
rknn_out = self.__rknnnn_execute(
[i.permute(0, 2, 3, 1).cpu().numpy() for i in inputs.values()])
return [torch.from_numpy(out) for out in rknn_out]
rknn_out = [torch.from_numpy(out) for out in rknn_out]
if self.output_names is not None:
return dict(zip(self.output_names, rknn_out))
return {'#' + str(i): x for i, x in enumerate(rknn_out)}
@TimeCounter.count_time(Backend.RKNN.value)
def __rknnnn_execute(self, inputs: Sequence[np.array]):

View File

@ -7,7 +7,8 @@ import mmcv
from mmdeploy.apis import build_task_processor
from mmdeploy.utils import (Backend, Task, get_backend, get_codebase,
get_common_config, get_ir_config, get_root_logger,
get_common_config, get_ir_config,
get_partition_config, get_root_logger,
get_task_type, is_dynamic_batch, load_config)
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
from .tracer import add_transform_tag, get_transform_static
@ -94,6 +95,9 @@ def get_models(deploy_cfg: Union[str, mmcv.Config],
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
precision = 'FP32'
ir_name = get_ir_config(deploy_cfg)['save_file']
if get_partition_config(deploy_cfg) is not None:
ir_name = get_partition_config(
deploy_cfg)['partition_cfg'][0]['save_file']
net = ir_name
weights = ''
backend = get_backend(deploy_cfg=deploy_cfg)
@ -185,6 +189,9 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
backend = get_backend(deploy_cfg=deploy_cfg)
if backend in (Backend.TORCHSCRIPT, Backend.RKNN):
output_names = ir_config.get('output_names', None)
if get_partition_config(deploy_cfg) is not None:
output_names = get_partition_config(
deploy_cfg)['partition_cfg'][0]['output_names']
input_map = dict(img='#0')
output_map = {name: f'#{i}' for i, name in enumerate(output_names)}
else:
@ -258,6 +265,8 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
for transform in transforms:
if transform['type'] == 'Normalize':
transform['to_float'] = False
transform['mean'] = [0, 0, 0]
transform['std'] = [1, 1, 1]
if transforms[0]['type'] != 'Lift':
assert transforms[0]['type'] == 'LoadImageFromFile', \
@ -299,14 +308,15 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
task = Task.INSTANCE_SEGMENTATION
component = task_map[task]['component']
if get_backend(deploy_cfg) == Backend.RKNN:
if 'YOLO' in task_processor.model_cfg.model.type:
bbox_head = task_processor.model_cfg.model.bbox_head
component = bbox_head.type
params['anchor_generator'] = bbox_head.get('anchor_generator',
None)
else: # default using base_dense_head
component = 'BaseDenseHead'
if task == Task.OBJECT_DETECTION:
if get_backend(deploy_cfg) == Backend.RKNN:
if 'YOLO' in task_processor.model_cfg.model.type:
bbox_head = task_processor.model_cfg.model.bbox_head
component = bbox_head.type
params['anchor_generator'] = bbox_head.get(
'anchor_generator', None)
else: # default using base_dense_head
component = 'BaseDenseHead'
if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION:
if 'type' in params:

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -8,7 +8,8 @@ import torch
from torch.utils.data import DataLoader, Dataset
from mmdeploy.utils import (get_backend_config, get_codebase,
get_codebase_config, get_root_logger)
get_codebase_config, get_rknn_quantization,
get_root_logger)
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
@ -139,18 +140,49 @@ class BaseTask(metaclass=ABCMeta):
return self.codebase_class.single_gpu_test(model, data_loader, show,
out_dir, **kwargs)
@staticmethod
def update_test_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
"""Update preprocess pipeline.
Args:
model_cfg (str | mmcv.Config): Model config file.
deploy_cfg (str | mmcv.Config): Deployment config file.
Returns:
cfg (mmcv.Config): Updated model_cfg.
"""
cfg = model_cfg.deepcopy()
if get_rknn_quantization(deploy_cfg):
pipelines = cfg.data.test.pipeline
for i, pipeline in enumerate(pipelines):
if pipeline['type'] == 'MultiScaleFlipAug':
assert 'transforms' in pipeline
for trans in pipeline['transforms']:
if trans['type'] == 'Normalize':
trans['mean'] = [0, 0, 0]
trans['std'] = [1, 1, 1]
else:
if pipeline['type'] == 'Normalize':
pipeline['mean'] = [0, 0, 0]
pipeline['std'] = [1, 1, 1]
cfg.data.test.pipeline = pipelines
return cfg
@abstractmethod
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None,
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None,
**kwargs) -> Tuple[Dict, torch.Tensor]:
"""Create input for model.
Args:
imgs (str | np.ndarray | Sequence): Input image(s),
accepted data types are `str`, `np.ndarray`.
input_shape (list[int]): Input shape of image in (width, height)
format, defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -147,15 +147,18 @@ class VideoRecognition(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None) \
-> Tuple[Dict, torch.Tensor]:
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None,
**kwargs) -> Tuple[Dict, torch.Tensor]:
"""Create input for recognizer.
Args:
imgs (Any): Input image(s), accepted data type are `str`,
`np.ndarray`, `torch.Tensor`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -8,8 +8,7 @@ import torch
from torch.utils.data import Dataset
from mmdeploy.codebase.base import BaseTask
from mmdeploy.utils import Task, get_root_logger
from mmdeploy.utils.config_utils import get_input_shape
from mmdeploy.utils import Task, get_input_shape, get_root_logger
from .mmclassification import MMCLS_TASK
@ -35,6 +34,7 @@ def process_model_config(model_cfg: mmcv.Config,
else:
if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
cfg.data.test.pipeline.pop(0)
# check whether input_shape is valid
if input_shape is not None:
if 'crop_size' in cfg.data.test.pipeline[2]:
@ -111,15 +111,18 @@ class Classification(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Optional[Sequence[int]] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for classifier.
Args:
imgs (Union[str, np.ndarray, Sequence]): Input image(s),
accepted data type are `str`, `np.ndarray`, Sequence.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.
@ -128,7 +131,10 @@ class Classification(BaseTask):
from mmcv.parallel import collate, scatter
if isinstance(imgs, (str, np.ndarray)):
imgs = [imgs]
cfg = process_model_config(self.model_cfg, imgs, input_shape)
model_cfg = self.model_cfg
if pipeline_updater is not None:
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
cfg = process_model_config(model_cfg, imgs, input_shape)
data_list = []
test_pipeline = Compose(cfg.data.test.pipeline)
for img in imgs:
@ -276,7 +282,8 @@ class Classification(BaseTask):
dict: Composed of the preprocess information.
"""
input_shape = get_input_shape(self.deploy_cfg)
cfg = process_model_config(self.model_cfg, [''], input_shape)
cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg)
cfg = process_model_config(cfg, [''], input_shape)
preprocess = cfg.data.test.pipeline
return preprocess

View File

@ -144,25 +144,6 @@ class SDKEnd2EndModel(End2EndModel):
return pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]
@__BACKEND_MODEL.register_module('rknn')
class RKNNEnd2EndModel(End2EndModel):
"""RKNN inference class, converts RKNN output to mmcls format."""
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
List[np.ndarray]:
"""The interface for forward test.
Args:
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
Returns:
List[np.ndarray]: A list of classification prediction.
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = [out.numpy() for out in outputs]
return outputs
def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
"""Get class name from config.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -7,8 +7,7 @@ import torch
from mmcv.parallel import DataContainer
from torch.utils.data import Dataset
from mmdeploy.utils import Task
from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape
from mmdeploy.utils import Task, get_input_shape, is_dynamic_shape
from ...base import BaseTask
from .mmdetection import MMDET_TASK
@ -103,15 +102,18 @@ class ObjectDetection(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for detector.
Args:
imgs (str|np.ndarray): Input image(s), accpeted data type are
`str`, `np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.
@ -121,7 +123,10 @@ class ObjectDetection(BaseTask):
if isinstance(imgs, (str, np.ndarray)):
imgs = [imgs]
dynamic_flag = is_dynamic_shape(self.deploy_cfg)
cfg = process_model_config(self.model_cfg, imgs, input_shape)
model_cfg = self.model_cfg
if pipeline_updater is not None:
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
cfg = process_model_config(model_cfg, imgs, input_shape)
# Drop pad_to_square when static shape. Because static shape should
# ensure the shape before input image.
if not dynamic_flag:
@ -291,7 +296,8 @@ class ObjectDetection(BaseTask):
dict: Composed of the preprocess information.
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg)
model_cfg = process_model_config(cfg, [''], input_shape)
preprocess = model_cfg.data.test.pipeline
return preprocess

View File

@ -722,6 +722,7 @@ class RKNNModel(End2EndModel):
class labels of shape [N, num_det].
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = [i for i in outputs.values()]
ret = self._get_bboxes(outputs, img_metas)
return ret

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from os import path as osp
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -65,12 +65,19 @@ class MonocularDetection(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape: Sequence[int] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for detector.
Args:
pcd (str): Input pcd file path.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.
"""

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -117,13 +117,16 @@ class SuperResolution(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None,
**kwargs) -> Tuple[Dict, torch.Tensor]:
"""Create input for editing processor.
Args:
imgs (str | np.ndarray): Input image(s).
input_shape (Sequence[int] | None): A list of two integer in
(width, height) format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -108,15 +108,18 @@ class TextDetection(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for segmentor.
Args:
imgs (str | np.ndarray): Input image(s), accepted data type are
`str`, `np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -114,15 +114,18 @@ class TextRecognition(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape: Sequence[int] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for segmentor.
Args:
imgs (str | np.ndarray): Input image(s), accepted data type are
`str`, `np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.

View File

@ -3,7 +3,7 @@
import copy
import logging
import os
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -130,15 +130,18 @@ class PoseDetection(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None,
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None,
**kwargs) -> Tuple[Dict, torch.Tensor]:
"""Create input for pose detection.
Args:
imgs (Any): Input image(s), accepted data type are ``str``,
``np.ndarray``.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to ``None``.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -150,15 +150,18 @@ class RotatedDetection(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape: Sequence[int] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for rotated object detection.
Args:
imgs (str | np.ndarray): Input image(s), accepted data type are
`str`, `np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
@ -33,6 +33,7 @@ def process_model_config(model_cfg: mmcv.Config,
if isinstance(imgs[0], np.ndarray):
# set loading pipeline type
cfg.data.test.pipeline[0] = LoadImage()
# for static exporting
if input_shape is not None:
for pipeline in cfg.data.test.pipeline[1:]:
@ -107,15 +108,18 @@ class Segmentation(BaseTask):
def create_input(self,
imgs: Union[str, np.ndarray, Sequence],
input_shape: Sequence[int] = None) \
input_shape: Optional[Sequence[int]] = None,
pipeline_updater: Optional[Callable] = None, **kwargs) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for segmentor.
Args:
imgs (Any): Input image(s), accepted data type are `str`,
`np.ndarray`, `torch.Tensor`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
input_shape (Sequence[int] | None): Input shape of image in
(width, height) format, defaults to `None`.
pipeline_updater (function | None): A function to get a new
pipeline.
Returns:
tuple: (data, img), meta information for the input image and input.
@ -125,7 +129,10 @@ class Segmentation(BaseTask):
if isinstance(imgs, (str, np.ndarray)):
imgs = [imgs]
imgs = [mmcv.imread(_) for _ in imgs]
cfg = process_model_config(self.model_cfg, imgs, input_shape)
model_cfg = self.model_cfg
if pipeline_updater is not None:
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
cfg = process_model_config(model_cfg, imgs, input_shape)
test_pipeline = Compose(cfg.data.test.pipeline)
data_list = []
for img in imgs:
@ -260,7 +267,8 @@ class Segmentation(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
load_from_file = self.model_cfg.data.test.pipeline[0]
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg)
model_cfg = process_model_config(cfg, [''], input_shape)
preprocess = model_cfg.data.test.pipeline
preprocess[0] = load_from_file
return preprocess

View File

@ -162,7 +162,9 @@ class RKNNModel(End2EndModel):
List[np.ndarray]: A list of segmentation map.
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = [output.argmax(dim=1, keepdim=True) for output in outputs]
outputs = [
output.argmax(dim=1, keepdim=True) for output in outputs.values()
]
outputs = [out.detach().cpu().numpy() for out in outputs]
return outputs

View File

@ -21,8 +21,9 @@ if importlib.util.find_spec('mmcv') is not None:
get_codebase_config, get_common_config,
get_dynamic_axes, get_input_shape,
get_ir_config, get_model_inputs,
get_onnx_config, get_partition_config,
get_quantization_config, get_task_type,
get_normalization, get_onnx_config,
get_partition_config, get_quantization_config,
get_rknn_quantization, get_task_type,
is_dynamic_batch, is_dynamic_shape, load_config)
# yapf: enable
@ -33,5 +34,6 @@ if importlib.util.find_spec('mmcv') is not None:
'get_codebase_config', 'get_common_config', 'get_dynamic_axes',
'get_input_shape', 'get_ir_config', 'get_model_inputs',
'get_onnx_config', 'get_partition_config', 'get_quantization_config',
'get_task_type', 'is_dynamic_batch', 'is_dynamic_shape', 'load_config'
'get_task_type', 'is_dynamic_batch', 'is_dynamic_shape', 'load_config',
'get_rknn_quantization', 'get_normalization'
]

View File

@ -393,3 +393,40 @@ def get_dynamic_axes(
raise KeyError('No names were found to define dynamic axes.')
dynamic_axes = dict(zip(axes_names, dynamic_axes))
return dynamic_axes
def get_normalization(model_cfg: Union[str, mmcv.Config]):
"""Get the Normalize transform from model config.
Args:
model_cfg (mmcv.Config): The content of config.
Returns:
dict: The Normalize transform.
"""
model_cfg = load_config(model_cfg)[0]
pipelines = model_cfg.data.test.pipeline
for i, pipeline in enumerate(pipelines):
if pipeline['type'] == 'MultiScaleFlipAug':
assert 'transforms' in pipeline
for trans in pipeline['transforms']:
if trans['type'] == 'Normalize':
return trans
else:
if pipeline['type'] == 'Normalize':
return pipeline
def get_rknn_quantization(deploy_cfg: mmcv.Config):
"""Get the flag of `do_quantization` for rknn backend.
Args:
deploy_cfg (mmcv.Config): The content of config.
Returns:
bool: Do quantization or not.
"""
if get_backend(deploy_cfg) == Backend.RKNN:
return get_backend_config(
deploy_cfg)['quantization_config']['do_quantization']
return False

View File

@ -372,6 +372,7 @@ def main():
onnx_path,
output_path,
deploy_cfg_path,
model_cfg_path,
dataset_file=dataset_file)
backend_files.append(output_path)