diff --git a/configs/mmrotate/rotated-detection_onnxruntime_dynamic.py b/configs/mmrotate/rotated-detection_onnxruntime_dynamic.py new file mode 100644 index 000000000..5252f11dd --- /dev/null +++ b/configs/mmrotate/rotated-detection_onnxruntime_dynamic.py @@ -0,0 +1,17 @@ +_base_ = ['./rotated-detection_onnxruntime_static.py'] +onnx_config = dict( + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'dets': { + 0: 'batch', + 1: 'num_dets', + }, + 'labels': { + 0: 'batch', + 1: 'num_dets', + }, + }, ) diff --git a/configs/mmrotate/rotated-detection_onnxruntime_static.py b/configs/mmrotate/rotated-detection_onnxruntime_static.py new file mode 100644 index 000000000..662608bbe --- /dev/null +++ b/configs/mmrotate/rotated-detection_onnxruntime_static.py @@ -0,0 +1,3 @@ +_base_ = ['./rotated-detection_static.py', '../_base_/backends/onnxruntime.py'] + +onnx_config = dict(output_names=['dets', 'labels'], input_shape=None) diff --git a/configs/mmrotate/rotated-detection_static.py b/configs/mmrotate/rotated-detection_static.py new file mode 100644 index 000000000..324de6f7f --- /dev/null +++ b/configs/mmrotate/rotated-detection_static.py @@ -0,0 +1,9 @@ +_base_ = ['../_base_/onnx_config.py'] +codebase_config = dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=3000, + keep_top_k=2000)) diff --git a/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp index 9858772ab..6052195a5 100644 --- a/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp +++ b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp @@ -264,6 +264,7 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2) NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), ort_(api_), info_(info) { iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); + score_threshold_ = ort_.KernelInfoGetAttribute(info, "score_threshold"); // create allocator allocator_ = Ort::AllocatorWithDefaultOptions(); @@ -271,6 +272,7 @@ NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info) void NMSRotatedKernel::Compute(OrtKernelContext* context) { const float iou_threshold = iou_threshold_; + const float score_threshold = score_threshold_; const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); @@ -280,67 +282,77 @@ void NMSRotatedKernel::Compute(OrtKernelContext* context) { OrtTensorDimensions boxes_dim(ort_, boxes); OrtTensorDimensions scores_dim(ort_, scores); - int64_t nboxes = boxes_dim[0]; - assert(boxes_dim[1] == 5); //(cx,cy,w,h,theta) + // loop over batch + int64_t nbatch = boxes_dim[0]; + int64_t nboxes = boxes_dim[1]; + int64_t nclass = scores_dim[1]; + assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta) // allocate tmp memory - float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nboxes * 5); - float* sc = (float*)allocator_.Alloc(sizeof(float) * nboxes); - bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nboxes); - for (int64_t i = 0; i < nboxes; i++) { - select[i] = true; - } + float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5); + float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes); + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); - memcpy(tmp_boxes, boxes_data, sizeof(float) * nboxes * 5); - memcpy(sc, scores_data, sizeof(float) * nboxes); + memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5); + memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes); - // sort scores - std::vector tmp_sc; - for (int i = 0; i < nboxes; i++) { - tmp_sc.push_back(sc[i]); - } - std::vector order(tmp_sc.size()); - std::iota(order.begin(), order.end(), 0); - std::sort(order.begin(), order.end(), - [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); - - for (int64_t _i = 0; _i < nboxes; _i++) { - if (select[_i] == false) continue; - auto i = order[_i]; - - for (int64_t _j = _i + 1; _j < nboxes; _j++) { - if (select[_j] == false) continue; - auto j = order[_j]; - RotatedBox box1, box2; - auto center_shift_x = (tmp_boxes[i * 5] + tmp_boxes[j * 5]) / 2.0; - auto center_shift_y = (tmp_boxes[i * 5 + 1] + tmp_boxes[j * 5 + 1]) / 2.0; - box1.x_ctr = tmp_boxes[i * 5] - center_shift_x; - box1.y_ctr = tmp_boxes[i * 5 + 1] - center_shift_y; - box1.w = tmp_boxes[i * 5 + 2]; - box1.h = tmp_boxes[i * 5 + 3]; - box1.a = tmp_boxes[i * 5 + 4]; - box2.x_ctr = tmp_boxes[j * 5] - center_shift_x; - box2.y_ctr = tmp_boxes[j * 5 + 1] - center_shift_y; - box2.w = tmp_boxes[j * 5 + 2]; - box2.h = tmp_boxes[j * 5 + 3]; - box2.a = tmp_boxes[j * 5 + 4]; - auto area1 = box1.w * box1.h; - auto area2 = box2.w * box2.h; - auto intersection = rotated_boxes_intersection(box1, box2); - float baseS = 1.0; - baseS = (area1 + area2 - intersection); - auto ovr = intersection / baseS; - if (ovr > iou_threshold) select[_j] = false; - } - } + // std::vector> res_order; std::vector res_order; - for (int i = 0; i < nboxes; i++) { - if (select[i]) { - res_order.push_back(order[i]); - } - } + for (int64_t k = 0; k < nbatch; k++) { + for (int64_t g = 0; g < nclass; g++) { + for (int64_t i = 0; i < nboxes; i++) { + select[i] = true; + } + // sort scores + std::vector tmp_sc; + for (int i = 0; i < nboxes; i++) { + tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]); + } + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), + [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); + for (int64_t _i = 0; _i < nboxes; _i++) { + if (select[_i] == false) continue; + auto i = order[_i]; + for (int64_t _j = _i + 1; _j < nboxes; _j++) { + if (select[_j] == false) continue; + auto j = order[_j]; + RotatedBox box1, box2; + auto center_shift_x = + (tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0; + auto center_shift_y = + (tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0; + box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x; + box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y; + box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2]; + box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3]; + box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4]; + box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x; + box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y; + box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2]; + box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3]; + box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4]; + auto area1 = box1.w * box1.h; + auto area2 = box2.w * box2.h; + auto intersection = rotated_boxes_intersection(box1, box2); + float baseS = 1.0; + baseS = (area1 + area2 - intersection); + auto ovr = intersection / baseS; + if (ovr > iou_threshold) select[_j] = false; + } + } + for (int i = 0; i < nboxes; i++) { + if (select[i] & (tmp_sc[order[i]] > score_threshold)) { + res_order.push_back(k); + res_order.push_back(g); + res_order.push_back(order[i]); + } + } + } // class loop + } // batch loop - std::vector inds_dims({(int64_t)res_order.size()}); + std::vector inds_dims({(int64_t)res_order.size() / 3, 3}); OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); int64_t* res_data = ort_.GetTensorMutableData(res); diff --git a/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h index 0c0f273a4..c1a12dad1 100644 --- a/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h +++ b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h @@ -22,6 +22,7 @@ struct NMSRotatedKernel { const OrtKernelInfo* info_; Ort::AllocatorWithDefaultOptions allocator_; float iou_threshold_; + float score_threshold_; }; struct NMSRotatedOp : Ort::CustomOpBase { diff --git a/docs/en/benchmark.md b/docs/en/benchmark.md index 9e66f19b8..eb7427e67 100644 --- a/docs/en/benchmark.md +++ b/docs/en/benchmark.md @@ -1819,6 +1819,53 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut +
+MMRotate +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MMRotatePytorchONNXRuntimeTensorRTPPLNNOpenVINOModel Config
ModelTaskDatasetMetricsfp32fp32fp32fp16fp16fp32model config file
RotatedRetinaNetRotated DetectionDOTA-v1.0mAP0.6980.698----$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py
+
+
+ ### Notes - As some datasets contain images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones. diff --git a/docs/en/codebases/mmrotate.md b/docs/en/codebases/mmrotate.md new file mode 100644 index 000000000..1a7311752 --- /dev/null +++ b/docs/en/codebases/mmrotate.md @@ -0,0 +1,53 @@ +# MMRotate Support + +[MMRotate](https://github.com/open-mmlab/mmrotate) is an open-source toolbox for rotated object detection based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project. + +## MMRotate installation tutorial + +Please refer to [official installation guide](https://mmrotate.readthedocs.io/en/latest/install.html) to install the codebase. + +## MMRotate models support + +| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config | +|:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:| +| RotatedRetinaNet | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | + +### Example + +```bash +# convert ort +python tools/deploy.py \ +configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \ +$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \ +$MMROTATE_DIR/checkpoints/rotated_retinanet_obb_r50_fpn_1x_dota_le135-e4131166.pth \ +$MMROTATE_DIR/demo/demo.jpg \ +--work-dir work-dirs/mmrotate/rotated_retinanet/ort \ +--device cpu + +# compute metric +python tools/test.py \ + configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \ + $MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \ + --model work-dirs/mmrotate/rotated_retinanet/ort/end2end.onnx \ + --metrics mAP + +# generate submit file +python tools/test.py \ + configs/mmrotate/rotated-detection_onnxruntime_dynamic.py \ + $MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py \ + --model work-dirs/mmrotate/rotated_retinanet/ort/end2end.onnx \ + --format-only \ + --metric-options submission_dir=work-dirs/mmrotate/rotated_retinanet/ort/Task1_results +``` + +Note + +- Usually, mmrotate models need some extra information for the input image, but we can't get it directly. So, when exporting the model, you can use `$MMROTATE_DIR/demo/demo.jpg` as input. + +## Reminder + +None + +## FAQs + +None diff --git a/docs/zh_cn/benchmark.md b/docs/zh_cn/benchmark.md index 3225c44fd..15bbe7296 100644 --- a/docs/zh_cn/benchmark.md +++ b/docs/zh_cn/benchmark.md @@ -1808,6 +1808,54 @@ GPU: ncnn, TensorRT, PPLNN +
+MMRotate +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MMRotatePytorchONNXRuntimeTensorRTPPLNNOpenVINOModel Config
ModelTaskDatasetMetricsfp32fp32fp32fp16fp16fp32model config file
RotatedRetinaNetRotated DetectionDOTA-v1.0mAP0.6980.698----$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py
+
+
+ + ### 注意 - 由于某些数据集在代码库中包含各种分辨率的图像,例如 MMDet,速度基准是通过 MMDeploy 中的静态配置获得的,而性能基准是通过动态配置获得的。 diff --git a/mmdeploy/codebase/__init__.py b/mmdeploy/codebase/__init__.py index a16c146d0..0311a0c57 100644 --- a/mmdeploy/codebase/__init__.py +++ b/mmdeploy/codebase/__init__.py @@ -4,7 +4,10 @@ import importlib from mmdeploy.utils import Codebase from .base import BaseTask, MMCodebase, get_codebase_class -extra_dependent_library = {Codebase.MMOCR: ['mmdet']} +extra_dependent_library = { + Codebase.MMOCR: ['mmdet'], + Codebase.MMROTATE: ['mmdet'] +} def import_codebase(codebase: Codebase): diff --git a/mmdeploy/codebase/mmrotate/__init__.py b/mmdeploy/codebase/mmrotate/__init__.py new file mode 100644 index 000000000..0d3510f91 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .core import * # noqa: F401,F403 +from .deploy import * # noqa: F401,F403 +from .models import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/core/__init__.py b/mmdeploy/codebase/mmrotate/core/__init__.py new file mode 100644 index 000000000..7c0065f72 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bbox import * # noqa: F401,F403 +from .post_processing import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py new file mode 100644 index 000000000..a42a3cf2d --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .delta_xywha_rbbox_coder import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/core/bbox/delta_xywha_rbbox_coder.py b/mmdeploy/codebase/mmrotate/core/bbox/delta_xywha_rbbox_coder.py new file mode 100644 index 000000000..3d6fd87e2 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/bbox/delta_xywha_rbbox_coder.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmrotate.core import norm_angle + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmrotate.core.bbox.coder.delta_xywha_rbbox_coder.delta2bbox', + backend='default') +def delta2bbox(ctx, + rois, + deltas, + means=(0., 0., 0., 0., 0.), + stds=(1., 1., 1., 1., 1.), + max_shape=None, + wh_ratio_clip=16 / 1000, + add_ctr_clamp=False, + ctr_clamp=32, + angle_range='oc', + norm_factor=None, + edge_swap=False, + proj_xy=False): + """Rewrite `delta2bbox` for default backend. + + Support batch bbox decoder. + + Args: + ctx (ContextCaller): The context with additional information. + rois (Tensor): Boxes to be transformed. Has shape (N, 5). + deltas (Tensor): Encoded offsets relative to each roi. + Has shape (N, num_classes * 5) or (N, 5). Note + N = num_base_anchors * W * H, when rois is a grid of + anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1., 1.). + max_shape (tuple[int, int]): Maximum bounds for boxes, specifies + (H, W). Default None. + wh_ratio_clip (float): Maximum aspect ratio for boxes. Default + 16 / 1000. + add_ctr_clamp (bool): Whether to add center clamp. When set to True, + the center of the prediction bounding box will be clamped to + avoid being too far away from the center of the anchor. + Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. + angle_range (str, optional): Angle representations. Defaults to 'oc'. + norm_factor (None|float, optional): Regularization factor of angle. + edge_swap (bool, optional): Whether swap the edge if w < h. + Defaults to False. + proj_xy (bool, optional): Whether project x and y according to angle. + Defaults to False. + + Return: + bboxes (Tensor): Boxes with shape (N, num_classes * 5) or (N, 5), + where 5 represent cx, cy, w, h, angle. + """ + means = deltas.new_tensor(means).view(1, -1) + stds = deltas.new_tensor(stds).view(1, -1) + delta_shape = deltas.shape + reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 5)) + denorm_deltas = reshaped_deltas * stds + means + + dx = denorm_deltas[..., 0] + dy = denorm_deltas[..., 1] + dw = denorm_deltas[..., 2] + dh = denorm_deltas[..., 3] + da = denorm_deltas[..., 4] + if norm_factor: + da *= norm_factor * np.pi + # Compute center of each roi + + px = rois[..., None, 0] + py = rois[..., None, 1] + # Compute width/height of each roi + pw = rois[..., None, 2] + ph = rois[..., None, 3] + # Compute rotated angle of each roi + pa = rois[..., None, 4] + dx_width = pw * dx + dy_height = ph * dy + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp) + dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp) + dw = torch.clamp(dw, max=max_ratio) + dh = torch.clamp(dh, max=max_ratio) + else: + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() + # Use network energy to shift the center of each roi + if proj_xy: + gx = dx * pw * torch.cos(pa) - dy * ph * torch.sin(pa) + px + gy = dx * pw * torch.sin(pa) + dy * ph * torch.cos(pa) + py + else: + gx = px + dx_width + gy = py + dy_height + # Compute angle + ga = norm_angle(pa + da, angle_range) + if max_shape is not None: + gx = gx.clamp(min=0, max=max_shape[1] - 1) + gy = gy.clamp(min=0, max=max_shape[0] - 1) + + if edge_swap: + w_regular = torch.where(gw > gh, gw, gh) + h_regular = torch.where(gw > gh, gh, gw) + theta_regular = torch.where(gw > gh, ga, ga + np.pi / 2) + theta_regular = norm_angle(theta_regular, angle_range) + return torch.stack([gx, gy, w_regular, h_regular, theta_regular], + dim=-1).view_as(deltas) + else: + return torch.stack([gx, gy, gw, gh, ga], dim=-1).view(deltas.size()) diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py b/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py new file mode 100644 index 000000000..01c3b72a4 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bbox_nms import multiclass_nms_rotated + +__all__ = ['multiclass_nms_rotated'] diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py new file mode 100644 index 000000000..44753271d --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from mmdeploy.mmcv.ops import ONNXNMSRotatedOp + + +def select_nms_index(scores: torch.Tensor, + boxes: torch.Tensor, + nms_index: torch.Tensor, + batch_size: int, + keep_top_k: int = -1): + """Transform NMSRotated output. + + Args: + scores (Tensor): The detection scores of shape + [N, num_classes, num_boxes]. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 6]. + nms_index (Tensor): NMS output of bounding boxes indexing. + batch_size (int): Batch size of the input image. + keep_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + + Returns: + tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 6] + and `labels` of shape [N, num_det]. + """ + batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1] + box_inds = nms_index[:, 2] + + # index by nms output + scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1) + boxes = boxes[batch_inds, box_inds, ...] + dets = torch.cat([boxes, scores], dim=1) + + # batch all + batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1) + batch_template = torch.arange( + 0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device) + batched_dets = batched_dets.where( + (batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1), + batched_dets.new_zeros(1)) + + batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1) + batched_labels = batched_labels.where( + (batch_inds == batch_template.unsqueeze(1)), + batched_labels.new_ones(1) * -1) + + N = batched_dets.shape[0] + + # expand tensor to eliminate [0, ...] tensor + batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 6))), + 1) + batched_labels = torch.cat((batched_labels, batched_labels.new_zeros( + (N, 1))), 1) + + # sort + is_use_topk = keep_top_k > 0 and \ + (torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1]) + if is_use_topk: + _, topk_inds = batched_dets[:, :, -1].topk(keep_top_k, dim=1) + else: + _, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) + topk_batch_inds = torch.arange( + batch_size, dtype=topk_inds.dtype, + device=topk_inds.device).view(-1, 1).expand_as(topk_inds) + batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] + batched_labels = batched_labels[topk_batch_inds, topk_inds, ...] + + # slice and recover the tensor + return batched_dets, batched_labels + + +def multiclass_nms_rotated(boxes: Tensor, + scores: Tensor, + iou_threshold: float = 0.1, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): + """NMSRotated for multi-class bboxes. + + This function helps exporting to onnx with batch and multiclass NMSRotated + op. It only supports class-agnostic detection results. That is, the scores + is of shape (N, num_bboxes, num_classes) and the boxes is of shape + (N, num_boxes, 5). + + Args: + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + iou_threshold (float): IOU threshold of nms. Defaults to 0.5. + score_threshold (float): bbox threshold, bboxes with scores lower than + it will not be considered. + pre_top_k (int): Number of top K boxes to keep before nms. + Defaults to -1. + keep_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + + Returns: + tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 6] + and `labels` of shape [N, num_det]. + """ + batch_size = scores.shape[0] + + if pre_top_k > 0: + max_scores, _ = scores.max(-1) + _, topk_inds = max_scores.topk(pre_top_k) + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() + boxes = boxes[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + + scores = scores.permute(0, 2, 1) + selected_indices = ONNXNMSRotatedOp.apply(boxes, scores, iou_threshold, + score_threshold) + + dets, labels = select_nms_index( + scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) + + return dets, labels diff --git a/mmdeploy/codebase/mmrotate/deploy/__init__.py b/mmdeploy/codebase/mmrotate/deploy/__init__.py new file mode 100644 index 000000000..138054107 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/deploy/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mmrotate import MMROTATE +from .rotated_detection import RotatedDetection + +__all__ = ['MMROTATE', 'RotatedDetection'] diff --git a/mmdeploy/codebase/mmrotate/deploy/mmrotate.py b/mmdeploy/codebase/mmrotate/deploy/mmrotate.py new file mode 100644 index 000000000..ab46df757 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/deploy/mmrotate.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import mmcv +import torch +from mmcv.utils import Registry +from mmdet.datasets import replace_ImageToTensor +from torch.utils.data import DataLoader, Dataset + +from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase +from mmdeploy.utils import Codebase, get_task_type + + +def __build_mmrotate_task(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, + device: str, registry: Registry) -> BaseTask: + task = get_task_type(deploy_cfg) + return registry.module_dict[task.value](model_cfg, deploy_cfg, device) + + +MMROTATE_TASK = Registry('mmrotate_tasks', build_func=__build_mmrotate_task) + + +@CODEBASE.register_module(Codebase.MMROTATE.value) +class MMROTATE(MMCodebase): + """mmrotate codebase class.""" + + task_registry = MMROTATE_TASK + + def __init__(self): + super(MMROTATE, self).__init__() + + @staticmethod + def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, + device: str): + """The interface to build the task processors of mmrotate. + + Args: + model_cfg (str | mmcv.Config): Model config file or loaded Config + object. + deploy_cfg (str | mmcv.Config): Deployment config file or loaded + Config object. + device (str): A string specifying device type. + + Returns: + BaseTask: A task processor. + """ + return MMROTATE_TASK.build(model_cfg, deploy_cfg, device) + + @staticmethod + def build_dataset(dataset_cfg: Union[str, mmcv.Config], + dataset_type: str = 'val', + **kwargs) -> Dataset: + """Build dataset for mmrotate. + + Args: + dataset_cfg (str | mmcv.Config): The input dataset config. + dataset_type (str): A string represents dataset type, e.g.: 'train' + , 'test', 'val'. Defaults to 'val'. + + Returns: + Dataset: A PyTorch dataset. + """ + from mmrotate.datasets import build_dataset as build_dataset_mmrotate + + # dataset_cfg = load_config(dataset_cfg)[0] + assert dataset_type in dataset_cfg.data + data_cfg = dataset_cfg.data[dataset_type] + # in case the dataset is concatenated + if isinstance(data_cfg, dict): + data_cfg.test_mode = True + samples_per_gpu = data_cfg.pop('samples_per_gpu', 1) + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline) + elif isinstance(data_cfg, list): + for ds_cfg in data_cfg: + ds_cfg.test_mode = True + samples_per_gpu = max( + [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in data_cfg]) + if samples_per_gpu > 1: + for ds_cfg in data_cfg: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + dataset = build_dataset_mmrotate(data_cfg) + return dataset + + @staticmethod + def build_dataloader(dataset: Dataset, + samples_per_gpu: int, + workers_per_gpu: int, + num_gpus: int = 1, + dist: bool = False, + shuffle: bool = False, + seed: Optional[int] = None, + drop_last: bool = False, + persistent_workers: bool = True, + **kwargs) -> DataLoader: + """Build dataloader for mmrotate. + + Args: + dataset (Dataset): Input dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e. + ,batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data + loading for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed + training. + dist (bool): Distributed training/test or not. Defaults to `False`. + shuffle (bool): Whether to shuffle the data at every epoch. + Defaults to `False`. + seed (int): An integer set to be seed. Default is `None`. + drop_last (bool): Whether to drop the last incomplete batch in + epoch. Default to `False`. + persistent_workers (bool): If `True`, the data loader will not + shutdown the worker processes after a dataset has been + consumed once. This allows to maintain the workers Dataset + instances alive. The argument also has effect in + PyTorch>=1.7.0. Default is `True`. + kwargs: Any other keyword argument to be used to initialize + DataLoader. + + Returns: + DataLoader: A PyTorch dataloader. + """ + from mmdet.datasets import build_dataloader as build_dataloader_mmdet + return build_dataloader_mmdet( + dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=num_gpus, + dist=dist, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + persistent_workers=persistent_workers, + **kwargs) + + @staticmethod + def single_gpu_test(model: torch.nn.Module, + data_loader: DataLoader, + show: bool = False, + out_dir: Optional[str] = None, + **kwargs): + """Run test with single gpu. + + Args: + model (torch.nn.Module): Input model from nn.Module. + data_loader (DataLoader): PyTorch data loader. + show (bool): Specifying whether to show plotted results. Defaults + to `False`. + out_dir (str): A directory to save results, defaults to `None`. + + Returns: + list: The prediction results. + """ + from mmdet.apis import single_gpu_test + outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs) + return outputs diff --git a/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py b/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py new file mode 100644 index 000000000..3123c670b --- /dev/null +++ b/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py @@ -0,0 +1,343 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +from mmcv.parallel import DataContainer, collate, scatter +from torch import nn +from torch.utils.data import Dataset + +from mmdeploy.codebase.base import BaseTask +from mmdeploy.utils import Task, get_input_shape +from .mmrotate import MMROTATE_TASK + + +def process_model_config(model_cfg: mmcv.Config, + imgs: Union[Sequence[str], Sequence[np.ndarray]], + input_shape: Optional[Sequence[int]] = None): + """Process the model config. + + Args: + model_cfg (mmcv.Config): The model config. + imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted + data type are List[str], List[np.ndarray]. + input_shape (list[int]): A list of two integer in (width, height) + format specifying input shape. Default: None. + + Returns: + mmcv.Config: the model config after processing. + """ + from mmdet.datasets import replace_ImageToTensor + + cfg = model_cfg.copy() + + if isinstance(imgs[0], np.ndarray): + cfg = cfg.copy() + # set loading pipeline type + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + # for static exporting + if input_shape is not None: + cfg.data.test.pipeline[1]['img_scale'] = tuple(input_shape) + transforms = cfg.data.test.pipeline[1]['transforms'] + for trans in transforms: + trans_type = trans['type'] + if trans_type == 'Pad': + trans['size_divisor'] = 1 + + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + return cfg + + +@MMROTATE_TASK.register_module(Task.ROTATED_DETECTION.value) +class RotatedDetection(BaseTask): + """Rotated detection task class. + + Args: + model_cfg (mmcv.Config): Loaded model Config object.. + deploy_cfg (mmcv.Config): Loaded deployment Config object. + device (str): A string represents device type. + """ + + def __init__(self, model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, + device: str): + super(RotatedDetection, self).__init__(model_cfg, deploy_cfg, device) + + def init_backend_model(self, + model_files: Optional[str] = None, + **kwargs) -> torch.nn.Module: + """Initialize backend model. + + Args: + model_files (Sequence[str]): Input model files. + + Returns: + nn.Module: An initialized backend model. + """ + from .rotated_detection_model import build_rotated_detection_model + model = build_rotated_detection_model( + model_files, self.model_cfg, self.deploy_cfg, device=self.device) + return model.eval() + + def init_pytorch_model(self, + model_checkpoint: Optional[str] = None, + cfg_options: Optional[Dict] = None, + **kwargs) -> torch.nn.Module: + """Initialize torch model. + + Args: + model_checkpoint (str): The checkpoint file of torch model, + defaults to `None`. + cfg_options (dict): Optional config key-pair parameters. + + Returns: + nn.Module: An initialized torch model generated by OpenMMLab + codebases. + """ + import warnings + + from mmcv.runner import load_checkpoint + from mmdet.core import get_classes + from mmrotate.models import build_detector + + if isinstance(self.model_cfg, str): + self.model_cfg = mmcv.Config.fromfile(self.model_cfg) + elif not isinstance(self.model_cfg, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(self.model_cfg)}') + if cfg_options is not None: + self.model_cfg.merge_from_dict(cfg_options) + self.model_cfg.model.pretrained = None + self.model_cfg.model.train_cfg = None + model = build_detector( + self.model_cfg.model, test_cfg=self.model_cfg.get('test_cfg')) + if model_checkpoint is not None: + map_loc = 'cpu' if self.device == 'cpu' else None + checkpoint = load_checkpoint( + model, model_checkpoint, map_location=map_loc) + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + warnings.simplefilter('once') + warnings.warn('Class names are not saved in the checkpoint\'s ' + 'meta data, use COCO classes by default.') + model.CLASSES = get_classes('coco') + model.cfg = self.model_cfg + model.to(self.device) + return model.eval() + + def create_input(self, + imgs: Union[str, np.ndarray], + input_shape: Sequence[int] = None) \ + -> Tuple[Dict, torch.Tensor]: + """Create input for rotated object detection. + + Args: + imgs (str | np.ndarray): Input image(s), accepted data type are + `str`, `np.ndarray`. + input_shape (list[int]): A list of two integer in (width, height) + format specifying input shape. Defaults to `None`. + + Returns: + tuple: (data, img), meta information for the input image and input. + """ + if isinstance(imgs, (list, tuple)): + if not isinstance(imgs[0], (np.ndarray, str)): + raise AssertionError('imgs must be strings or numpy arrays') + + elif isinstance(imgs, (np.ndarray, str)): + imgs = [imgs] + else: + raise AssertionError('imgs must be strings or numpy arrays') + cfg = process_model_config(self.model_cfg, imgs, input_shape) + from mmdet.datasets.pipelines import Compose + test_pipeline = Compose(cfg.data.test.pipeline) + + data_list = [] + for img in imgs: + # prepare data + if isinstance(imgs[0], np.ndarray): + # directly add img + data = dict(img=img) + else: + # add information into dict + data = dict(img_info=dict(filename=img), img_prefix=None) + + # build the data pipeline + data = test_pipeline(data) + # get tensor from list to stack for batch mode (rotated detection) + data_list.append(data) + + if isinstance(data_list[0]['img'], list) and len(data_list) > 1: + raise Exception('aug test does not support ' + f'inference with batch size ' + f'{len(data_list)}') + + data = collate(data_list, samples_per_gpu=len(imgs)) + + # process img_metas + if isinstance(data['img_metas'], list): + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + else: + data['img_metas'] = data['img_metas'].data + + if isinstance(data['img'], list): + data['img'] = [img.data for img in data['img']] + if isinstance(data['img'][0], list): + data['img'] = [img[0] for img in data['img']] + else: + data['img'] = data['img'].data + + if self.device != 'cpu': + data = scatter(data, [self.device])[0] + + return data, data['img'] + + def visualize(self, + model: nn.Module, + image: Union[str, np.ndarray], + result: list, + output_file: str, + window_name: str = '', + show_result: bool = False): + """Visualize predictions of a model. + + Args: + model (nn.Module): Input model. + image (str | np.ndarray): Input image to draw predictions on. + result (list): A list of predictions. + output_file (str): Output file to save drawn image. + window_name (str): The name of visualization window. Defaults to + an empty string. + show_result (bool): Whether to show result in windows, defaults + to `False`. + """ + show_img = mmcv.imread(image) if isinstance(image, str) else image + output_file = None if show_result else output_file + model.show_result( + show_img, + result, + out_file=output_file, + win_name=window_name, + show=show_result) + + @staticmethod + def run_inference(model: nn.Module, + model_inputs: Dict[str, torch.Tensor]) -> list: + """Run inference once for a segmentation model of mmseg. + + Args: + model (nn.Module): Input model. + model_inputs (dict): A dict containing model inputs tensor and + meta info. + + Returns: + list: The predictions of model inference. + """ + return model(**model_inputs, return_loss=False, rescale=True) + + @staticmethod + def get_partition_cfg(partition_type: str) -> Dict: + """Get a certain partition config. + + Args: + partition_type (str): A string specifying partition type. + + Returns: + dict: A dictionary of partition config. + """ + raise NotImplementedError('Not supported yet.') + + @staticmethod + def get_tensor_from_input(input_data: Dict[str, Any]) -> torch.Tensor: + """Get input tensor from input data. + + Args: + input_data (dict): Input data containing meta info and image + tensor. + Returns: + torch.Tensor: An image in `Tensor`. + """ + if isinstance(input_data['img'], DataContainer): + return input_data['img'].data[0] + return input_data['img'][0] + + @staticmethod + def evaluate_outputs(model_cfg, + outputs: Sequence, + dataset: Dataset, + metrics: Optional[str] = None, + out: Optional[str] = None, + metric_options: Optional[dict] = None, + format_only: bool = False, + log_file: Optional[str] = None): + """Perform post-processing to predictions of model. + + Args: + outputs (Sequence): A list of predictions of model inference. + dataset (Dataset): Input dataset to run test. + model_cfg (mmcv.Config): The model config. + metrics (str): Evaluation metrics, which depends on + the codebase and the dataset, e.g., "mAP" for rotated + detection. + out (str): Output result file in pickle format, defaults to `None`. + metric_options (dict): Custom options for evaluation, will be + kwargs for dataset.evaluate() function. Defaults to `None`. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. Defaults + to `False`. + log_file (str | None): The file to write the evaluation results. + Defaults to `None` and the results will only print on stdout. + """ + from mmcv.utils import get_logger + logger = get_logger('test', log_file=log_file) + + if out: + logger.debug(f'writing results to {out}') + mmcv.dump(outputs, out) + kwargs = {} if metric_options is None else metric_options + if format_only: + dataset.format_results(outputs, **kwargs) + if metrics: + eval_kwargs = model_cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=metrics, **kwargs)) + logger.info(dataset.evaluate(outputs, **eval_kwargs)) + + def get_preprocess(self) -> Dict: + """Get the preprocess information for SDK. + + Return: + dict: Composed of the preprocess information. + """ + input_shape = get_input_shape(self.deploy_cfg) + model_cfg = process_model_config(self.model_cfg, [''], input_shape) + preprocess = model_cfg.data.test.pipeline + return preprocess + + def get_postprocess(self) -> Dict: + """Get the postprocess information for SDK. + + Return: + dict: Composed of the postprocess information. + """ + postprocess = self.model_cfg.model.bbox_head + return postprocess + + def get_model_name(self) -> str: + """Get the model name. + + Return: + str: the name of the model. + """ + assert 'type' in self.model_cfg.model, 'model config contains no type' + name = self.model_cfg.model.type.lower() + return name diff --git a/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py b/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py new file mode 100644 index 000000000..b39f73843 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/deploy/rotated_detection_model.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Union + +import mmcv +import numpy as np +import torch +from mmcv.utils import Registry +from mmrotate.models.detectors import RotatedBaseDetector + +from mmdeploy.codebase.base import BaseBackendModel +from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ + get_classes_from_config +from mmdeploy.utils import (Backend, get_backend, get_codebase_config, + load_config) + + +def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs): + return registry.module_dict[cls_name](*args, **kwargs) + + +__BACKEND_MODEL = mmcv.utils.Registry( + 'backend_rotated_detectors', build_func=__build_backend_model) + + +@__BACKEND_MODEL.register_module('end2end') +class End2EndModel(BaseBackendModel): + """End to end model for inference of rotated detection. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files(e.g. + '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + class_names (Sequence[str]): A list of string specifying class names. + device (str): A string represents device type. + deploy_cfg (str | mmcv.Config): Deployment config file or loaded Config + object. + model_cfg (str | mmcv.Config): Model config file or loaded Config + object. + """ + + def __init__( + self, + backend: Backend, + backend_files: Sequence[str], + class_names: Sequence[str], + device: str, + deploy_cfg: Union[str, mmcv.Config] = None, + model_cfg: Union[str, mmcv.Config] = None, + ): + super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) + model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg) + self.CLASSES = class_names + self.deploy_cfg = deploy_cfg + self.device = device + self.show_score = False + self._init_wrapper( + backend=backend, backend_files=backend_files, device=device) + + def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], + device: str): + """Initialize the wrapper of backends. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files + (e.g. .onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + device (str): A string represents device type. + """ + output_names = self.output_names + self.wrapper = BaseBackendModel._build_wrapper( + backend=backend, + backend_files=backend_files, + device=device, + input_names=[self.input_name], + output_names=output_names, + deploy_cfg=self.deploy_cfg) + + def forward(self, img: Sequence[torch.Tensor], + img_metas: Sequence[Sequence[dict]], *args, **kwargs) -> list: + """Run forward inference. + + Args: + img (Sequence[torch.Tensor]): A list contains input image(s) + in [N x C x H x W] format. + img_metas (Sequence[Sequence[dict]]): A list of meta info for + image(s). + + Returns: + list: A list contains predictions. + """ + input_img = img[0].contiguous() + img_metas = img_metas[0] + outputs = self.forward_test(input_img, img_metas, *args, **kwargs) + batch_dets, batch_labels = outputs[:2] + batch_size = input_img.shape[0] + rescale = kwargs.get('rescale', False) + + results = [] + + for i in range(batch_size): + dets, labels = batch_dets[i], batch_labels[i] + if rescale: + scale_factor = img_metas[i]['scale_factor'] + + if isinstance(scale_factor, (list, tuple, np.ndarray)): + assert len(scale_factor) == 4 + scale_factor = np.array(scale_factor)[None, :] # [1,4] + scale_factor = torch.from_numpy(scale_factor).to( + device=dets.device) + dets[:, :4] /= scale_factor + dets = dets.cpu().numpy() + labels = labels.cpu().numpy() + dets_results = [ + dets[labels == i, :] for i in range(len(self.CLASSES)) + ] + results.append(dets_results) + + return results + + def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ + List[torch.Tensor]: + """The interface for forward test. + + Args: + imgs (torch.Tensor): Input image(s) in [N x C x H x W] format. + + Returns: + List[torch.Tensor]: A list of predictions of input images. + """ + outputs = self.wrapper({self.input_name: imgs}) + outputs = self.wrapper.output_to_list(outputs) + return outputs + + def show_result(self, + img: np.ndarray, + result: dict, + win_name: str = '', + show: bool = True, + score_thr: float = 0.3, + out_file: str = None): + """Show predictions of segmentation. + Args: + img: (np.ndarray): Input image to draw predictions. + result (dict): A dict of predictions. + win_name (str): The name of visualization window. + show (bool): Whether to show plotted image in windows. Defaults to + `True`. + score_thr: (float): The thresh of score. Defaults to `0.3`. + out_file (str): Output image file to save drawn predictions. + + Returns: + np.ndarray: Drawn image, only if not `show` or `out_file`. + """ + return RotatedBaseDetector.show_result( + self, + img, + result, + score_thr=score_thr, + show=show, + win_name=win_name, + out_file=out_file) + + +def build_rotated_detection_model(model_files: Sequence[str], + model_cfg: Union[str, mmcv.Config], + deploy_cfg: Union[str, mmcv.Config], + device: str, **kwargs): + """Build rotated detection model for different backends. + + Args: + model_files (Sequence[str]): Input model file(s). + model_cfg (str | mmcv.Config): Input model config file or Config + object. + deploy_cfg (str | mmcv.Config): Input deployment config file or + Config object. + device (str): Device to input model. + + Returns: + BaseBackendModel: Rotated detector for a configured backend. + """ + # load cfg if necessary + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) + + backend = get_backend(deploy_cfg) + class_names = get_classes_from_config(model_cfg) + model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end') + + backend_rotated_detector = __BACKEND_MODEL.build( + model_type, + backend=backend, + backend_files=model_files, + class_names=class_names, + device=device, + deploy_cfg=deploy_cfg, + model_cfg=model_cfg, + **kwargs) + + return backend_rotated_detector diff --git a/mmdeploy/codebase/mmrotate/models/__init__.py b/mmdeploy/codebase/mmrotate/models/__init__.py new file mode 100644 index 000000000..6fe59fd52 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .rotated_anchor_head import rotated_anchor_head__get_bbox +from .single_stage_rotated_detector import \ + single_stage_rotated_detector__simple_test + +__all__ = [ + 'single_stage_rotated_detector__simple_test', + 'rotated_anchor_head__get_bbox' +] diff --git a/mmdeploy/codebase/mmrotate/models/rotated_anchor_head.py b/mmdeploy/codebase/mmrotate/models/rotated_anchor_head.py new file mode 100644 index 000000000..15ffc8e15 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/rotated_anchor_head.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.codebase.mmdet import (get_post_processing_params, + pad_with_value_if_necessary) +from mmdeploy.codebase.mmrotate.core.post_processing import \ + multiclass_nms_rotated +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmrotate.models.dense_heads.rotated_anchor_head.' + 'RotatedAnchorHead.get_bboxes') +def rotated_anchor_head__get_bbox(ctx, + self, + cls_scores, + bbox_preds, + img_metas=None, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Rewrite `get_bboxes` of `RotatedAnchorHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + , 1., num_priors * 5, H, W). + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. Default None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default True. + + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, + batch_mlvl_scores, batch_mlvl_centerness + """ + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + mlvl_priors = self.anchor_generator.grid_priors( + featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + assert img_metas is not None + img_shape = img_metas[0]['img_shape'] + + assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors) + batch_size = cls_scores[0].shape[0] + cfg = self.test_cfg + pre_topk = cfg.get('nms_pre', -1) + + mlvl_valid_bboxes = [] + mlvl_valid_scores = [] + mlvl_valid_priors = [] + + for cls_score, bbox_pred, priors in zip(mlvl_cls_scores, mlvl_bbox_preds, + mlvl_priors): + + scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, + self.cls_out_channels) + if self.use_sigmoid_cls: + scores = scores.sigmoid() + else: + scores = scores.softmax(-1) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 5) + if not is_dynamic_flag: + priors = priors.data + priors = priors.expand(batch_size, -1, priors.size(-1)) + if pre_topk > 0: + priors = pad_with_value_if_necessary(priors, 1, pre_topk) + bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk) + scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.) + + nms_pre_score = scores + + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores, _ = nms_pre_score.max(-1) + else: + max_scores, _ = nms_pre_score[..., :-1].max(-1) + _, topk_inds = max_scores.topk(pre_topk) + batch_inds = torch.arange( + batch_size, + device=bbox_pred.device).view(-1, 1).expand_as(topk_inds) + priors = priors[batch_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + + mlvl_valid_bboxes.append(bbox_pred) + mlvl_valid_scores.append(scores) + mlvl_valid_priors.append(priors) + + batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1) + batch_scores = torch.cat(mlvl_valid_scores, dim=1) + batch_priors = torch.cat(mlvl_valid_priors, dim=1) + batch_bboxes = self.bbox_coder.decode( + batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape) + + if not self.use_sigmoid_cls: + batch_scores = batch_scores[..., :self.num_classes] + + if not with_nms: + return batch_bboxes, batch_scores + + post_params = get_post_processing_params(deploy_cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + return multiclass_nms_rotated( + batch_bboxes, + batch_scores, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/mmdeploy/codebase/mmrotate/models/single_stage_rotated_detector.py b/mmdeploy/codebase/mmrotate/models/single_stage_rotated_detector.py new file mode 100644 index 000000000..cfeb5aa9a --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/single_stage_rotated_detector.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmrotate.models.detectors.RotatedSingleStageDetector' + '.simple_test') +def single_stage_rotated_detector__simple_test(ctx, + self, + img, + img_metas, + rescale=False): + """Rewrite `simple_test` of RotatedSingleStageDetector for default backend. + + Rewrite this function to early return the results to avoid post processing. + The process is not suitable for exporting to backends and better get + implemented in SDK. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the class + SingleStageTextDetector. + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + + Returns: + outs (Tensor): A feature map output from bbox_head. The tensor shape + (N, C, H, W). + """ + x = self.extract_feat(img) + outs = self.bbox_head(x) + outs = self.bbox_head.get_bboxes(*outs, img_metas, rescale=rescale) + + return outs diff --git a/mmdeploy/mmcv/ops/nms_rotated.py b/mmdeploy/mmcv/ops/nms_rotated.py index 3ba763cda..3c1142632 100644 --- a/mmdeploy/mmcv/ops/nms_rotated.py +++ b/mmdeploy/mmcv/ops/nms_rotated.py @@ -7,31 +7,53 @@ class ONNXNMSRotatedOp(torch.autograd.Function): """Create onnx::NMSRotated op.""" @staticmethod - def forward(ctx, boxes: Tensor, scores: Tensor, - iou_threshold: float) -> Tensor: + def forward(ctx, boxes: Tensor, scores: Tensor, iou_threshold: float, + score_threshold: float) -> Tensor: """Get NMS rotated output indices. Args: ctx (Context): The context with meta information. - boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5]. scores (Tensor): The detection scores of shape - [N, num_boxes, num_classes]. + [N, num_classes, num_boxes]. iou_threshold (float): IOU threshold of nms. + score_threshold (float): bbox threshold, bboxes with scores + lower than it will not be considered. Returns: Tensor: Selected indices of boxes. """ from mmcv.utils import ext_loader ext_module = ext_loader.load_ext('_ext', ['nms_rotated']) + batch_size, num_class, _ = scores.shape - _, order = scores.sort(0, descending=True) - dets_sorted = boxes.index_select(0, order) - keep_inds = ext_module.nms_rotated(boxes, scores, order, dets_sorted, - iou_threshold, 0) - return keep_inds + indices = [] + for batch_id in range(batch_size): + for cls_id in range(num_class): + _boxes = boxes[batch_id, ...] + # score_threshold=0 requires scores to be contiguous + _scores = scores[batch_id, cls_id, ...].contiguous() + valid_mask = _scores > score_threshold + _boxes, _scores = _boxes[valid_mask], _scores[valid_mask] + valid_inds = torch.nonzero( + valid_mask, as_tuple=False).squeeze(dim=1) + _, order = _scores.sort(0, descending=True) + dets_sorted = _boxes.index_select(0, order) + box_inds = ext_module.nms_rotated(_boxes, _scores, order, + dets_sorted, iou_threshold, + 0) + box_inds = valid_inds[box_inds] + batch_inds = torch.zeros_like(box_inds) + batch_id + cls_inds = torch.zeros_like(box_inds) + cls_id + indices.append( + torch.stack([batch_inds, cls_inds, box_inds], dim=-1)) + + indices = torch.cat(indices) + return indices @staticmethod - def symbolic(g, boxes: Tensor, scores: Tensor, iou_threshold: float): + def symbolic(g, boxes: Tensor, scores: Tensor, iou_threshold: float, + score_threshold: float): """Symbolic function for onnx::NMSRotated. Args: @@ -40,6 +62,8 @@ class ONNXNMSRotatedOp(torch.autograd.Function): scores (Tensor): The detection scores of shape [N, num_boxes, num_classes]. iou_threshold (float): IOU threshold of nms. + score_threshold (float): bbox threshold, bboxes with scores + lower than it will not be considered. Returns: NMSRotated op for onnx. @@ -48,4 +72,5 @@ class ONNXNMSRotatedOp(torch.autograd.Function): 'mmdeploy::NMSRotated', boxes, scores, - iou_threshold_f=float(iou_threshold)) + iou_threshold_f=float(iou_threshold), + score_threshold_f=float(score_threshold)) diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index bddd09de8..6e43ccbd8 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -26,6 +26,7 @@ class Task(AdvancedEnum): INSTANCE_SEGMENTATION = 'InstanceSegmentation' VOXEL_DETECTION = 'VoxelDetection' POSE_DETECTION = 'PoseDetection' + ROTATED_DETECTION = 'RotatedDetection' class Codebase(AdvancedEnum): @@ -37,6 +38,7 @@ class Codebase(AdvancedEnum): MMEDIT = 'mmedit' MMDET3D = 'mmdet3d' MMPOSE = 'mmpose' + MMROTATE = 'mmrotate' class IR(AdvancedEnum): diff --git a/tests/test_codebase/test_mmrotate/data/dota_sample/P2805__1024__0___0.txt b/tests/test_codebase/test_mmrotate/data/dota_sample/P2805__1024__0___0.txt new file mode 100644 index 000000000..38b83aee2 --- /dev/null +++ b/tests/test_codebase/test_mmrotate/data/dota_sample/P2805__1024__0___0.txt @@ -0,0 +1,4 @@ +359.0 663.0 369.0 497.0 543.0 509.0 531.0 677.0 plane 0 +540.0 884.0 363.0 862.0 392.0 674.0 570.0 695.0 plane 0 +788.0 844.0 734.0 701.0 916.0 631.0 970.0 762.0 plane 0 +720.0 726.0 668.0 583.0 852.0 494.0 913.0 636.0 plane 0 diff --git a/tests/test_codebase/test_mmrotate/data/model.py b/tests/test_codebase/test_mmrotate/data/model.py new file mode 100644 index 000000000..a67ecfa29 --- /dev/null +++ b/tests/test_codebase/test_mmrotate/data/model.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +model = dict( + type='RotatedRetinaNet', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + zero_init_residual=False, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RotatedRetinaHead', + num_classes=15, + in_channels=256, + stacked_convs=4, + feat_channels=256, + assign_by_circumhbbox=None, + anchor_generator=dict( + type='RotatedAnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[1.0, 0.5, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHAOBBoxCoder', + angle_range='le135', + norm_factor=1, + edge_swap=False, + proj_xy=True, + target_means=(0.0, 0.0, 0.0, 0.0, 0.0), + target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBboxOverlaps2D')), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) +# dataset settings +dataset_type = 'DOTADataset' +data_root = '.' +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1024, 1024), + flip=False, + transforms=[ + dict(type='RResize'), + dict( + type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + test=dict( + type=dataset_type, + ann_file='tests/test_codebase/test_mmrotate/data/dota_sample/', + img_prefix=data_root, + pipeline=test_pipeline, + version='le135')) diff --git a/tests/test_codebase/test_mmrotate/data/single_stage_model.json b/tests/test_codebase/test_mmrotate/data/single_stage_model.json new file mode 100644 index 000000000..8727fa1c0 --- /dev/null +++ b/tests/test_codebase/test_mmrotate/data/single_stage_model.json @@ -0,0 +1,115 @@ +{ + "type": "RotatedRetinaNet", + "backbone": { + "type": "ResNet", + "depth": 50, + "num_stages": 4, + "out_indices": [ + 0, + 1, + 2, + 3 + ], + "frozen_stages": 1, + "norm_cfg": { + "type": "BN", + "requires_grad": true + }, + "norm_eval": true, + "style": "pytorch", + "init_cfg": { + "type": "Pretrained", + "checkpoint": "torchvision://resnet50" + } + }, + "neck": { + "type": "FPN", + "in_channels": [ + 256, + 512, + 1024, + 2048 + ], + "out_channels": 256, + "start_level": 1, + "add_extra_convs": "on_input", + "num_outs": 5 + }, + "bbox_head": { + "type": "RotatedRetinaHead", + "num_classes": 15, + "in_channels": 256, + "stacked_convs": 4, + "feat_channels": 256, + "anchor_generator": { + "type": "RotatedAnchorGenerator", + "octave_base_scale": 4, + "scales_per_octave": 3, + "ratios": [ + 0.5, + 1.0, + 2.0 + ], + "strides": [ + 8, + 16, + 32, + 64, + 128 + ] + }, + "bbox_coder": { + "type": "DeltaXYWHAOBBoxCoder", + "target_means": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "target_stds": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ] + }, + "loss_cls": { + "type": "FocalLoss", + "use_sigmoid": true, + "gamma": 2.0, + "alpha": 0.25, + "loss_weight": 1.0 + }, + "loss_bbox": { + "type": "L1Loss", + "loss_weight": 1.0 + } + }, + "train_cfg": { + "assigner": { + "type": "MaxIoUAssigner", + "pos_iou_thr": 0.5, + "neg_iou_thr": 0.4, + "min_pos_iou": 0, + "ignore_iof_thr": -1, + "iou_calculator": { + "type": "RBboxOverlaps2D" + } + }, + "allowed_border": -1, + "pos_weight": -1, + "debug": false + }, + "test_cfg": { + "nms_pre": 2000, + "min_bbox_size": 0, + "score_thr": 0.05, + "nms": { + "type": "nms", + "iou_threshold": 0.1 + }, + "max_per_img": 2000 + } +} diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py new file mode 100644 index 000000000..dad51cd1a --- /dev/null +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +import pytest +import torch + +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase +from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend, + get_onnx_model, get_rewrite_outputs) + +try: + import_codebase(Codebase.MMROTATE) +except ImportError: + pytest.skip( + f'{Codebase.MMROTATE} is not installed.', allow_module_level=True) + + +@backend_checker(Backend.ONNXRUNTIME) +def test_multiclass_nms_rotated(): + from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(output_names=None, input_shape=None), + backend_config=dict( + type='onnxruntime', + common_config=dict( + fp16_mode=False, max_workspace_size=1 << 20), + model_inputs=[ + dict( + input_shapes=dict( + boxes=dict( + min_shape=[1, 5, 5], + opt_shape=[1, 5, 5], + max_shape=[1, 5, 5]), + scores=dict( + min_shape=[1, 5, 8], + opt_shape=[1, 5, 8], + max_shape=[1, 5, 8]))) + ]), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + pre_top_k=-1, + keep_top_k=10, + )))) + + boxes = torch.rand(1, 5, 5) + scores = torch.rand(1, 5, 8) + keep_top_k = 10 + wrapped_func = WrapFunction(multiclass_nms_rotated, keep_top_k=keep_top_k) + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_func, + model_inputs={ + 'boxes': boxes, + 'scores': scores + }, + deploy_cfg=deploy_cfg) + + assert rewrite_outputs is not None, 'Got unexpected rewrite '\ + 'outputs: {}'.format(rewrite_outputs) + + +@backend_checker(Backend.ONNXRUNTIME) +@pytest.mark.parametrize('pre_top_k', [-1, 1000]) +def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k): + backend_type = 'onnxruntime' + + from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated + keep_top_k = 15 + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict( + output_names=None, + input_shape=None, + dynamic_axes=dict( + boxes={ + 0: 'batch_size', + 1: 'num_boxes' + }, + scores={ + 0: 'batch_size', + 1: 'num_boxes', + 2: 'num_classes' + }, + ), + ), + backend_config=dict(type=backend_type), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + )))) + + num_classes = 5 + num_boxes = 2 + batch_size = 1 + export_boxes = torch.rand(batch_size, num_boxes, 5) + export_scores = torch.ones(batch_size, num_boxes, num_classes) + model_inputs = {'boxes': export_boxes, 'scores': export_scores} + + wrapped_func = WrapFunction(multiclass_nms_rotated, keep_top_k=keep_top_k) + + onnx_model_path = get_onnx_model( + wrapped_func, model_inputs=model_inputs, deploy_cfg=deploy_cfg) + + num_boxes = 100 + test_boxes = torch.rand(batch_size, num_boxes, 5) + test_scores = torch.ones(batch_size, num_boxes, num_classes) + model_inputs = {'boxes': test_boxes, 'scores': test_scores} + + import mmdeploy.backend.onnxruntime as ort_apis + backend_model = ort_apis.ORTWrapper(onnx_model_path, 'cuda:0', None) + output = backend_model.forward(model_inputs) + output = backend_model.output_to_list(output) + dets = output[0] + + # Subtract 1 dim since we pad the tensors + assert dets.shape[1] - 1 < keep_top_k, \ + 'multiclass_nms_rotated returned more values than "keep_top_k"\n' \ + f'dets.shape: {dets.shape}\n' \ + f'keep_top_k: {keep_top_k}' + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +@pytest.mark.parametrize('add_ctr_clamp', [True, False]) +@pytest.mark.parametrize('max_shape,proj_xy,edge_swap', + [(None, False, False), + (torch.tensor([100, 200]), True, True)]) +def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool, + max_shape: tuple, proj_xy: bool, edge_swap: bool): + check_backend(backend_type) + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(output_names=None, input_shape=None), + backend_config=dict(type=backend_type.value, model_inputs=None), + codebase_config=dict(type='mmrotate', task='RotatedDetection'))) + + # wrap function to enable rewrite + def delta2bbox(*args, **kwargs): + import mmrotate + return mmrotate.core.bbox.coder.delta_xywha_rbbox_coder.delta2bbox( + *args, **kwargs) + + rois = torch.rand(5, 5) + deltas = torch.rand(5, 5) + original_outputs = delta2bbox( + rois, + deltas, + max_shape=max_shape, + add_ctr_clamp=add_ctr_clamp, + proj_xy=proj_xy, + edge_swap=edge_swap) + + # wrap function to nn.Module, enable torch.onnx.export + wrapped_func = WrapFunction( + delta2bbox, + max_shape=max_shape, + add_ctr_clamp=add_ctr_clamp, + proj_xy=proj_xy, + edge_swap=edge_swap) + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_func, + model_inputs={ + 'rois': rois.unsqueeze(0), + 'deltas': deltas.unsqueeze(0) + }, + deploy_cfg=deploy_cfg) + + if is_backend_output: + model_output = original_outputs.squeeze().cpu().numpy() + rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + else: + assert rewrite_outputs is not None diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py new file mode 100644 index 000000000..501411ce7 --- /dev/null +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import random +from typing import Dict, List + +import mmcv +import numpy as np +import pytest +import torch + +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase +from mmdeploy.utils.config_utils import get_ir_config +from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs, + get_rewrite_outputs) + +try: + import_codebase(Codebase.MMROTATE) +except ImportError: + pytest.skip( + f'{Codebase.MMROTATE} is not installed.', allow_module_level=True) + + +def seed_everything(seed=1029): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enabled = False + + +def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List: + """Converts output from a dictionary to a list. + + The new list will contain only those output values, whose names are in list + 'output_names'. + """ + outputs = [ + value for name, value in rewrite_output.items() if name in output_names + ] + return outputs + + +def get_anchor_head_model(): + """AnchorHead Config.""" + test_cfg = mmcv.Config( + dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + + from mmrotate.models.dense_heads import RotatedAnchorHead + model = RotatedAnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg) + model.requires_grad_(False) + + return model + + +def _replace_r50_with_r18(model): + """Replace ResNet50 with ResNet18 in config.""" + model = copy.deepcopy(model) + if model.backbone.type == 'ResNet': + model.backbone.depth = 18 + model.backbone.base_channels = 2 + model.neck.in_channels = [2, 4, 8, 16] + return model + + +# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) +# @pytest.mark.parametrize('model_cfg_path', [ +# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json' +# ]) +# def test_forward_of_base_detector(model_cfg_path, backend): +# check_backend(backend) +# deploy_cfg = mmcv.Config( +# dict( +# backend_config=dict(type=backend.value), +# onnx_config=dict( +# output_names=['dets', 'labels'], input_shape=None), +# codebase_config=dict( +# type='mmrotate', +# task='RotatedDetection', +# post_processing=dict( +# score_threshold=0.05, +# iou_threshold=0.5, +# pre_top_k=-1, +# keep_top_k=100, +# )))) + +# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path))) +# model_cfg.model = _replace_r50_with_r18(model_cfg.model) + +# from mmrotate.models import build_detector + +# model_cfg.model.pretrained = None +# model_cfg.model.train_cfg = None +# model = build_detector( +# model_cfg.model, test_cfg= model_cfg.get('test_cfg')) +# model.cfg = model_cfg +# model.to('cpu') + +# img = torch.randn(1, 3, 64, 64) +# rewrite_inputs = {'img': img} +# rewrite_outputs, _ = get_rewrite_outputs( +# wrapped_model=model, +# model_inputs=rewrite_inputs, +# deploy_cfg=deploy_cfg) + +# assert rewrite_outputs is not None + + +def get_deploy_cfg(backend_type: Backend, ir_type: str): + return mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict( + type=ir_type, + output_names=['dets', 'labels'], + input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=2000, + keep_top_k=2000, + )))) + + +@pytest.mark.parametrize('backend_type, ir_type', + [(Backend.ONNXRUNTIME, 'onnx')]) +def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str): + """Test get_bboxes rewrite of base dense head.""" + check_backend(backend_type) + anchor_head = get_anchor_head_model() + anchor_head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + deploy_cfg = get_deploy_cfg(backend_type, ir_type) + output_names = get_ir_config(deploy_cfg).get('output_names', None) + + # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16), + # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2). + # the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16), + # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2) + seed_everything(1234) + cls_score = [ + torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(5678) + bboxes = [torch.rand(1, 45, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] + + # to get outputs of pytorch model + model_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + 'img_metas': img_metas + } + model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs) + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = torch.Tensor([s, s]) + wrapped_model = WrapModel( + anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + if isinstance(rewrite_outputs, dict): + rewrite_outputs = convert_to_list(rewrite_outputs, output_names) + for model_output, rewrite_output in zip(model_outputs[0], + rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + # hard code to make two tensors with the same shape + # rewrite and original codes applied different nms strategy + assert np.allclose( + model_output[:rewrite_output.shape[0]][:2], + rewrite_output[:2], + rtol=1e-03, + atol=1e-05) + else: + assert rewrite_outputs is not None diff --git a/tests/test_codebase/test_mmrotate/test_rotated_detection.py b/tests/test_codebase/test_mmrotate/test_rotated_detection.py new file mode 100644 index 000000000..34b921a73 --- /dev/null +++ b/tests/test_codebase/test_mmrotate/test_rotated_detection.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from tempfile import NamedTemporaryFile, TemporaryDirectory + +import mmcv +import numpy as np +import pytest +import torch +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +import mmdeploy.backend.onnxruntime as ort_apis +from mmdeploy.apis import build_task_processor +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Codebase, load_config +from mmdeploy.utils.test import DummyModel, SwitchBackendWrapper + +try: + import_codebase(Codebase.MMROTATE) +except ImportError: + pytest.skip( + f'{Codebase.MMROTATE} is not installed.', allow_module_level=True) + +model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' +model_cfg = load_config(model_cfg_path)[0] +deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=2000, + keep_top_k=2000)), + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + input_shape=None, + input_names=['input'], + output_names=['dets', 'labels']))) +onnx_file = NamedTemporaryFile(suffix='.onnx').name +task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') +img_shape = (32, 32) +img = np.random.rand(*img_shape, 3) + + +def test_init_pytorch_model(): + from mmrotate.models import RotatedBaseDetector + model = task_processor.init_pytorch_model(None) + assert isinstance(model, RotatedBaseDetector) + + +@pytest.fixture +def backend_model(): + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + wrapper = SwitchBackendWrapper(ORTWrapper) + wrapper.set(outputs={ + 'dets': torch.rand(1, 10, 6), + 'labels': torch.rand(1, 10) + }) + + yield task_processor.init_backend_model(['']) + + wrapper.recover() + + +def test_init_backend_model(backend_model): + from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \ + End2EndModel + assert isinstance(backend_model, End2EndModel) + + +@pytest.mark.parametrize('device', ['cpu']) +def test_create_input(device): + original_device = task_processor.device + task_processor.device = device + inputs = task_processor.create_input(img, input_shape=img_shape) + assert len(inputs) == 2 + task_processor.device = original_device + + +def test_run_inference(backend_model): + torch_model = task_processor.init_pytorch_model(None) + input_dict, _ = task_processor.create_input(img, input_shape=img_shape) + torch_results = task_processor.run_inference(torch_model, input_dict) + backend_results = task_processor.run_inference(backend_model, input_dict) + assert torch_results is not None + assert backend_results is not None + assert len(torch_results[0]) == len(backend_results[0]) + + +def test_visualize(backend_model): + input_dict, _ = task_processor.create_input(img, input_shape=img_shape) + results = task_processor.run_inference(backend_model, input_dict) + with TemporaryDirectory() as dir: + filename = dir + 'tmp.jpg' + task_processor.visualize(backend_model, img, results[0], filename, '') + assert os.path.exists(filename) + + +def test_get_partition_cfg(): + with pytest.raises(NotImplementedError): + _ = task_processor.get_partition_cfg(partition_type='') + + +def test_build_dataset_and_dataloader(): + dataset = task_processor.build_dataset( + dataset_cfg=model_cfg, dataset_type='test') + assert isinstance(dataset, Dataset), 'Failed to build dataset' + dataloader = task_processor.build_dataloader(dataset, 1, 1) + assert isinstance(dataloader, DataLoader), 'Failed to build dataloader' + + +def test_single_gpu_test_and_evaluate(): + from mmcv.parallel import MMDataParallel + + class DummyDataset(Dataset): + + def __getitem__(self, index): + return 0 + + def __len__(self): + return 0 + + def evaluate(self, *args, **kwargs): + return 0 + + def format_results(self, *args, **kwargs): + return 0 + + dataset = DummyDataset() + # Prepare dataloader + dataloader = DataLoader(dataset) + + # Prepare dummy model + model = DummyModel(outputs=[torch.rand([1, 10, 6]), torch.rand([1, 10])]) + model = MMDataParallel(model, device_ids=[0]) + # Run test + outputs = task_processor.single_gpu_test(model, dataloader) + assert isinstance(outputs, list) + output_file = NamedTemporaryFile(suffix='.pkl').name + task_processor.evaluate_outputs( + model_cfg, outputs, dataset, 'bbox', out=output_file, format_only=True) diff --git a/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py b/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py new file mode 100644 index 000000000..d13617488 --- /dev/null +++ b/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from tempfile import NamedTemporaryFile + +import mmcv +import numpy as np +import pytest +import torch + +import mmdeploy.backend.onnxruntime as ort_apis +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase, load_config +from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker + +try: + import_codebase(Codebase.MMROTATE) +except ImportError: + pytest.skip( + f'{Codebase.MMROTATE} is not installed.', allow_module_level=True) + +IMAGE_SIZE = 32 + + +@backend_checker(Backend.ONNXRUNTIME) +class TestEnd2EndModel: + + @classmethod + def setup_class(cls): + # force add backend wrapper regardless of plugins + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + cls.wrapper = SwitchBackendWrapper(ORTWrapper) + cls.outputs = { + 'dets': torch.rand(1, 10, 6), + 'labels': torch.rand(1, 10) + } + cls.wrapper.set(outputs=cls.outputs) + deploy_cfg = mmcv.Config( + {'onnx_config': { + 'output_names': ['dets', 'labels'] + }}) + model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' + model_cfg = load_config(model_cfg_path)[0] + + from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \ + End2EndModel + cls.end2end_model = End2EndModel( + Backend.ONNXRUNTIME, [''], ['' for i in range(15)], + device='cpu', + deploy_cfg=deploy_cfg, + model_cfg=model_cfg) + + @classmethod + def teardown_class(cls): + cls.wrapper.recover() + + @pytest.mark.parametrize( + 'ori_shape', + [[IMAGE_SIZE, IMAGE_SIZE, 3], [2 * IMAGE_SIZE, 2 * IMAGE_SIZE, 3]]) + def test_forward(self, ori_shape): + imgs = [torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)] + img_metas = [[{ + 'ori_shape': ori_shape, + 'img_shape': [IMAGE_SIZE, IMAGE_SIZE, 3], + 'scale_factor': [1., 1., 1., 1.], + 'filename': '' + }]] + results = self.end2end_model.forward(imgs, img_metas) + assert results is not None, 'failed to get output using '\ + 'End2EndModel' + + def test_forward_test(self): + imgs = torch.rand(2, 3, IMAGE_SIZE, IMAGE_SIZE) + results = self.end2end_model.forward_test(imgs) + assert isinstance(results[0], torch.Tensor) + + def test_show_result(self): + input_img = np.zeros([IMAGE_SIZE, IMAGE_SIZE, 3]) + img_path = NamedTemporaryFile(suffix='.jpg').name + + result = torch.rand(1, 10, 6) + self.end2end_model.show_result( + input_img, result, '', show=False, out_file=img_path) + assert osp.exists(img_path) + + +@backend_checker(Backend.ONNXRUNTIME) +def test_build_rotated_detection_model(): + model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' + model_cfg = load_config(model_cfg_path)[0] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='onnxruntime'), + onnx_config=dict(output_names=['dets', 'labels']), + codebase_config=dict(type='mmrotate'))) + + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import ( + End2EndModel, build_rotated_detection_model) + segmentor = build_rotated_detection_model([''], model_cfg, deploy_cfg, + 'cpu') + assert isinstance(segmentor, End2EndModel) diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 12596180d..bf191eaf0 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -778,18 +778,24 @@ def test_expand(backend, @pytest.mark.parametrize('backend', [TEST_ONNXRT]) @pytest.mark.parametrize('iou_threshold', [0.1, 0.3]) -def test_nms_rotated(backend, iou_threshold, save_dir=None): +@pytest.mark.parametrize('score_threshold', [0., 0.1]) +def test_nms_rotated(backend, iou_threshold, score_threshold, save_dir=None): backend.check_env() boxes = torch.tensor( - [[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]], + [[[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]], + [[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]]], + dtype=torch.float32) + scores = torch.tensor( + [[[0.5, 0.1, 0.1], [0.1, 0.6, 0.1], [0.1, 0.1, 0.7], [0.1, 0.1, 0.1]], + [[0.1, 0.1, 0.1], [0.7, 0.1, 0.1], [0.1, 0.6, 0.1], [0.1, 0.1, 0.5]]], dtype=torch.float32) - scores = torch.tensor([0.5, 0.6, 0.7], dtype=torch.float32) from mmdeploy.mmcv.ops import ONNXNMSRotatedOp def wrapped_function(torch_boxes, torch_scores): - return ONNXNMSRotatedOp.apply(torch_boxes, torch_scores, iou_threshold) + return ONNXNMSRotatedOp.apply(torch_boxes, torch_scores, iou_threshold, + score_threshold) wrapped_model = WrapFunction(wrapped_function).eval()