Support single stage rotated detector in MMRotate (#428)
* fix lint
* fix lint
* add mmrotate part
* update
* update
* fix
* remove init_detector
* success run with bs=1
* nms_rotated support batch
* support [batch_id, class_id, box_id]
* fix
* fix
* Create test_mmrotate_core.py
* add ut
* add ut
* Update nms_rotated.py
* fix
* Revert "fix"
This reverts commit f792387fb4
.
* add mmrotate into requirements
* add ut
* update doc
* update
* skip test because mmcv version < 1.4.6
* update
* Update rotated-detection_static.py
* Update rotated-detection_static.py
* Update rotated-detection_static.py
* fix bug of memory leak.
* Update rotated_detection_model.py
pull/422/head
parent
76f6e253bb
commit
42dc5bc316
|
@ -0,0 +1,17 @@
|
|||
_base_ = ['./rotated-detection_onnxruntime_static.py']
|
||||
onnx_config = dict(
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'dets': {
|
||||
0: 'batch',
|
||||
1: 'num_dets',
|
||||
},
|
||||
'labels': {
|
||||
0: 'batch',
|
||||
1: 'num_dets',
|
||||
},
|
||||
}, )
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = ['./rotated-detection_static.py', '../_base_/backends/onnxruntime.py']
|
||||
|
||||
onnx_config = dict(output_names=['dets', 'labels'], input_shape=None)
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = ['../_base_/onnx_config.py']
|
||||
codebase_config = dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=3000,
|
||||
keep_top_k=2000))
|
|
@ -264,6 +264,7 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
|
|||
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
|
||||
: api_(api), ort_(api_), info_(info) {
|
||||
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
|
||||
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");
|
||||
|
||||
// create allocator
|
||||
allocator_ = Ort::AllocatorWithDefaultOptions();
|
||||
|
@ -271,6 +272,7 @@ NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
|
|||
|
||||
void NMSRotatedKernel::Compute(OrtKernelContext* context) {
|
||||
const float iou_threshold = iou_threshold_;
|
||||
const float score_threshold = score_threshold_;
|
||||
|
||||
const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* boxes_data = reinterpret_cast<const float*>(ort_.GetTensorData<float>(boxes));
|
||||
|
@ -280,67 +282,77 @@ void NMSRotatedKernel::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions boxes_dim(ort_, boxes);
|
||||
OrtTensorDimensions scores_dim(ort_, scores);
|
||||
|
||||
int64_t nboxes = boxes_dim[0];
|
||||
assert(boxes_dim[1] == 5); //(cx,cy,w,h,theta)
|
||||
// loop over batch
|
||||
int64_t nbatch = boxes_dim[0];
|
||||
int64_t nboxes = boxes_dim[1];
|
||||
int64_t nclass = scores_dim[1];
|
||||
assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta)
|
||||
|
||||
// allocate tmp memory
|
||||
float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nboxes * 5);
|
||||
float* sc = (float*)allocator_.Alloc(sizeof(float) * nboxes);
|
||||
bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nboxes);
|
||||
for (int64_t i = 0; i < nboxes; i++) {
|
||||
select[i] = true;
|
||||
}
|
||||
float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5);
|
||||
float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes);
|
||||
bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes);
|
||||
|
||||
memcpy(tmp_boxes, boxes_data, sizeof(float) * nboxes * 5);
|
||||
memcpy(sc, scores_data, sizeof(float) * nboxes);
|
||||
memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5);
|
||||
memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes);
|
||||
|
||||
// sort scores
|
||||
std::vector<float> tmp_sc;
|
||||
for (int i = 0; i < nboxes; i++) {
|
||||
tmp_sc.push_back(sc[i]);
|
||||
}
|
||||
std::vector<int64_t> order(tmp_sc.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(),
|
||||
[&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; });
|
||||
|
||||
for (int64_t _i = 0; _i < nboxes; _i++) {
|
||||
if (select[_i] == false) continue;
|
||||
auto i = order[_i];
|
||||
|
||||
for (int64_t _j = _i + 1; _j < nboxes; _j++) {
|
||||
if (select[_j] == false) continue;
|
||||
auto j = order[_j];
|
||||
RotatedBox box1, box2;
|
||||
auto center_shift_x = (tmp_boxes[i * 5] + tmp_boxes[j * 5]) / 2.0;
|
||||
auto center_shift_y = (tmp_boxes[i * 5 + 1] + tmp_boxes[j * 5 + 1]) / 2.0;
|
||||
box1.x_ctr = tmp_boxes[i * 5] - center_shift_x;
|
||||
box1.y_ctr = tmp_boxes[i * 5 + 1] - center_shift_y;
|
||||
box1.w = tmp_boxes[i * 5 + 2];
|
||||
box1.h = tmp_boxes[i * 5 + 3];
|
||||
box1.a = tmp_boxes[i * 5 + 4];
|
||||
box2.x_ctr = tmp_boxes[j * 5] - center_shift_x;
|
||||
box2.y_ctr = tmp_boxes[j * 5 + 1] - center_shift_y;
|
||||
box2.w = tmp_boxes[j * 5 + 2];
|
||||
box2.h = tmp_boxes[j * 5 + 3];
|
||||
box2.a = tmp_boxes[j * 5 + 4];
|
||||
auto area1 = box1.w * box1.h;
|
||||
auto area2 = box2.w * box2.h;
|
||||
auto intersection = rotated_boxes_intersection(box1, box2);
|
||||
float baseS = 1.0;
|
||||
baseS = (area1 + area2 - intersection);
|
||||
auto ovr = intersection / baseS;
|
||||
if (ovr > iou_threshold) select[_j] = false;
|
||||
}
|
||||
}
|
||||
// std::vector<std::vector<int64_t>> res_order;
|
||||
std::vector<int64_t> res_order;
|
||||
for (int i = 0; i < nboxes; i++) {
|
||||
if (select[i]) {
|
||||
res_order.push_back(order[i]);
|
||||
}
|
||||
}
|
||||
for (int64_t k = 0; k < nbatch; k++) {
|
||||
for (int64_t g = 0; g < nclass; g++) {
|
||||
for (int64_t i = 0; i < nboxes; i++) {
|
||||
select[i] = true;
|
||||
}
|
||||
// sort scores
|
||||
std::vector<float> tmp_sc;
|
||||
for (int i = 0; i < nboxes; i++) {
|
||||
tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]);
|
||||
}
|
||||
std::vector<int64_t> order(tmp_sc.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(),
|
||||
[&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; });
|
||||
for (int64_t _i = 0; _i < nboxes; _i++) {
|
||||
if (select[_i] == false) continue;
|
||||
auto i = order[_i];
|
||||
for (int64_t _j = _i + 1; _j < nboxes; _j++) {
|
||||
if (select[_j] == false) continue;
|
||||
auto j = order[_j];
|
||||
RotatedBox box1, box2;
|
||||
auto center_shift_x =
|
||||
(tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0;
|
||||
auto center_shift_y =
|
||||
(tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0;
|
||||
box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x;
|
||||
box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y;
|
||||
box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2];
|
||||
box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3];
|
||||
box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4];
|
||||
box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x;
|
||||
box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y;
|
||||
box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2];
|
||||
box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3];
|
||||
box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4];
|
||||
auto area1 = box1.w * box1.h;
|
||||
auto area2 = box2.w * box2.h;
|
||||
auto intersection = rotated_boxes_intersection(box1, box2);
|
||||
float baseS = 1.0;
|
||||
baseS = (area1 + area2 - intersection);
|
||||
auto ovr = intersection / baseS;
|
||||
if (ovr > iou_threshold) select[_j] = false;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < nboxes; i++) {
|
||||
if (select[i] & (tmp_sc[order[i]] > score_threshold)) {
|
||||
res_order.push_back(k);
|
||||
res_order.push_back(g);
|
||||
res_order.push_back(order[i]);
|
||||
}
|
||||
}
|
||||
} // class loop
|
||||
} // batch loop
|
||||
|
||||
std::vector<int64_t> inds_dims({(int64_t)res_order.size()});
|
||||
std::vector<int64_t> inds_dims({(int64_t)res_order.size() / 3, 3});
|
||||
|
||||
OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size());
|
||||
int64_t* res_data = ort_.GetTensorMutableData<int64_t>(res);
|
||||
|
|
|
@ -22,6 +22,7 @@ struct NMSRotatedKernel {
|
|||
const OrtKernelInfo* info_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
float iou_threshold_;
|
||||
float score_threshold_;
|
||||
};
|
||||
|
||||
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
|
||||
|
|
|
@ -1819,6 +1819,53 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
|
|||
</div>
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary style="margin-left: 25px;">MMRotate</summary>
|
||||
<div style="margin-left: 25px;">
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="center" colspan="4">MMRotate</th>
|
||||
<th align="center">Pytorch</th>
|
||||
<th align="center">ONNXRuntime</th>
|
||||
<th align="center" colspan="2">TensorRT</th>
|
||||
<th align="center">PPLNN</th>
|
||||
<th align="center">OpenVINO</th>
|
||||
<th align="left">Model Config</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td align="center">Model</td>
|
||||
<td align="center">Task</td>
|
||||
<td align="center">Dataset</td>
|
||||
<td align="center">Metrics</td>
|
||||
<td align="center">fp32</td>
|
||||
<td align="center">fp32</td>
|
||||
<td align="center">fp32</td>
|
||||
<td align="center">fp16</td>
|
||||
<td align="center">fp16</td>
|
||||
<td align="center">fp32</td>
|
||||
<td>model config file</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2">RotatedRetinaNet</td>
|
||||
<td align="center" rowspan="2">Rotated Detection</td>
|
||||
<td align="center" rowspan="2">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.698</td>
|
||||
<td align="center">0.698</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td rowspan="2">$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
|
||||
### Notes
|
||||
- As some datasets contain images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# MMRotate Support
|
||||
|
||||
[MMRotate](https://github.com/open-mmlab/mmrotate) is an open-source toolbox for rotated object detection based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project.
|
||||
|
||||
## MMRotate installation tutorial
|
||||
|
||||
Please refer to [official installation guide](https://mmrotate.readthedocs.io/en/latest/install.html) to install the codebase.
|
||||
|
||||
## MMRotate models support
|
||||
|
||||
| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
|
||||
|:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:|
|
||||
| RotatedRetinaNet | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
# convert ort
|
||||
python tools/deploy.py \
|
||||
configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \
|
||||
$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \
|
||||
$MMROTATE_DIR/checkpoints/rotated_retinanet_obb_r50_fpn_1x_dota_le135-e4131166.pth \
|
||||
$MMROTATE_DIR/demo/demo.jpg \
|
||||
--work-dir work-dirs/mmrotate/rotated_retinanet/ort \
|
||||
--device cpu
|
||||
|
||||
# compute metric
|
||||
python tools/test.py \
|
||||
configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \
|
||||
$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \
|
||||
--model work-dirs/mmrotate/rotated_retinanet/ort/end2end.onnx \
|
||||
--metrics mAP
|
||||
|
||||
# generate submit file
|
||||
python tools/test.py \
|
||||
configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \
|
||||
$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \
|
||||
--model work-dirs/mmrotate/rotated_retinanet/ort/end2end.onnx \
|
||||
--format-only \
|
||||
--metric-options submission_dir=work-dirs/mmrotate/rotated_retinanet/ort/Task1_results
|
||||
```
|
||||
|
||||
Note
|
||||
|
||||
- Usually, mmrotate models need some extra information for the input image, but we can't get it directly. So, when exporting the model, you can use `$MMROTATE_DIR/demo/demo.jpg` as input.
|
||||
|
||||
## Reminder
|
||||
|
||||
None
|
||||
|
||||
## FAQs
|
||||
|
||||
None
|
|
@ -1808,6 +1808,54 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary style="margin-left: 25px;">MMRotate</summary>
|
||||
<div style="margin-left: 25px;">
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="center" colspan="4">MMRotate</th>
|
||||
<th align="center">Pytorch</th>
|
||||
<th align="center">ONNXRuntime</th>
|
||||
<th align="center" colspan="2">TensorRT</th>
|
||||
<th align="center">PPLNN</th>
|
||||
<th align="center">OpenVINO</th>
|
||||
<th align="left">Model Config</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td align="center">Model</td>
|
||||
<td align="center">Task</td>
|
||||
<td align="center">Dataset</td>
|
||||
<td align="center">Metrics</td>
|
||||
<td align="center">fp32</td>
|
||||
<td align="center">fp32</td>
|
||||
<td align="center">fp32</td>
|
||||
<td align="center">fp16</td>
|
||||
<td align="center">fp16</td>
|
||||
<td align="center">fp32</td>
|
||||
<td>model config file</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2">RotatedRetinaNet</td>
|
||||
<td align="center" rowspan="2">Rotated Detection</td>
|
||||
<td align="center" rowspan="2">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.698</td>
|
||||
<td align="center">0.698</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td rowspan="2">$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
|
||||
### 注意
|
||||
- 由于某些数据集在代码库中包含各种分辨率的图像,例如 MMDet,速度基准是通过 MMDeploy 中的静态配置获得的,而性能基准是通过动态配置获得的。
|
||||
|
||||
|
|
|
@ -4,7 +4,10 @@ import importlib
|
|||
from mmdeploy.utils import Codebase
|
||||
from .base import BaseTask, MMCodebase, get_codebase_class
|
||||
|
||||
extra_dependent_library = {Codebase.MMOCR: ['mmdet']}
|
||||
extra_dependent_library = {
|
||||
Codebase.MMOCR: ['mmdet'],
|
||||
Codebase.MMROTATE: ['mmdet']
|
||||
}
|
||||
|
||||
|
||||
def import_codebase(codebase: Codebase):
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .core import * # noqa: F401,F403
|
||||
from .deploy import * # noqa: F401,F403
|
||||
from .models import * # noqa: F401,F403
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bbox import * # noqa: F401,F403
|
||||
from .post_processing import * # noqa: F401,F403
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .delta_xywha_rbbox_coder import * # noqa: F401,F403
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmrotate.core import norm_angle
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmrotate.core.bbox.coder.delta_xywha_rbbox_coder.delta2bbox',
|
||||
backend='default')
|
||||
def delta2bbox(ctx,
|
||||
rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1., 1.),
|
||||
max_shape=None,
|
||||
wh_ratio_clip=16 / 1000,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32,
|
||||
angle_range='oc',
|
||||
norm_factor=None,
|
||||
edge_swap=False,
|
||||
proj_xy=False):
|
||||
"""Rewrite `delta2bbox` for default backend.
|
||||
|
||||
Support batch bbox decoder.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
rois (Tensor): Boxes to be transformed. Has shape (N, 5).
|
||||
deltas (Tensor): Encoded offsets relative to each roi.
|
||||
Has shape (N, num_classes * 5) or (N, 5). Note
|
||||
N = num_base_anchors * W * H, when rois is a grid of
|
||||
anchors. Offset encoding follows [1]_.
|
||||
means (Sequence[float]): Denormalizing means for delta coordinates.
|
||||
Default (0., 0., 0., 0., 0.).
|
||||
stds (Sequence[float]): Denormalizing standard deviation for delta
|
||||
coordinates. Default (1., 1., 1., 1., 1.).
|
||||
max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
|
||||
(H, W). Default None.
|
||||
wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
|
||||
16 / 1000.
|
||||
add_ctr_clamp (bool): Whether to add center clamp. When set to True,
|
||||
the center of the prediction bounding box will be clamped to
|
||||
avoid being too far away from the center of the anchor.
|
||||
Only used by YOLOF. Default False.
|
||||
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
|
||||
Default 32.
|
||||
angle_range (str, optional): Angle representations. Defaults to 'oc'.
|
||||
norm_factor (None|float, optional): Regularization factor of angle.
|
||||
edge_swap (bool, optional): Whether swap the edge if w < h.
|
||||
Defaults to False.
|
||||
proj_xy (bool, optional): Whether project x and y according to angle.
|
||||
Defaults to False.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (N, num_classes * 5) or (N, 5),
|
||||
where 5 represent cx, cy, w, h, angle.
|
||||
"""
|
||||
means = deltas.new_tensor(means).view(1, -1)
|
||||
stds = deltas.new_tensor(stds).view(1, -1)
|
||||
delta_shape = deltas.shape
|
||||
reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 5))
|
||||
denorm_deltas = reshaped_deltas * stds + means
|
||||
|
||||
dx = denorm_deltas[..., 0]
|
||||
dy = denorm_deltas[..., 1]
|
||||
dw = denorm_deltas[..., 2]
|
||||
dh = denorm_deltas[..., 3]
|
||||
da = denorm_deltas[..., 4]
|
||||
if norm_factor:
|
||||
da *= norm_factor * np.pi
|
||||
# Compute center of each roi
|
||||
|
||||
px = rois[..., None, 0]
|
||||
py = rois[..., None, 1]
|
||||
# Compute width/height of each roi
|
||||
pw = rois[..., None, 2]
|
||||
ph = rois[..., None, 3]
|
||||
# Compute rotated angle of each roi
|
||||
pa = rois[..., None, 4]
|
||||
dx_width = pw * dx
|
||||
dy_height = ph * dy
|
||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||
if add_ctr_clamp:
|
||||
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
|
||||
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
|
||||
dw = torch.clamp(dw, max=max_ratio)
|
||||
dh = torch.clamp(dh, max=max_ratio)
|
||||
else:
|
||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
||||
# Use exp(network energy) to enlarge/shrink each roi
|
||||
gw = pw * dw.exp()
|
||||
gh = ph * dh.exp()
|
||||
# Use network energy to shift the center of each roi
|
||||
if proj_xy:
|
||||
gx = dx * pw * torch.cos(pa) - dy * ph * torch.sin(pa) + px
|
||||
gy = dx * pw * torch.sin(pa) + dy * ph * torch.cos(pa) + py
|
||||
else:
|
||||
gx = px + dx_width
|
||||
gy = py + dy_height
|
||||
# Compute angle
|
||||
ga = norm_angle(pa + da, angle_range)
|
||||
if max_shape is not None:
|
||||
gx = gx.clamp(min=0, max=max_shape[1] - 1)
|
||||
gy = gy.clamp(min=0, max=max_shape[0] - 1)
|
||||
|
||||
if edge_swap:
|
||||
w_regular = torch.where(gw > gh, gw, gh)
|
||||
h_regular = torch.where(gw > gh, gh, gw)
|
||||
theta_regular = torch.where(gw > gh, ga, ga + np.pi / 2)
|
||||
theta_regular = norm_angle(theta_regular, angle_range)
|
||||
return torch.stack([gx, gy, w_regular, h_regular, theta_regular],
|
||||
dim=-1).view_as(deltas)
|
||||
else:
|
||||
return torch.stack([gx, gy, gw, gh, ga], dim=-1).view(deltas.size())
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bbox_nms import multiclass_nms_rotated
|
||||
|
||||
__all__ = ['multiclass_nms_rotated']
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.mmcv.ops import ONNXNMSRotatedOp
|
||||
|
||||
|
||||
def select_nms_index(scores: torch.Tensor,
|
||||
boxes: torch.Tensor,
|
||||
nms_index: torch.Tensor,
|
||||
batch_size: int,
|
||||
keep_top_k: int = -1):
|
||||
"""Transform NMSRotated output.
|
||||
|
||||
Args:
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_classes, num_boxes].
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 6].
|
||||
nms_index (Tensor): NMS output of bounding boxes indexing.
|
||||
batch_size (int): Batch size of the input image.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 6]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1]
|
||||
box_inds = nms_index[:, 2]
|
||||
|
||||
# index by nms output
|
||||
scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1)
|
||||
boxes = boxes[batch_inds, box_inds, ...]
|
||||
dets = torch.cat([boxes, scores], dim=1)
|
||||
|
||||
# batch all
|
||||
batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
batch_template = torch.arange(
|
||||
0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device)
|
||||
batched_dets = batched_dets.where(
|
||||
(batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1),
|
||||
batched_dets.new_zeros(1))
|
||||
|
||||
batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1)
|
||||
batched_labels = batched_labels.where(
|
||||
(batch_inds == batch_template.unsqueeze(1)),
|
||||
batched_labels.new_ones(1) * -1)
|
||||
|
||||
N = batched_dets.shape[0]
|
||||
|
||||
# expand tensor to eliminate [0, ...] tensor
|
||||
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 6))),
|
||||
1)
|
||||
batched_labels = torch.cat((batched_labels, batched_labels.new_zeros(
|
||||
(N, 1))), 1)
|
||||
|
||||
# sort
|
||||
is_use_topk = keep_top_k > 0 and \
|
||||
(torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
|
||||
if is_use_topk:
|
||||
_, topk_inds = batched_dets[:, :, -1].topk(keep_top_k, dim=1)
|
||||
else:
|
||||
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
|
||||
topk_batch_inds = torch.arange(
|
||||
batch_size, dtype=topk_inds.dtype,
|
||||
device=topk_inds.device).view(-1, 1).expand_as(topk_inds)
|
||||
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
|
||||
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]
|
||||
|
||||
# slice and recover the tensor
|
||||
return batched_dets, batched_labels
|
||||
|
||||
|
||||
def multiclass_nms_rotated(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
iou_threshold: float = 0.1,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = -1):
|
||||
"""NMSRotated for multi-class bboxes.
|
||||
|
||||
This function helps exporting to onnx with batch and multiclass NMSRotated
|
||||
op. It only supports class-agnostic detection results. That is, the scores
|
||||
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
|
||||
(N, num_boxes, 5).
|
||||
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
||||
score_threshold (float): bbox threshold, bboxes with scores lower than
|
||||
it will not be considered.
|
||||
pre_top_k (int): Number of top K boxes to keep before nms.
|
||||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 6]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
batch_size = scores.shape[0]
|
||||
|
||||
if pre_top_k > 0:
|
||||
max_scores, _ = scores.max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_top_k)
|
||||
batch_inds = torch.arange(batch_size).view(
|
||||
-1, 1).expand_as(topk_inds).long()
|
||||
boxes = boxes[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
scores = scores.permute(0, 2, 1)
|
||||
selected_indices = ONNXNMSRotatedOp.apply(boxes, scores, iou_threshold,
|
||||
score_threshold)
|
||||
|
||||
dets, labels = select_nms_index(
|
||||
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
|
||||
|
||||
return dets, labels
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .mmrotate import MMROTATE
|
||||
from .rotated_detection import RotatedDetection
|
||||
|
||||
__all__ = ['MMROTATE', 'RotatedDetection']
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.utils import Registry
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
|
||||
from mmdeploy.utils import Codebase, get_task_type
|
||||
|
||||
|
||||
def __build_mmrotate_task(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
||||
device: str, registry: Registry) -> BaseTask:
|
||||
task = get_task_type(deploy_cfg)
|
||||
return registry.module_dict[task.value](model_cfg, deploy_cfg, device)
|
||||
|
||||
|
||||
MMROTATE_TASK = Registry('mmrotate_tasks', build_func=__build_mmrotate_task)
|
||||
|
||||
|
||||
@CODEBASE.register_module(Codebase.MMROTATE.value)
|
||||
class MMROTATE(MMCodebase):
|
||||
"""mmrotate codebase class."""
|
||||
|
||||
task_registry = MMROTATE_TASK
|
||||
|
||||
def __init__(self):
|
||||
super(MMROTATE, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
||||
device: str):
|
||||
"""The interface to build the task processors of mmrotate.
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmcv.Config): Model config file or loaded Config
|
||||
object.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or loaded
|
||||
Config object.
|
||||
device (str): A string specifying device type.
|
||||
|
||||
Returns:
|
||||
BaseTask: A task processor.
|
||||
"""
|
||||
return MMROTATE_TASK.build(model_cfg, deploy_cfg, device)
|
||||
|
||||
@staticmethod
|
||||
def build_dataset(dataset_cfg: Union[str, mmcv.Config],
|
||||
dataset_type: str = 'val',
|
||||
**kwargs) -> Dataset:
|
||||
"""Build dataset for mmrotate.
|
||||
|
||||
Args:
|
||||
dataset_cfg (str | mmcv.Config): The input dataset config.
|
||||
dataset_type (str): A string represents dataset type, e.g.: 'train'
|
||||
, 'test', 'val'. Defaults to 'val'.
|
||||
|
||||
Returns:
|
||||
Dataset: A PyTorch dataset.
|
||||
"""
|
||||
from mmrotate.datasets import build_dataset as build_dataset_mmrotate
|
||||
|
||||
# dataset_cfg = load_config(dataset_cfg)[0]
|
||||
assert dataset_type in dataset_cfg.data
|
||||
data_cfg = dataset_cfg.data[dataset_type]
|
||||
# in case the dataset is concatenated
|
||||
if isinstance(data_cfg, dict):
|
||||
data_cfg.test_mode = True
|
||||
samples_per_gpu = data_cfg.pop('samples_per_gpu', 1)
|
||||
if samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
|
||||
elif isinstance(data_cfg, list):
|
||||
for ds_cfg in data_cfg:
|
||||
ds_cfg.test_mode = True
|
||||
samples_per_gpu = max(
|
||||
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in data_cfg])
|
||||
if samples_per_gpu > 1:
|
||||
for ds_cfg in data_cfg:
|
||||
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
|
||||
dataset = build_dataset_mmrotate(data_cfg)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def build_dataloader(dataset: Dataset,
|
||||
samples_per_gpu: int,
|
||||
workers_per_gpu: int,
|
||||
num_gpus: int = 1,
|
||||
dist: bool = False,
|
||||
shuffle: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
drop_last: bool = False,
|
||||
persistent_workers: bool = True,
|
||||
**kwargs) -> DataLoader:
|
||||
"""Build dataloader for mmrotate.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Input dataset.
|
||||
samples_per_gpu (int): Number of training samples on each GPU, i.e.
|
||||
,batch size of each GPU.
|
||||
workers_per_gpu (int): How many subprocesses to use for data
|
||||
loading for each GPU.
|
||||
num_gpus (int): Number of GPUs. Only used in non-distributed
|
||||
training.
|
||||
dist (bool): Distributed training/test or not. Defaults to `False`.
|
||||
shuffle (bool): Whether to shuffle the data at every epoch.
|
||||
Defaults to `False`.
|
||||
seed (int): An integer set to be seed. Default is `None`.
|
||||
drop_last (bool): Whether to drop the last incomplete batch in
|
||||
epoch. Default to `False`.
|
||||
persistent_workers (bool): If `True`, the data loader will not
|
||||
shutdown the worker processes after a dataset has been
|
||||
consumed once. This allows to maintain the workers Dataset
|
||||
instances alive. The argument also has effect in
|
||||
PyTorch>=1.7.0. Default is `True`.
|
||||
kwargs: Any other keyword argument to be used to initialize
|
||||
DataLoader.
|
||||
|
||||
Returns:
|
||||
DataLoader: A PyTorch dataloader.
|
||||
"""
|
||||
from mmdet.datasets import build_dataloader as build_dataloader_mmdet
|
||||
return build_dataloader_mmdet(
|
||||
dataset,
|
||||
samples_per_gpu,
|
||||
workers_per_gpu,
|
||||
num_gpus=num_gpus,
|
||||
dist=dist,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=persistent_workers,
|
||||
**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def single_gpu_test(model: torch.nn.Module,
|
||||
data_loader: DataLoader,
|
||||
show: bool = False,
|
||||
out_dir: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""Run test with single gpu.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Input model from nn.Module.
|
||||
data_loader (DataLoader): PyTorch data loader.
|
||||
show (bool): Specifying whether to show plotted results. Defaults
|
||||
to `False`.
|
||||
out_dir (str): A directory to save results, defaults to `None`.
|
||||
|
||||
Returns:
|
||||
list: The prediction results.
|
||||
"""
|
||||
from mmdet.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs)
|
||||
return outputs
|
|
@ -0,0 +1,343 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import DataContainer, collate, scatter
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.codebase.base import BaseTask
|
||||
from mmdeploy.utils import Task, get_input_shape
|
||||
from .mmrotate import MMROTATE_TASK
|
||||
|
||||
|
||||
def process_model_config(model_cfg: mmcv.Config,
|
||||
imgs: Union[Sequence[str], Sequence[np.ndarray]],
|
||||
input_shape: Optional[Sequence[int]] = None):
|
||||
"""Process the model config.
|
||||
|
||||
Args:
|
||||
model_cfg (mmcv.Config): The model config.
|
||||
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
|
||||
data type are List[str], List[np.ndarray].
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Default: None.
|
||||
|
||||
Returns:
|
||||
mmcv.Config: the model config after processing.
|
||||
"""
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
|
||||
cfg = model_cfg.copy()
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||
# for static exporting
|
||||
if input_shape is not None:
|
||||
cfg.data.test.pipeline[1]['img_scale'] = tuple(input_shape)
|
||||
transforms = cfg.data.test.pipeline[1]['transforms']
|
||||
for trans in transforms:
|
||||
trans_type = trans['type']
|
||||
if trans_type == 'Pad':
|
||||
trans['size_divisor'] = 1
|
||||
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
return cfg
|
||||
|
||||
|
||||
@MMROTATE_TASK.register_module(Task.ROTATED_DETECTION.value)
|
||||
class RotatedDetection(BaseTask):
|
||||
"""Rotated detection task class.
|
||||
|
||||
Args:
|
||||
model_cfg (mmcv.Config): Loaded model Config object..
|
||||
deploy_cfg (mmcv.Config): Loaded deployment Config object.
|
||||
device (str): A string represents device type.
|
||||
"""
|
||||
|
||||
def __init__(self, model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
||||
device: str):
|
||||
super(RotatedDetection, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
model_files (Sequence[str]): Input model files.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized backend model.
|
||||
"""
|
||||
from .rotated_detection_model import build_rotated_detection_model
|
||||
model = build_rotated_detection_model(
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
model_checkpoint (str): The checkpoint file of torch model,
|
||||
defaults to `None`.
|
||||
cfg_options (dict): Optional config key-pair parameters.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized torch model generated by OpenMMLab
|
||||
codebases.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmdet.core import get_classes
|
||||
from mmrotate.models import build_detector
|
||||
|
||||
if isinstance(self.model_cfg, str):
|
||||
self.model_cfg = mmcv.Config.fromfile(self.model_cfg)
|
||||
elif not isinstance(self.model_cfg, mmcv.Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(self.model_cfg)}')
|
||||
if cfg_options is not None:
|
||||
self.model_cfg.merge_from_dict(cfg_options)
|
||||
self.model_cfg.model.pretrained = None
|
||||
self.model_cfg.model.train_cfg = None
|
||||
model = build_detector(
|
||||
self.model_cfg.model, test_cfg=self.model_cfg.get('test_cfg'))
|
||||
if model_checkpoint is not None:
|
||||
map_loc = 'cpu' if self.device == 'cpu' else None
|
||||
checkpoint = load_checkpoint(
|
||||
model, model_checkpoint, map_location=map_loc)
|
||||
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn('Class names are not saved in the checkpoint\'s '
|
||||
'meta data, use COCO classes by default.')
|
||||
model.CLASSES = get_classes('coco')
|
||||
model.cfg = self.model_cfg
|
||||
model.to(self.device)
|
||||
return model.eval()
|
||||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Sequence[int] = None) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for rotated object detection.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray): Input image(s), accepted data type are
|
||||
`str`, `np.ndarray`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
|
||||
data_list = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
# directly add img
|
||||
data = dict(img=img)
|
||||
else:
|
||||
# add information into dict
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
# get tensor from list to stack for batch mode (rotated detection)
|
||||
data_list.append(data)
|
||||
|
||||
if isinstance(data_list[0]['img'], list) and len(data_list) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(data_list)}')
|
||||
|
||||
data = collate(data_list, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
if isinstance(data['img_metas'], list):
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img.data for img in data['img']]
|
||||
if isinstance(data['img'][0], list):
|
||||
data['img'] = [img[0] for img in data['img']]
|
||||
else:
|
||||
data['img'] = data['img'].data
|
||||
|
||||
if self.device != 'cpu':
|
||||
data = scatter(data, [self.device])[0]
|
||||
|
||||
return data, data['img']
|
||||
|
||||
def visualize(self,
|
||||
model: nn.Module,
|
||||
image: Union[str, np.ndarray],
|
||||
result: list,
|
||||
output_file: str,
|
||||
window_name: str = '',
|
||||
show_result: bool = False):
|
||||
"""Visualize predictions of a model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
image (str | np.ndarray): Input image to draw predictions on.
|
||||
result (list): A list of predictions.
|
||||
output_file (str): Output file to save drawn image.
|
||||
window_name (str): The name of visualization window. Defaults to
|
||||
an empty string.
|
||||
show_result (bool): Whether to show result in windows, defaults
|
||||
to `False`.
|
||||
"""
|
||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||
output_file = None if show_result else output_file
|
||||
model.show_result(
|
||||
show_img,
|
||||
result,
|
||||
out_file=output_file,
|
||||
win_name=window_name,
|
||||
show=show_result)
|
||||
|
||||
@staticmethod
|
||||
def run_inference(model: nn.Module,
|
||||
model_inputs: Dict[str, torch.Tensor]) -> list:
|
||||
"""Run inference once for a segmentation model of mmseg.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
model_inputs (dict): A dict containing model inputs tensor and
|
||||
meta info.
|
||||
|
||||
Returns:
|
||||
list: The predictions of model inference.
|
||||
"""
|
||||
return model(**model_inputs, return_loss=False, rescale=True)
|
||||
|
||||
@staticmethod
|
||||
def get_partition_cfg(partition_type: str) -> Dict:
|
||||
"""Get a certain partition config.
|
||||
|
||||
Args:
|
||||
partition_type (str): A string specifying partition type.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of partition config.
|
||||
"""
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_from_input(input_data: Dict[str, Any]) -> torch.Tensor:
|
||||
"""Get input tensor from input data.
|
||||
|
||||
Args:
|
||||
input_data (dict): Input data containing meta info and image
|
||||
tensor.
|
||||
Returns:
|
||||
torch.Tensor: An image in `Tensor`.
|
||||
"""
|
||||
if isinstance(input_data['img'], DataContainer):
|
||||
return input_data['img'].data[0]
|
||||
return input_data['img'][0]
|
||||
|
||||
@staticmethod
|
||||
def evaluate_outputs(model_cfg,
|
||||
outputs: Sequence,
|
||||
dataset: Dataset,
|
||||
metrics: Optional[str] = None,
|
||||
out: Optional[str] = None,
|
||||
metric_options: Optional[dict] = None,
|
||||
format_only: bool = False,
|
||||
log_file: Optional[str] = None):
|
||||
"""Perform post-processing to predictions of model.
|
||||
|
||||
Args:
|
||||
outputs (Sequence): A list of predictions of model inference.
|
||||
dataset (Dataset): Input dataset to run test.
|
||||
model_cfg (mmcv.Config): The model config.
|
||||
metrics (str): Evaluation metrics, which depends on
|
||||
the codebase and the dataset, e.g., "mAP" for rotated
|
||||
detection.
|
||||
out (str): Output result file in pickle format, defaults to `None`.
|
||||
metric_options (dict): Custom options for evaluation, will be
|
||||
kwargs for dataset.evaluate() function. Defaults to `None`.
|
||||
format_only (bool): Format the output results without perform
|
||||
evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server. Defaults
|
||||
to `False`.
|
||||
log_file (str | None): The file to write the evaluation results.
|
||||
Defaults to `None` and the results will only print on stdout.
|
||||
"""
|
||||
from mmcv.utils import get_logger
|
||||
logger = get_logger('test', log_file=log_file)
|
||||
|
||||
if out:
|
||||
logger.debug(f'writing results to {out}')
|
||||
mmcv.dump(outputs, out)
|
||||
kwargs = {} if metric_options is None else metric_options
|
||||
if format_only:
|
||||
dataset.format_results(outputs, **kwargs)
|
||||
if metrics:
|
||||
eval_kwargs = model_cfg.get('evaluation', {}).copy()
|
||||
# hard-code way to remove EvalHook args
|
||||
for key in [
|
||||
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
||||
'rule'
|
||||
]:
|
||||
eval_kwargs.pop(key, None)
|
||||
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
||||
logger.info(dataset.evaluate(outputs, **eval_kwargs))
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
Return:
|
||||
dict: Composed of the preprocess information.
|
||||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.data.test.pipeline
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
"""Get the postprocess information for SDK.
|
||||
|
||||
Return:
|
||||
dict: Composed of the postprocess information.
|
||||
"""
|
||||
postprocess = self.model_cfg.model.bbox_head
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""Get the model name.
|
||||
|
||||
Return:
|
||||
str: the name of the model.
|
||||
"""
|
||||
assert 'type' in self.model_cfg.model, 'model config contains no type'
|
||||
name = self.model_cfg.model.type.lower()
|
||||
return name
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.utils import Registry
|
||||
from mmrotate.models.detectors import RotatedBaseDetector
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.codebase.mmdet.deploy.object_detection_model import \
|
||||
get_classes_from_config
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
load_config)
|
||||
|
||||
|
||||
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
||||
return registry.module_dict[cls_name](*args, **kwargs)
|
||||
|
||||
|
||||
__BACKEND_MODEL = mmcv.utils.Registry(
|
||||
'backend_rotated_detectors', build_func=__build_backend_model)
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('end2end')
|
||||
class End2EndModel(BaseBackendModel):
|
||||
"""End to end model for inference of rotated detection.
|
||||
|
||||
Args:
|
||||
backend (Backend): The backend enum, specifying backend type.
|
||||
backend_files (Sequence[str]): Paths to all required backend files(e.g.
|
||||
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
||||
class_names (Sequence[str]): A list of string specifying class names.
|
||||
device (str): A string represents device type.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or loaded Config
|
||||
object.
|
||||
model_cfg (str | mmcv.Config): Model config file or loaded Config
|
||||
object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
class_names: Sequence[str],
|
||||
device: str,
|
||||
deploy_cfg: Union[str, mmcv.Config] = None,
|
||||
model_cfg: Union[str, mmcv.Config] = None,
|
||||
):
|
||||
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
||||
self.CLASSES = class_names
|
||||
self.deploy_cfg = deploy_cfg
|
||||
self.device = device
|
||||
self.show_score = False
|
||||
self._init_wrapper(
|
||||
backend=backend, backend_files=backend_files, device=device)
|
||||
|
||||
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
|
||||
device: str):
|
||||
"""Initialize the wrapper of backends.
|
||||
|
||||
Args:
|
||||
backend (Backend): The backend enum, specifying backend type.
|
||||
backend_files (Sequence[str]): Paths to all required backend files
|
||||
(e.g. .onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
||||
device (str): A string represents device type.
|
||||
"""
|
||||
output_names = self.output_names
|
||||
self.wrapper = BaseBackendModel._build_wrapper(
|
||||
backend=backend,
|
||||
backend_files=backend_files,
|
||||
device=device,
|
||||
input_names=[self.input_name],
|
||||
output_names=output_names,
|
||||
deploy_cfg=self.deploy_cfg)
|
||||
|
||||
def forward(self, img: Sequence[torch.Tensor],
|
||||
img_metas: Sequence[Sequence[dict]], *args, **kwargs) -> list:
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
img (Sequence[torch.Tensor]): A list contains input image(s)
|
||||
in [N x C x H x W] format.
|
||||
img_metas (Sequence[Sequence[dict]]): A list of meta info for
|
||||
image(s).
|
||||
|
||||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
input_img = img[0].contiguous()
|
||||
img_metas = img_metas[0]
|
||||
outputs = self.forward_test(input_img, img_metas, *args, **kwargs)
|
||||
batch_dets, batch_labels = outputs[:2]
|
||||
batch_size = input_img.shape[0]
|
||||
rescale = kwargs.get('rescale', False)
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(batch_size):
|
||||
dets, labels = batch_dets[i], batch_labels[i]
|
||||
if rescale:
|
||||
scale_factor = img_metas[i]['scale_factor']
|
||||
|
||||
if isinstance(scale_factor, (list, tuple, np.ndarray)):
|
||||
assert len(scale_factor) == 4
|
||||
scale_factor = np.array(scale_factor)[None, :] # [1,4]
|
||||
scale_factor = torch.from_numpy(scale_factor).to(
|
||||
device=dets.device)
|
||||
dets[:, :4] /= scale_factor
|
||||
dets = dets.cpu().numpy()
|
||||
labels = labels.cpu().numpy()
|
||||
dets_results = [
|
||||
dets[labels == i, :] for i in range(len(self.CLASSES))
|
||||
]
|
||||
results.append(dets_results)
|
||||
|
||||
return results
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
|
||||
List[torch.Tensor]:
|
||||
"""The interface for forward test.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]: A list of predictions of input images.
|
||||
"""
|
||||
outputs = self.wrapper({self.input_name: imgs})
|
||||
outputs = self.wrapper.output_to_list(outputs)
|
||||
return outputs
|
||||
|
||||
def show_result(self,
|
||||
img: np.ndarray,
|
||||
result: dict,
|
||||
win_name: str = '',
|
||||
show: bool = True,
|
||||
score_thr: float = 0.3,
|
||||
out_file: str = None):
|
||||
"""Show predictions of segmentation.
|
||||
Args:
|
||||
img: (np.ndarray): Input image to draw predictions.
|
||||
result (dict): A dict of predictions.
|
||||
win_name (str): The name of visualization window.
|
||||
show (bool): Whether to show plotted image in windows. Defaults to
|
||||
`True`.
|
||||
score_thr: (float): The thresh of score. Defaults to `0.3`.
|
||||
out_file (str): Output image file to save drawn predictions.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Drawn image, only if not `show` or `out_file`.
|
||||
"""
|
||||
return RotatedBaseDetector.show_result(
|
||||
self,
|
||||
img,
|
||||
result,
|
||||
score_thr=score_thr,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
out_file=out_file)
|
||||
|
||||
|
||||
def build_rotated_detection_model(model_files: Sequence[str],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
device: str, **kwargs):
|
||||
"""Build rotated detection model for different backends.
|
||||
|
||||
Args:
|
||||
model_files (Sequence[str]): Input model file(s).
|
||||
model_cfg (str | mmcv.Config): Input model config file or Config
|
||||
object.
|
||||
deploy_cfg (str | mmcv.Config): Input deployment config file or
|
||||
Config object.
|
||||
device (str): Device to input model.
|
||||
|
||||
Returns:
|
||||
BaseBackendModel: Rotated detector for a configured backend.
|
||||
"""
|
||||
# load cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
|
||||
backend = get_backend(deploy_cfg)
|
||||
class_names = get_classes_from_config(model_cfg)
|
||||
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
|
||||
|
||||
backend_rotated_detector = __BACKEND_MODEL.build(
|
||||
model_type,
|
||||
backend=backend,
|
||||
backend_files=model_files,
|
||||
class_names=class_names,
|
||||
device=device,
|
||||
deploy_cfg=deploy_cfg,
|
||||
model_cfg=model_cfg,
|
||||
**kwargs)
|
||||
|
||||
return backend_rotated_detector
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .rotated_anchor_head import rotated_anchor_head__get_bbox
|
||||
from .single_stage_rotated_detector import \
|
||||
single_stage_rotated_detector__simple_test
|
||||
|
||||
__all__ = [
|
||||
'single_stage_rotated_detector__simple_test',
|
||||
'rotated_anchor_head__get_bbox'
|
||||
]
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
pad_with_value_if_necessary)
|
||||
from mmdeploy.codebase.mmrotate.core.post_processing import \
|
||||
multiclass_nms_rotated
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmrotate.models.dense_heads.rotated_anchor_head.'
|
||||
'RotatedAnchorHead.get_bboxes')
|
||||
def rotated_anchor_head__get_bbox(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas=None,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of `RotatedAnchorHead` for default backend.
|
||||
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
cls_scores (list[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
, 1., num_priors * 5, H, W).
|
||||
img_metas (list[dict], Optional): Image meta info. Default None.
|
||||
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default True.
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
||||
mlvl_priors = self.anchor_generator.grid_priors(
|
||||
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
assert img_metas is not None
|
||||
img_shape = img_metas[0]['img_shape']
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
|
||||
batch_size = cls_scores[0].shape[0]
|
||||
cfg = self.test_cfg
|
||||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_valid_scores = []
|
||||
mlvl_valid_priors = []
|
||||
|
||||
for cls_score, bbox_pred, priors in zip(mlvl_cls_scores, mlvl_bbox_preds,
|
||||
mlvl_priors):
|
||||
|
||||
scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
|
||||
self.cls_out_channels)
|
||||
if self.use_sigmoid_cls:
|
||||
scores = scores.sigmoid()
|
||||
else:
|
||||
scores = scores.softmax(-1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 5)
|
||||
if not is_dynamic_flag:
|
||||
priors = priors.data
|
||||
priors = priors.expand(batch_size, -1, priors.size(-1))
|
||||
if pre_topk > 0:
|
||||
priors = pad_with_value_if_necessary(priors, 1, pre_topk)
|
||||
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
|
||||
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
|
||||
|
||||
nms_pre_score = scores
|
||||
|
||||
# Get maximum scores for foreground classes.
|
||||
if self.use_sigmoid_cls:
|
||||
max_scores, _ = nms_pre_score.max(-1)
|
||||
else:
|
||||
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
batch_inds = torch.arange(
|
||||
batch_size,
|
||||
device=bbox_pred.device).view(-1, 1).expand_as(topk_inds)
|
||||
priors = priors[batch_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_valid_scores.append(scores)
|
||||
mlvl_valid_priors.append(priors)
|
||||
|
||||
batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
|
||||
batch_priors = torch.cat(mlvl_valid_priors, dim=1)
|
||||
batch_bboxes = self.bbox_coder.decode(
|
||||
batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape)
|
||||
|
||||
if not self.use_sigmoid_cls:
|
||||
batch_scores = batch_scores[..., :self.num_classes]
|
||||
|
||||
if not with_nms:
|
||||
return batch_bboxes, batch_scores
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
return multiclass_nms_rotated(
|
||||
batch_bboxes,
|
||||
batch_scores,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmrotate.models.detectors.RotatedSingleStageDetector'
|
||||
'.simple_test')
|
||||
def single_stage_rotated_detector__simple_test(ctx,
|
||||
self,
|
||||
img,
|
||||
img_metas,
|
||||
rescale=False):
|
||||
"""Rewrite `simple_test` of RotatedSingleStageDetector for default backend.
|
||||
|
||||
Rewrite this function to early return the results to avoid post processing.
|
||||
The process is not suitable for exporting to backends and better get
|
||||
implemented in SDK.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the class
|
||||
SingleStageTextDetector.
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
|
||||
Returns:
|
||||
outs (Tensor): A feature map output from bbox_head. The tensor shape
|
||||
(N, C, H, W).
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head(x)
|
||||
outs = self.bbox_head.get_bboxes(*outs, img_metas, rescale=rescale)
|
||||
|
||||
return outs
|
|
@ -7,31 +7,53 @@ class ONNXNMSRotatedOp(torch.autograd.Function):
|
|||
"""Create onnx::NMSRotated op."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, boxes: Tensor, scores: Tensor,
|
||||
iou_threshold: float) -> Tensor:
|
||||
def forward(ctx, boxes: Tensor, scores: Tensor, iou_threshold: float,
|
||||
score_threshold: float) -> Tensor:
|
||||
"""Get NMS rotated output indices.
|
||||
|
||||
Args:
|
||||
ctx (Context): The context with meta information.
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
[N, num_classes, num_boxes].
|
||||
iou_threshold (float): IOU threshold of nms.
|
||||
score_threshold (float): bbox threshold, bboxes with scores
|
||||
lower than it will not be considered.
|
||||
|
||||
Returns:
|
||||
Tensor: Selected indices of boxes.
|
||||
"""
|
||||
from mmcv.utils import ext_loader
|
||||
ext_module = ext_loader.load_ext('_ext', ['nms_rotated'])
|
||||
batch_size, num_class, _ = scores.shape
|
||||
|
||||
_, order = scores.sort(0, descending=True)
|
||||
dets_sorted = boxes.index_select(0, order)
|
||||
keep_inds = ext_module.nms_rotated(boxes, scores, order, dets_sorted,
|
||||
iou_threshold, 0)
|
||||
return keep_inds
|
||||
indices = []
|
||||
for batch_id in range(batch_size):
|
||||
for cls_id in range(num_class):
|
||||
_boxes = boxes[batch_id, ...]
|
||||
# score_threshold=0 requires scores to be contiguous
|
||||
_scores = scores[batch_id, cls_id, ...].contiguous()
|
||||
valid_mask = _scores > score_threshold
|
||||
_boxes, _scores = _boxes[valid_mask], _scores[valid_mask]
|
||||
valid_inds = torch.nonzero(
|
||||
valid_mask, as_tuple=False).squeeze(dim=1)
|
||||
_, order = _scores.sort(0, descending=True)
|
||||
dets_sorted = _boxes.index_select(0, order)
|
||||
box_inds = ext_module.nms_rotated(_boxes, _scores, order,
|
||||
dets_sorted, iou_threshold,
|
||||
0)
|
||||
box_inds = valid_inds[box_inds]
|
||||
batch_inds = torch.zeros_like(box_inds) + batch_id
|
||||
cls_inds = torch.zeros_like(box_inds) + cls_id
|
||||
indices.append(
|
||||
torch.stack([batch_inds, cls_inds, box_inds], dim=-1))
|
||||
|
||||
indices = torch.cat(indices)
|
||||
return indices
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, boxes: Tensor, scores: Tensor, iou_threshold: float):
|
||||
def symbolic(g, boxes: Tensor, scores: Tensor, iou_threshold: float,
|
||||
score_threshold: float):
|
||||
"""Symbolic function for onnx::NMSRotated.
|
||||
|
||||
Args:
|
||||
|
@ -40,6 +62,8 @@ class ONNXNMSRotatedOp(torch.autograd.Function):
|
|||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
iou_threshold (float): IOU threshold of nms.
|
||||
score_threshold (float): bbox threshold, bboxes with scores
|
||||
lower than it will not be considered.
|
||||
|
||||
Returns:
|
||||
NMSRotated op for onnx.
|
||||
|
@ -48,4 +72,5 @@ class ONNXNMSRotatedOp(torch.autograd.Function):
|
|||
'mmdeploy::NMSRotated',
|
||||
boxes,
|
||||
scores,
|
||||
iou_threshold_f=float(iou_threshold))
|
||||
iou_threshold_f=float(iou_threshold),
|
||||
score_threshold_f=float(score_threshold))
|
||||
|
|
|
@ -26,6 +26,7 @@ class Task(AdvancedEnum):
|
|||
INSTANCE_SEGMENTATION = 'InstanceSegmentation'
|
||||
VOXEL_DETECTION = 'VoxelDetection'
|
||||
POSE_DETECTION = 'PoseDetection'
|
||||
ROTATED_DETECTION = 'RotatedDetection'
|
||||
|
||||
|
||||
class Codebase(AdvancedEnum):
|
||||
|
@ -37,6 +38,7 @@ class Codebase(AdvancedEnum):
|
|||
MMEDIT = 'mmedit'
|
||||
MMDET3D = 'mmdet3d'
|
||||
MMPOSE = 'mmpose'
|
||||
MMROTATE = 'mmrotate'
|
||||
|
||||
|
||||
class IR(AdvancedEnum):
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
359.0 663.0 369.0 497.0 543.0 509.0 531.0 677.0 plane 0
|
||||
540.0 884.0 363.0 862.0 392.0 674.0 570.0 695.0 plane 0
|
||||
788.0 844.0 734.0 701.0 916.0 631.0 970.0 762.0 plane 0
|
||||
720.0 726.0 668.0 583.0 852.0 494.0 913.0 636.0 plane 0
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
model = dict(
|
||||
type='RotatedRetinaNet',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=1,
|
||||
zero_init_residual=False,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=True,
|
||||
style='pytorch',
|
||||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
start_level=1,
|
||||
add_extra_convs='on_input',
|
||||
num_outs=5),
|
||||
bbox_head=dict(
|
||||
type='RotatedRetinaHead',
|
||||
num_classes=15,
|
||||
in_channels=256,
|
||||
stacked_convs=4,
|
||||
feat_channels=256,
|
||||
assign_by_circumhbbox=None,
|
||||
anchor_generator=dict(
|
||||
type='RotatedAnchorGenerator',
|
||||
octave_base_scale=4,
|
||||
scales_per_octave=3,
|
||||
ratios=[1.0, 0.5, 2.0],
|
||||
strides=[8, 16, 32, 64, 128]),
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHAOBBoxCoder',
|
||||
angle_range='le135',
|
||||
norm_factor=1,
|
||||
edge_swap=False,
|
||||
proj_xy=True,
|
||||
target_means=(0.0, 0.0, 0.0, 0.0, 0.0),
|
||||
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
|
||||
loss_cls=dict(
|
||||
type='FocalLoss',
|
||||
use_sigmoid=True,
|
||||
gamma=2.0,
|
||||
alpha=0.25,
|
||||
loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
train_cfg=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.5,
|
||||
neg_iou_thr=0.4,
|
||||
min_pos_iou=0,
|
||||
ignore_iof_thr=-1,
|
||||
iou_calculator=dict(type='RBboxOverlaps2D')),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
debug=False),
|
||||
test_cfg=dict(
|
||||
nms_pre=2000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(iou_thr=0.1),
|
||||
max_per_img=2000))
|
||||
# dataset settings
|
||||
dataset_type = 'DOTADataset'
|
||||
data_root = '.'
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(1024, 1024),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='RResize'),
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file='tests/test_codebase/test_mmrotate/data/dota_sample/',
|
||||
img_prefix=data_root,
|
||||
pipeline=test_pipeline,
|
||||
version='le135'))
|
|
@ -0,0 +1,115 @@
|
|||
{
|
||||
"type": "RotatedRetinaNet",
|
||||
"backbone": {
|
||||
"type": "ResNet",
|
||||
"depth": 50,
|
||||
"num_stages": 4,
|
||||
"out_indices": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3
|
||||
],
|
||||
"frozen_stages": 1,
|
||||
"norm_cfg": {
|
||||
"type": "BN",
|
||||
"requires_grad": true
|
||||
},
|
||||
"norm_eval": true,
|
||||
"style": "pytorch",
|
||||
"init_cfg": {
|
||||
"type": "Pretrained",
|
||||
"checkpoint": "torchvision://resnet50"
|
||||
}
|
||||
},
|
||||
"neck": {
|
||||
"type": "FPN",
|
||||
"in_channels": [
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048
|
||||
],
|
||||
"out_channels": 256,
|
||||
"start_level": 1,
|
||||
"add_extra_convs": "on_input",
|
||||
"num_outs": 5
|
||||
},
|
||||
"bbox_head": {
|
||||
"type": "RotatedRetinaHead",
|
||||
"num_classes": 15,
|
||||
"in_channels": 256,
|
||||
"stacked_convs": 4,
|
||||
"feat_channels": 256,
|
||||
"anchor_generator": {
|
||||
"type": "RotatedAnchorGenerator",
|
||||
"octave_base_scale": 4,
|
||||
"scales_per_octave": 3,
|
||||
"ratios": [
|
||||
0.5,
|
||||
1.0,
|
||||
2.0
|
||||
],
|
||||
"strides": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
},
|
||||
"bbox_coder": {
|
||||
"type": "DeltaXYWHAOBBoxCoder",
|
||||
"target_means": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"target_stds": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0
|
||||
]
|
||||
},
|
||||
"loss_cls": {
|
||||
"type": "FocalLoss",
|
||||
"use_sigmoid": true,
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.25,
|
||||
"loss_weight": 1.0
|
||||
},
|
||||
"loss_bbox": {
|
||||
"type": "L1Loss",
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
},
|
||||
"train_cfg": {
|
||||
"assigner": {
|
||||
"type": "MaxIoUAssigner",
|
||||
"pos_iou_thr": 0.5,
|
||||
"neg_iou_thr": 0.4,
|
||||
"min_pos_iou": 0,
|
||||
"ignore_iof_thr": -1,
|
||||
"iou_calculator": {
|
||||
"type": "RBboxOverlaps2D"
|
||||
}
|
||||
},
|
||||
"allowed_border": -1,
|
||||
"pos_weight": -1,
|
||||
"debug": false
|
||||
},
|
||||
"test_cfg": {
|
||||
"nms_pre": 2000,
|
||||
"min_bbox_size": 0,
|
||||
"score_thr": 0.05,
|
||||
"nms": {
|
||||
"type": "nms",
|
||||
"iou_threshold": 0.1
|
||||
},
|
||||
"max_per_img": 2000
|
||||
}
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend,
|
||||
get_onnx_model, get_rewrite_outputs)
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMROTATE)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
def test_multiclass_nms_rotated():
|
||||
from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(output_names=None, input_shape=None),
|
||||
backend_config=dict(
|
||||
type='onnxruntime',
|
||||
common_config=dict(
|
||||
fp16_mode=False, max_workspace_size=1 << 20),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
boxes=dict(
|
||||
min_shape=[1, 5, 5],
|
||||
opt_shape=[1, 5, 5],
|
||||
max_shape=[1, 5, 5]),
|
||||
scores=dict(
|
||||
min_shape=[1, 5, 8],
|
||||
opt_shape=[1, 5, 8],
|
||||
max_shape=[1, 5, 8])))
|
||||
]),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=10,
|
||||
))))
|
||||
|
||||
boxes = torch.rand(1, 5, 5)
|
||||
scores = torch.rand(1, 5, 8)
|
||||
keep_top_k = 10
|
||||
wrapped_func = WrapFunction(multiclass_nms_rotated, keep_top_k=keep_top_k)
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={
|
||||
'boxes': boxes,
|
||||
'scores': scores
|
||||
},
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
||||
'outputs: {}'.format(rewrite_outputs)
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
@pytest.mark.parametrize('pre_top_k', [-1, 1000])
|
||||
def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k):
|
||||
backend_type = 'onnxruntime'
|
||||
|
||||
from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated
|
||||
keep_top_k = 15
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(
|
||||
output_names=None,
|
||||
input_shape=None,
|
||||
dynamic_axes=dict(
|
||||
boxes={
|
||||
0: 'batch_size',
|
||||
1: 'num_boxes'
|
||||
},
|
||||
scores={
|
||||
0: 'batch_size',
|
||||
1: 'num_boxes',
|
||||
2: 'num_classes'
|
||||
},
|
||||
),
|
||||
),
|
||||
backend_config=dict(type=backend_type),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k,
|
||||
))))
|
||||
|
||||
num_classes = 5
|
||||
num_boxes = 2
|
||||
batch_size = 1
|
||||
export_boxes = torch.rand(batch_size, num_boxes, 5)
|
||||
export_scores = torch.ones(batch_size, num_boxes, num_classes)
|
||||
model_inputs = {'boxes': export_boxes, 'scores': export_scores}
|
||||
|
||||
wrapped_func = WrapFunction(multiclass_nms_rotated, keep_top_k=keep_top_k)
|
||||
|
||||
onnx_model_path = get_onnx_model(
|
||||
wrapped_func, model_inputs=model_inputs, deploy_cfg=deploy_cfg)
|
||||
|
||||
num_boxes = 100
|
||||
test_boxes = torch.rand(batch_size, num_boxes, 5)
|
||||
test_scores = torch.ones(batch_size, num_boxes, num_classes)
|
||||
model_inputs = {'boxes': test_boxes, 'scores': test_scores}
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
backend_model = ort_apis.ORTWrapper(onnx_model_path, 'cuda:0', None)
|
||||
output = backend_model.forward(model_inputs)
|
||||
output = backend_model.output_to_list(output)
|
||||
dets = output[0]
|
||||
|
||||
# Subtract 1 dim since we pad the tensors
|
||||
assert dets.shape[1] - 1 < keep_top_k, \
|
||||
'multiclass_nms_rotated returned more values than "keep_top_k"\n' \
|
||||
f'dets.shape: {dets.shape}\n' \
|
||||
f'keep_top_k: {keep_top_k}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize('add_ctr_clamp', [True, False])
|
||||
@pytest.mark.parametrize('max_shape,proj_xy,edge_swap',
|
||||
[(None, False, False),
|
||||
(torch.tensor([100, 200]), True, True)])
|
||||
def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
|
||||
max_shape: tuple, proj_xy: bool, edge_swap: bool):
|
||||
check_backend(backend_type)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(output_names=None, input_shape=None),
|
||||
backend_config=dict(type=backend_type.value, model_inputs=None),
|
||||
codebase_config=dict(type='mmrotate', task='RotatedDetection')))
|
||||
|
||||
# wrap function to enable rewrite
|
||||
def delta2bbox(*args, **kwargs):
|
||||
import mmrotate
|
||||
return mmrotate.core.bbox.coder.delta_xywha_rbbox_coder.delta2bbox(
|
||||
*args, **kwargs)
|
||||
|
||||
rois = torch.rand(5, 5)
|
||||
deltas = torch.rand(5, 5)
|
||||
original_outputs = delta2bbox(
|
||||
rois,
|
||||
deltas,
|
||||
max_shape=max_shape,
|
||||
add_ctr_clamp=add_ctr_clamp,
|
||||
proj_xy=proj_xy,
|
||||
edge_swap=edge_swap)
|
||||
|
||||
# wrap function to nn.Module, enable torch.onnx.export
|
||||
wrapped_func = WrapFunction(
|
||||
delta2bbox,
|
||||
max_shape=max_shape,
|
||||
add_ctr_clamp=add_ctr_clamp,
|
||||
proj_xy=proj_xy,
|
||||
edge_swap=edge_swap)
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={
|
||||
'rois': rois.unsqueeze(0),
|
||||
'deltas': deltas.unsqueeze(0)
|
||||
},
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
model_output = original_outputs.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.config_utils import get_ir_config
|
||||
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
||||
get_rewrite_outputs)
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMROTATE)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
|
||||
|
||||
|
||||
def seed_everything(seed=1029):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.enabled = False
|
||||
|
||||
|
||||
def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List:
|
||||
"""Converts output from a dictionary to a list.
|
||||
|
||||
The new list will contain only those output values, whose names are in list
|
||||
'output_names'.
|
||||
"""
|
||||
outputs = [
|
||||
value for name, value in rewrite_output.items() if name in output_names
|
||||
]
|
||||
return outputs
|
||||
|
||||
|
||||
def get_anchor_head_model():
|
||||
"""AnchorHead Config."""
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
nms_pre=2000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(iou_thr=0.1),
|
||||
max_per_img=2000))
|
||||
|
||||
from mmrotate.models.dense_heads import RotatedAnchorHead
|
||||
model = RotatedAnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
||||
model.requires_grad_(False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _replace_r50_with_r18(model):
|
||||
"""Replace ResNet50 with ResNet18 in config."""
|
||||
model = copy.deepcopy(model)
|
||||
if model.backbone.type == 'ResNet':
|
||||
model.backbone.depth = 18
|
||||
model.backbone.base_channels = 2
|
||||
model.neck.in_channels = [2, 4, 8, 16]
|
||||
return model
|
||||
|
||||
|
||||
# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
# @pytest.mark.parametrize('model_cfg_path', [
|
||||
# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json'
|
||||
# ])
|
||||
# def test_forward_of_base_detector(model_cfg_path, backend):
|
||||
# check_backend(backend)
|
||||
# deploy_cfg = mmcv.Config(
|
||||
# dict(
|
||||
# backend_config=dict(type=backend.value),
|
||||
# onnx_config=dict(
|
||||
# output_names=['dets', 'labels'], input_shape=None),
|
||||
# codebase_config=dict(
|
||||
# type='mmrotate',
|
||||
# task='RotatedDetection',
|
||||
# post_processing=dict(
|
||||
# score_threshold=0.05,
|
||||
# iou_threshold=0.5,
|
||||
# pre_top_k=-1,
|
||||
# keep_top_k=100,
|
||||
# ))))
|
||||
|
||||
# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
||||
# model_cfg.model = _replace_r50_with_r18(model_cfg.model)
|
||||
|
||||
# from mmrotate.models import build_detector
|
||||
|
||||
# model_cfg.model.pretrained = None
|
||||
# model_cfg.model.train_cfg = None
|
||||
# model = build_detector(
|
||||
# model_cfg.model, test_cfg= model_cfg.get('test_cfg'))
|
||||
# model.cfg = model_cfg
|
||||
# model.to('cpu')
|
||||
|
||||
# img = torch.randn(1, 3, 64, 64)
|
||||
# rewrite_inputs = {'img': img}
|
||||
# rewrite_outputs, _ = get_rewrite_outputs(
|
||||
# wrapped_model=model,
|
||||
# model_inputs=rewrite_inputs,
|
||||
# deploy_cfg=deploy_cfg)
|
||||
|
||||
# assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_deploy_cfg(backend_type: Backend, ir_type: str):
|
||||
return mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(
|
||||
type=ir_type,
|
||||
output_names=['dets', 'labels'],
|
||||
input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=2000,
|
||||
keep_top_k=2000,
|
||||
))))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type, ir_type',
|
||||
[(Backend.ONNXRUNTIME, 'onnx')])
|
||||
def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
|
||||
"""Test get_bboxes rewrite of base dense head."""
|
||||
check_backend(backend_type)
|
||||
anchor_head = get_anchor_head_model()
|
||||
anchor_head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
|
||||
output_names = get_ir_config(deploy_cfg).get('output_names', None)
|
||||
|
||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [torch.rand(1, 45, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
assert np.allclose(
|
||||
model_output[:rewrite_output.shape[0]][:2],
|
||||
rewrite_output[:2],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Codebase, load_config
|
||||
from mmdeploy.utils.test import DummyModel, SwitchBackendWrapper
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMROTATE)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
|
||||
|
||||
model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=2000,
|
||||
keep_top_k=2000)),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['input'],
|
||||
output_names=['dets', 'labels'])))
|
||||
onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
img_shape = (32, 32)
|
||||
img = np.random.rand(*img_shape, 3)
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
from mmrotate.models import RotatedBaseDetector
|
||||
model = task_processor.init_pytorch_model(None)
|
||||
assert isinstance(model, RotatedBaseDetector)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend_model():
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
wrapper.set(outputs={
|
||||
'dets': torch.rand(1, 10, 6),
|
||||
'labels': torch.rand(1, 10)
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \
|
||||
End2EndModel
|
||||
assert isinstance(backend_model, End2EndModel)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', ['cpu'])
|
||||
def test_create_input(device):
|
||||
original_device = task_processor.device
|
||||
task_processor.device = device
|
||||
inputs = task_processor.create_input(img, input_shape=img_shape)
|
||||
assert len(inputs) == 2
|
||||
task_processor.device = original_device
|
||||
|
||||
|
||||
def test_run_inference(backend_model):
|
||||
torch_model = task_processor.init_pytorch_model(None)
|
||||
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
||||
torch_results = task_processor.run_inference(torch_model, input_dict)
|
||||
backend_results = task_processor.run_inference(backend_model, input_dict)
|
||||
assert torch_results is not None
|
||||
assert backend_results is not None
|
||||
assert len(torch_results[0]) == len(backend_results[0])
|
||||
|
||||
|
||||
def test_visualize(backend_model):
|
||||
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
||||
results = task_processor.run_inference(backend_model, input_dict)
|
||||
with TemporaryDirectory() as dir:
|
||||
filename = dir + 'tmp.jpg'
|
||||
task_processor.visualize(backend_model, img, results[0], filename, '')
|
||||
assert os.path.exists(filename)
|
||||
|
||||
|
||||
def test_get_partition_cfg():
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = task_processor.get_partition_cfg(partition_type='')
|
||||
|
||||
|
||||
def test_build_dataset_and_dataloader():
|
||||
dataset = task_processor.build_dataset(
|
||||
dataset_cfg=model_cfg, dataset_type='test')
|
||||
assert isinstance(dataset, Dataset), 'Failed to build dataset'
|
||||
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||
assert isinstance(dataloader, DataLoader), 'Failed to build dataloader'
|
||||
|
||||
|
||||
def test_single_gpu_test_and_evaluate():
|
||||
from mmcv.parallel import MMDataParallel
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return 0
|
||||
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
def format_results(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
dataset = DummyDataset()
|
||||
# Prepare dataloader
|
||||
dataloader = DataLoader(dataset)
|
||||
|
||||
# Prepare dummy model
|
||||
model = DummyModel(outputs=[torch.rand([1, 10, 6]), torch.rand([1, 10])])
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
# Run test
|
||||
outputs = task_processor.single_gpu_test(model, dataloader)
|
||||
assert isinstance(outputs, list)
|
||||
output_file = NamedTemporaryFile(suffix='.pkl').name
|
||||
task_processor.evaluate_outputs(
|
||||
model_cfg, outputs, dataset, 'bbox', out=output_file, format_only=True)
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase, load_config
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMROTATE)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
|
||||
|
||||
IMAGE_SIZE = 32
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
class TestEnd2EndModel:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# force add backend wrapper regardless of plugins
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
cls.outputs = {
|
||||
'dets': torch.rand(1, 10, 6),
|
||||
'labels': torch.rand(1, 10)
|
||||
}
|
||||
cls.wrapper.set(outputs=cls.outputs)
|
||||
deploy_cfg = mmcv.Config(
|
||||
{'onnx_config': {
|
||||
'output_names': ['dets', 'labels']
|
||||
}})
|
||||
model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
|
||||
from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \
|
||||
End2EndModel
|
||||
cls.end2end_model = End2EndModel(
|
||||
Backend.ONNXRUNTIME, [''], ['' for i in range(15)],
|
||||
device='cpu',
|
||||
deploy_cfg=deploy_cfg,
|
||||
model_cfg=model_cfg)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls.wrapper.recover()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'ori_shape',
|
||||
[[IMAGE_SIZE, IMAGE_SIZE, 3], [2 * IMAGE_SIZE, 2 * IMAGE_SIZE, 3]])
|
||||
def test_forward(self, ori_shape):
|
||||
imgs = [torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)]
|
||||
img_metas = [[{
|
||||
'ori_shape': ori_shape,
|
||||
'img_shape': [IMAGE_SIZE, IMAGE_SIZE, 3],
|
||||
'scale_factor': [1., 1., 1., 1.],
|
||||
'filename': ''
|
||||
}]]
|
||||
results = self.end2end_model.forward(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using '\
|
||||
'End2EndModel'
|
||||
|
||||
def test_forward_test(self):
|
||||
imgs = torch.rand(2, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
results = self.end2end_model.forward_test(imgs)
|
||||
assert isinstance(results[0], torch.Tensor)
|
||||
|
||||
def test_show_result(self):
|
||||
input_img = np.zeros([IMAGE_SIZE, IMAGE_SIZE, 3])
|
||||
img_path = NamedTemporaryFile(suffix='.jpg').name
|
||||
|
||||
result = torch.rand(1, 10, 6)
|
||||
self.end2end_model.show_result(
|
||||
input_img, result, '', show=False, out_file=img_path)
|
||||
assert osp.exists(img_path)
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
def test_build_rotated_detection_model():
|
||||
model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
onnx_config=dict(output_names=['dets', 'labels']),
|
||||
codebase_config=dict(type='mmrotate')))
|
||||
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||
from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import (
|
||||
End2EndModel, build_rotated_detection_model)
|
||||
segmentor = build_rotated_detection_model([''], model_cfg, deploy_cfg,
|
||||
'cpu')
|
||||
assert isinstance(segmentor, End2EndModel)
|
|
@ -778,18 +778,24 @@ def test_expand(backend,
|
|||
|
||||
@pytest.mark.parametrize('backend', [TEST_ONNXRT])
|
||||
@pytest.mark.parametrize('iou_threshold', [0.1, 0.3])
|
||||
def test_nms_rotated(backend, iou_threshold, save_dir=None):
|
||||
@pytest.mark.parametrize('score_threshold', [0., 0.1])
|
||||
def test_nms_rotated(backend, iou_threshold, score_threshold, save_dir=None):
|
||||
backend.check_env()
|
||||
|
||||
boxes = torch.tensor(
|
||||
[[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]],
|
||||
[[[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]],
|
||||
[[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]]],
|
||||
dtype=torch.float32)
|
||||
scores = torch.tensor(
|
||||
[[[0.5, 0.1, 0.1], [0.1, 0.6, 0.1], [0.1, 0.1, 0.7], [0.1, 0.1, 0.1]],
|
||||
[[0.1, 0.1, 0.1], [0.7, 0.1, 0.1], [0.1, 0.6, 0.1], [0.1, 0.1, 0.5]]],
|
||||
dtype=torch.float32)
|
||||
scores = torch.tensor([0.5, 0.6, 0.7], dtype=torch.float32)
|
||||
|
||||
from mmdeploy.mmcv.ops import ONNXNMSRotatedOp
|
||||
|
||||
def wrapped_function(torch_boxes, torch_scores):
|
||||
return ONNXNMSRotatedOp.apply(torch_boxes, torch_scores, iou_threshold)
|
||||
return ONNXNMSRotatedOp.apply(torch_boxes, torch_scores, iou_threshold,
|
||||
score_threshold)
|
||||
|
||||
wrapped_model = WrapFunction(wrapped_function).eval()
|
||||
|
||||
|
|
Loading…
Reference in New Issue