Sync rv1126 to dev-1.x by cherry-pick (#1295)

* remove imports (#1207)

* remove imports

* update doc

* detailed docstring

* rephrase

* Add model conversion support to rv1126 (#1203)

* WIP

* fix interpolate

* support yolov3 and retinanet

* support seg

* support ssd

* supports both partition types for retinanet and ssd

* mean std doc

* update doc, add UT

* support FSAF

* rename configs

* update dump info

* update

* python package installation doc

* update doc

* update doc

* doc

* fix

* docstring

* remote partition config
pull/1317/head
AllentDan 2022-11-07 10:19:22 +08:00 committed by GitHub
parent 331292a992
commit c5edb85550
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 482 additions and 64 deletions

View File

@ -1,8 +1,8 @@
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None,
std_values=None,
target_platform='rk3588',
optimization_level=3),
mean_values=None, # [[103.53, 116.28, 123.675]],
std_values=None, # [[57.375, 57.12, 58.395]],
target_platform='rv1126', # 'rk3588'
optimization_level=1),
quantization_config=dict(do_quantization=False, dataset=None))

View File

@ -0,0 +1,29 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
backend_config = dict(input_size_list=[[3, 320, 320]])
# yolov3, yolox
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True, # should always be set to True
# partition_cfg=[
# dict(
# save_file='model.onnx', # name to save the partitioned onnx
# start=['detector_forward:input'], # [mark_name:input, ...]
# end=['yolo_head:input']) # [mark_name:output, ...]
# ])
# # retinanet, ssd, fsaf
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True,
# partition_cfg=[
# dict(
# save_file='model.onnx',
# start='detector_forward:input',
# end=['BaseDenseHead:output'])
# ])

View File

@ -1,17 +0,0 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[640, 640])
codebase_config = dict(model_type='rknn')
backend_config = dict(input_size_list=[[3, 640, 640]])
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='model.onnx', # name to save the partitioned onnx model
start=['detector_forward:input'], # [mark_name:input/output, ...]
end=['yolo_head:input']) # [mark_name:input/output, ...]
])

View File

@ -1,7 +1,7 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[512, 512])
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
backend_config = dict(input_size_list=[[3, 512, 512]])
backend_config = dict(input_size_list=[[3, 320, 320]])

View File

@ -1,18 +1,26 @@
# Build for RKNN
This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`.
This tutorial is based on Ubuntu-18.04 and Rockchip NPU `rk3588`. For different NPU devices, you may have to use different rknn packages.
Below is a table describing the relationship:
| Device | Python Package | c/c++ SDK |
| -------------------- | ---------------------------------------------------------------- | -------------------------------------------------- |
| RK1808/RK1806 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) |
| RV1109/RV1126 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) |
| RK3566/RK3568/RK3588 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
| RV1103/RV1106 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
## Installation
It is recommended to create a virtual environment for the project.
1. get RKNN-Toolkit2 through:
1. Get RKNN-Toolkit2 or RKNN-Toolkit through git. RKNN-Toolkit2 for example:
```
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
```
2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc). In our testing, we used the rknn-toolkit2 1.2.0 with commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`. When installing rknn-toolkit2, it is better to append `--no-deps` after the commands to avoid dependency conflicts. For example:
2. Install RKNN python package following [rknn-toolkit2 doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc) or [rknn-toolkit doc](https://github.com/rockchip-linux/rknn-toolkit/tree/master/doc). When installing rknn python package, it is better to append `--no-deps` after the commands to avoid dependency conflicts. RKNN-Toolkit2 package for example:
```
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
@ -67,17 +75,19 @@ backend_config = dict(
```
The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`.
The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`. You may have to modify `target_platform` for your own preference.
## Build SDK with Rockchip NPU
1. get rknpu2 through:
### Build SDK with RKNPU2
1. Get rknpu2 through git:
```
git clone git@github.com:rockchip-linux/rknpu2.git
```
2. for linux, download gcc cross compiler. The download link of the compiler from the official user guide of `rknpu2` was deprecated. You may use another verified [link](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). After download and unzip the compiler, you may open the terminal, set `RKNN_TOOL_CHAIN` and `RKNPU2_DEVICE_DIR` by `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`.
2. For linux, download gcc cross compiler. The download link of the compiler from the official user guide of `rknpu2` was deprecated. You may use another verified [link](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). After download and unzip the compiler, you may open the terminal, set `RKNN_TOOL_CHAIN` and `RKNPU2_DEVICE_DIR` by `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`.
3. after the above preparition, run the following commands:
@ -144,4 +154,38 @@ label: 65, score: 0.95
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`.
Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[[103.53, 116.28, 123.675]]` and `std_values=[[57.375, 57.12, 58.395]]`.
- MMDet models.
YOLOV3 & YOLOX: you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py):
```python
# yolov3, yolox
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='model.onnx', # name to save the partitioned onnx
start=['detector_forward:input'], # [mark_name:input, ...]
end=['yolo_head:input']) # [mark_name:output, ...]
])
```
RetinaNet & SSD & FSAF with rknn-toolkit2, you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py). Users with rknn-toolkit can directly use default config.
```python
# retinanet, ssd
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True,
partition_cfg=[
dict(
save_file='model.onnx',
start='detector_forward:input',
end=['BaseDenseHead:output'])
])
```
- SDK only supports int8 rknn model, which require `do_quantization=True` when converting models.

View File

@ -4,14 +4,14 @@ The table below lists the models that are guaranteed to be exportable to other b
| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Ascend | RKNN | Model config |
| :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :----: | :--: | :---------------------------------------------------------------------------------------------: |
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
| FCOS | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
| FoveaBox | MMDetection | Y | Y | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
| ATSS | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
| GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |

View File

@ -1,9 +1,9 @@
# Supported RKNN feature
Currently, MMDeploy only tests rk3588 with linux platform.
Currently, MMDeploy only tests rk3588 and rv1126 with linux platform.
The following features cannot be automatically enabled by mmdeploy and you need to manually modify the configuration in MMDeploy like [here](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py).
- target_platform other than `3588`
- target_platform other than default
- quantization settings
- optimization level other than 3
- optimization level other than 1

View File

@ -1,18 +1,26 @@
# 支持 RKNN
本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。
本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。对于不同的 NPU 设备,您需要使用不同的 rknn 包.
这是设备和安装包的关系表:
| Device | Python Package | c/c++ SDK |
| -------------------- | ---------------------------------------------------------------- | -------------------------------------------------- |
| RK1808/RK1806 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) |
| RV1109/RV1126 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) |
| RK3566/RK3568/RK3588 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
| RV1103/RV1106 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) |
## 安装
建议为项目创建一个虚拟环境。
1. 获取 RKNN-Toolkit2:
1. 使用 git 获取 RKNN-Toolkit2 或者 RKNN-Toolkit。以 RKNN-Toolkit2 为例:
```
git clone git@github.com:rockchip-linux/rknn-toolkit2.git
```
2. 通过 [官方文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc),安装 RKNN python 安装包. 在我们的测试中, 使用的 rknn-toolkit 版本是 1.2.0commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`。安装 rknn-toolkit2 时,最好在安装命令后添加`--no-deps`,以避免依赖包的冲突。比如:
2. 通过 [rknn-toolkit2 文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc) 或者 [rknn-toolkit 文档](https://github.com/rockchip-linux/rknn-toolkit/tree/master/doc)安装 RKNN python 安装包。安装 rknn python 包时,最好在安装命令后添加`--no-deps`以避免依赖包的冲突。以rknn-toolkit2为例:
```
pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps
@ -71,6 +79,8 @@ backend_config = dict(
## 安装 SDK
### RKNPU2 编译 MMDeploy SDK
1. 获取 rknpu2:
```
@ -144,4 +154,38 @@ label: 65, score: 0.95
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```
此外, deploy_cfg 的 `mean_values``std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[123.675, 116.28, 103.53]` `std_values=[58.395, 57.12, 57.375]`
此外, deploy_cfg 的 `mean_values``std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[[103.53, 116.28, 123.675]]`, `std_values=[[57.375, 57.12, 58.395]]`
- MMDet 模型.
YOLOV3 & YOLOX: 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py):
```python
# yolov3, yolox
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='model.onnx', # name to save the partitioned onnx
start=['detector_forward:input'], # [mark_name:input, ...]
end=['yolo_head:input']) # [mark_name:output, ...]
])
```
RetinaNet & SSD & FSAF with rknn-toolkit2, 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)。使用 rknn-toolkit 的用户则不用。
```python
# retinanet, ssd
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True,
partition_cfg=[
dict(
save_file='model.onnx',
start='detector_forward:input',
end=['BaseDenseHead:output'])
])
```
- SDK 只支持 int8 的 rknn 模型,这需要在转换模型时设置 `do_quantization=True`

View File

@ -4,14 +4,14 @@
| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Ascend | RKNN | Model config |
| :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :----: | :--: | :---------------------------------------------------------------------------------------------: |
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
| FCOS | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
| FoveaBox | MMDetection | Y | Y | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
| ATSS | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
| GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |

View File

@ -1,9 +1,9 @@
# 支持的 RKNN 特征
目前, MMDeploy 只在 rk3588 的 linux 平台上测试过.
目前, MMDeploy 只在 rk3588 和 rv1126 的 linux 平台上测试过.
以下特性需要手动在 MMDeploy 自行配置,如[这里](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py).
- target_platform = `3588`
- target_platform = default
- quantization settings
- optimization level = 3
- optimization level = 1

View File

@ -8,7 +8,8 @@ import onnx.utils
from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
get_new_name, parse_extractor_io_string,
remove_identity, rename_value)
remove_identity, remove_imports,
rename_value)
from mmdeploy.utils import get_root_logger
@ -198,6 +199,9 @@ def extract_partition(model: Union[str, onnx.ModelProto],
dim.dim_value = 0
dim.dim_param = f'dim_{idx}'
# remove mmdeploy domain if useless
remove_imports(extracted_model)
# save extract_model if save_file is given
if save_file is not None:
onnx.save(extracted_model, save_file)

View File

@ -132,7 +132,7 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
ir_config = get_ir_config(deploy_cfg)
backend = get_backend(deploy_cfg=deploy_cfg)
if backend == Backend.TORCHSCRIPT:
if backend in (Backend.TORCHSCRIPT, Backend.RKNN):
output_names = ir_config.get('output_names', None)
input_map = dict(img='#0')
output_map = {name: f'#{i}' for i, name in enumerate(output_names)}
@ -159,6 +159,11 @@ def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
transforms = task_processor.get_preprocess()
if get_backend(deploy_cfg) == Backend.RKNN:
del transforms[-2]
for transform in transforms:
if transform['type'] == 'Normalize':
transform['to_float'] = False
assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item'\
' type of pipeline should be LoadImageFromFile'
return dict(

View File

@ -10,8 +10,9 @@ from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task
from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape
from mmdeploy.utils import Backend, Codebase, Task
from mmdeploy.utils.config_utils import (get_backend, get_input_shape,
is_dynamic_shape)
MMDET_TASK = Registry('mmdet_tasks')
@ -278,6 +279,14 @@ class ObjectDetection(BaseTask):
if 'mask_thr_binary' in params['rcnn']:
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
type = 'ResizeInstanceMask' # for instance-seg
if get_backend(self.deploy_cfg) == Backend.RKNN:
if 'YOLO' in self.model_cfg.model.type:
bbox_head = self.model_cfg.model.bbox_head
type = bbox_head.type
params['anchor_generator'] = bbox_head.get(
'anchor_generator', None)
else: # default using base_dense_head
type = 'BaseDenseHead'
return dict(type=type, params=params)
def get_model_name(self, *args, **kwargs) -> str:

View File

@ -657,10 +657,11 @@ class RKNNModel(End2EndModel):
head_cfg = self.model_cfg._cfg_dict.model.bbox_head
head = build_head(head_cfg)
if head_cfg.type == 'YOLOXHead':
divisor = round(len(outputs) / 3)
ret = head.predict_by_feat(
outputs[:3],
outputs[3:6],
outputs[6:9],
outputs[:divisor],
outputs[divisor:2 * divisor],
outputs[2 * divisor:],
metainfos,
cfg=self.model_cfg._cfg_dict.model.test_cfg,
rescale=True)
@ -670,6 +671,31 @@ class RKNNModel(End2EndModel):
metainfos,
cfg=self.model_cfg._cfg_dict.model.test_cfg,
rescale=True)
elif head_cfg.type in ('RetinaHead', 'SSDHead', 'FSAFHead'):
partition_cfgs = get_partition_config(self.deploy_cfg)
if partition_cfgs is None: # bbox decoding done in rknn model
from mmdet.structures.bbox import scale_boxes
from ..models.layers.bbox_nms import _multiclass_nms
dets, labels = _multiclass_nms(outputs[0], outputs[1])
ret = [InstanceData() for i in range(dets.shape[0])]
for i, instance_data in enumerate(ret):
instance_data.bboxes = dets[i, :, :4]
instance_data.scores = dets[i, :, 4]
instance_data.labels = labels[i]
scale_factor = [
1 / s for s in metainfos[i]['scale_factor']
]
instance_data.bboxes = scale_boxes(instance_data.bboxes,
scale_factor)
return ret
divisor = round(len(outputs) / 2)
ret = head.predict_by_feat(
outputs[:divisor],
outputs[divisor:],
batch_img_metas=metainfos,
rescale=True,
cfg=self.model_cfg._cfg_dict.model.test_cfg)
else:
raise NotImplementedError(f'{head_cfg.type} not supported yet.')
return ret

View File

@ -14,7 +14,7 @@ from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
pad_with_value_if_necessary)
from mmdeploy.codebase.mmdet.models.layers import multiclass_nms
from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import Backend, is_dynamic_shape
@ -192,6 +192,132 @@ def base_dense_head__predict_by_feat(
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.'
'BaseDenseHead.predict_by_feat',
backend=Backend.RKNN.value)
def base_dense_head__predict_by_feat__rknn(
ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True,
**kwargs):
"""Rewrite `predict_by_feat` of `BaseDenseHead` for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
# mark nodes for partition
@mark('BaseDenseHead', outputs=['BaseDenseHead.cls', 'BaseDenseHead.loc'])
def __mark_dense_head(cls_scores, bbox_preds):
return cls_scores, bbox_preds
cls_scores, bbox_preds = __mark_dense_head(cls_scores, bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors]
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
if score_factors is None:
with_score_factors = False
mlvl_score_factor = [None for _ in range(num_levels)]
else:
with_score_factors = True
mlvl_score_factor = [
score_factors[i].detach() for i in range(num_levels)
]
mlvl_score_factors = []
assert batch_img_metas is not None
img_shape = batch_img_metas[0]['img_shape']
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
batch_size = cls_scores[0].shape[0]
mlvl_valid_bboxes = []
mlvl_valid_scores = []
mlvl_valid_priors = []
for cls_score, bbox_pred, score_factors, priors in zip(
mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
else:
scores = scores.softmax(-1)[:, :, :-1]
if with_score_factors:
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
-1).sigmoid()
score_factors = score_factors.unsqueeze(2)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)
mlvl_valid_priors.append(priors)
if with_score_factors:
mlvl_score_factors.append(score_factors)
batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
batch_priors = torch.cat(mlvl_valid_priors, dim=1)
batch_bboxes = self.bbox_coder.decode(
batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape)
if with_score_factors:
batch_score_factors = torch.cat(mlvl_score_factors, dim=1)
if not self.use_sigmoid_cls:
batch_scores = batch_scores[..., :self.num_classes]
if with_score_factors:
batch_scores = batch_scores * batch_score_factors
if isinstance(self, PAAHead):
batch_scores = batch_scores.sqrt()
return batch_bboxes, batch_scores
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
'.get_bboxes',
backend=Backend.NCNN.value)
def base_dense_head__predict_by_feat__ncnn(
ctx,

View File

@ -20,7 +20,7 @@ def single_stage_text_detector__forward(
Args:
batch_inputs (torch.Tensor): Images of shape (N, C, H, W).
batch_data_samples (list[TextDetDataSample]): A list of N
data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.

View File

@ -16,7 +16,7 @@ def base_decoder__forward(
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> Sequence[TextRecogDataSample]:
"""Perform forward propagation of the decoder and postprocessor.
"""Rewrite `predict` of `BaseDecoder` to skip post-process.
Args:
feat (Tensor, optional): Features from the backbone. Defaults

View File

@ -20,13 +20,10 @@ def encoder_decoder_recognizer__forward(ctx, self, batch_inputs: torch.Tensor,
ctx (ContextCaller): The context with additional information.
self: The instance of the class
EncoderDecoderRecognizer.
img (Tensor): Input images of shape (N, C, H, W).
batch_inputs (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
data_samples (TextRecogDataSample): Containing meta information
and gold annotations for each of the images. Defaults to None.
Returns:
out_dec (Tensor): A feature map output from a decoder. The tensor shape

View File

@ -2,10 +2,10 @@
from .extractor import create_extractor, parse_extractor_io_string
from .function_marker import mark, reset_mark_function_count
from .optimize import (attribute_to_dict, get_new_name, remove_identity,
rename_value)
remove_imports, rename_value)
__all__ = [
'mark', 'reset_mark_function_count', 'create_extractor',
'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict',
'rename_value', 'get_new_name'
'rename_value', 'get_new_name', 'remove_imports'
]

View File

@ -206,3 +206,24 @@ def remove_identity(model: onnx.ModelProto):
pass
remove_nodes(model, is_identity)
def remove_imports(model: onnx.ModelProto):
"""Remove useless imports from an ONNX model.
The domain like `mmdeploy` might influence model conversion for
some backends.
Args:
model (onnx.ModelProto): Input onnx model.
"""
logger = get_root_logger()
dst_domain = ['']
for node in model.graph.node:
if hasattr(node, 'module') and (node.module not in dst_domain):
dst_domain.append(node.module)
src_domains = [oi.domain for oi in model.opset_import]
for i, src_domain in enumerate(src_domains):
if src_domain not in dst_domain:
logger.info(f'remove opset_import {src_domain}')
model.opset_import.pop(i)

View File

@ -40,6 +40,38 @@ def interpolate__ncnn(ctx,
recompute_scale_factor=recompute_scale_factor)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.interpolate', backend='rknn')
def interpolate__rknn(ctx,
input: torch.Tensor,
size: Optional[Union[int, Tuple[int], Tuple[int, int],
Tuple[int, int, int]]] = None,
scale_factor: Optional[Union[float,
Tuple[float]]] = None,
mode: str = 'nearest',
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None):
"""Rewrite `interpolate` for rknn backend.
rknn require `size` should be constant in ONNX Node. We use `scale_factor`
instead of `size` to avoid dynamic size.
"""
input_size = input.shape
if scale_factor is None:
scale_factor = [(s_out / s_in)
for s_out, s_in in zip(size, input_size[2:])]
if isinstance(scale_factor[0], torch.Tensor):
scale_factor = [i.item() for i in scale_factor]
return ctx.origin_func(
input,
None,
scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor)
@FUNCTION_REWRITER.register_rewriter(
'torch.nn.functional.interpolate',
is_pytorch=True,

View File

@ -1821,6 +1821,80 @@ def test_base_dense_head_predict_by_feat__ncnn():
assert rewrite_outputs.shape[-1] == 6
@backend_checker(Backend.RKNN)
def test_base_dense_head_get_bboxes__rknn():
"""Test get_bboxes rewrite of ssd head for rknn."""
ssd_head = get_ssd_head_model()
ssd_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['output']
input_names = []
for i in range(6):
input_names.append('cls_scores_' + str(i))
input_names.append('bbox_preds_' + str(i))
dynamic_axes = None
deploy_cfg = mmengine.Config(
dict(
backend_config=dict(type=Backend.RKNN.value),
onnx_config=dict(
input_names=input_names,
output_names=output_names,
input_shape=None,
dynamic_axes=dynamic_axes),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
model_type='rknn',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))))
# For the ssd_head:
# the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10),
# (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1)
# the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10),
# (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1)
feat_shape = [20, 10, 5, 3, 2, 1]
num_prior = 6
seed_everything(1234)
cls_score = [
torch.rand(1, 30, feat_shape[i], feat_shape[i])
for i in range(num_prior)
]
seed_everything(5678)
bboxes = [
torch.rand(1, 24, feat_shape[i], feat_shape[i])
for i in range(num_prior)
]
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = [s, s]
wrapped_model = WrapModel(
ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=False)
# output should be of shape [1, N, 4]
assert rewrite_outputs[0].shape[-1] == 4
@pytest.mark.parametrize('backend_type, ir_type', [(Backend.OPENVINO, 'onnx')])
def test_reppoints_head_predict_by_feat(backend_type: Backend, ir_type: str):
"""Test predict_by_feat rewrite of base dense head."""

View File

@ -119,6 +119,30 @@ def test_interpolate_static():
assert np.allclose(model_output, rewrite_output[0], rtol=1e-03, atol=1e-05)
@backend_checker(Backend.RKNN)
def test_interpolate__rknn():
input = torch.rand([1, 2, 2, 2])
model_output = F.interpolate(input, scale_factor=[2, 2])
def interpolate_caller(*arg, **kwargs):
return F.interpolate(*arg, **kwargs)
deploy_cfg = Config(
dict(
onnx_config=dict(input_shape=None),
backend_config=dict(type='rknn', model_inputs=None),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
wrapped_func = WrapFunction(interpolate_caller, size=[4, 4])
rewrite_output, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={'input': input},
deploy_cfg=deploy_cfg,
run_with_backend=False)
assert np.allclose(model_output, rewrite_output[0], rtol=1e-03, atol=1e-05)
@backend_checker(Backend.NCNN)
def test_linear_ncnn():
input = torch.rand([1, 2, 2])