From c7003bb76ad2d848aaa9b9bb41d712536987a7df Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Tue, 28 Mar 2023 20:59:26 +0800 Subject: [PATCH] [Fix] Fix CascadeRoIHead export when reg_class_agnostic=True in box_head (#1900) * fix convnext * fix batch inference * update docs * add regression test config * fix pose_tracker.cpp lint --- csrc/mmdeploy/apis/python/pose_tracker.cpp | 3 +- docs/en/03-benchmark/supported_models.md | 1 + .../models/roi_heads/cascade_roi_head.py | 40 +++++++++---------- tests/regression/mmdet.yml | 9 +++++ 4 files changed, 31 insertions(+), 22 deletions(-) diff --git a/csrc/mmdeploy/apis/python/pose_tracker.cpp b/csrc/mmdeploy/apis/python/pose_tracker.cpp index 16c79bbfa..035ce3cdd 100644 --- a/csrc/mmdeploy/apis/python/pose_tracker.cpp +++ b/csrc/mmdeploy/apis/python/pose_tracker.cpp @@ -30,7 +30,8 @@ std::vector Apply(mmdeploy::PoseTracker* self, std::vector batch_ret; batch_ret.reserve(frames.size()); for (const auto& rs : results) { - py::array_t keypoints({static_cast(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3}); + py::array_t keypoints( + {static_cast(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3}); py::array_t bboxes({static_cast(rs.size()), 4}); py::array_t track_ids(static_cast(rs.size())); auto kpts_ptr = keypoints.mutable_data(); diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index c955be8a1..1283e033a 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -17,6 +17,7 @@ The table below lists the models that are guaranteed to be exportable to other b | GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | | Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | +| ConvNeXt | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/convnext) | | Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) | | VFNet | MMDetection | N | N | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | | RepPoints | MMDetection | N | N | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) | diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py index 179643abb..f6ede1d0d 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py @@ -42,15 +42,18 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas, 'while in exporting to ONNX' # Remove the scores rois = proposals[..., :-1] - batch_size = rois.shape[0] num_proposals_per_img = rois.shape[1] + batch_size = rois.shape[0] # Eliminate the batch dimension rois = rois.view(-1, 4) + inds = torch.arange( + batch_size, device=rois.device).float().repeat(num_proposals_per_img, + 1) + inds = inds.t().reshape(-1, 1) + rois = torch.cat([inds, rois], dim=1) - # Add dummy batch index - rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1) - - max_shape = img_metas[0]['img_shape'] + max_shape = None + scale_factor = None ms_scores = [] rcnn_test_cfg = self.test_cfg @@ -59,24 +62,19 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas, cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] - # Recover the batch dimension - rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) - cls_score = cls_score.reshape(batch_size, num_proposals_per_img, - cls_score.size(-1)) - bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4) + ms_scores.append(cls_score) if i < self.num_stages - 1: - assert self.bbox_head[i].reg_class_agnostic - new_rois = self.bbox_head[i].bbox_coder.decode( - rois[..., 1:], bbox_pred, max_shape=max_shape) - rois = new_rois.reshape(-1, new_rois.shape[-1]) - # Add dummy batch index - rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1) + assert not self.bbox_head[i].custom_activation + bbox_label = cls_score[:, :-1].argmax(dim=1) + rois = self.bbox_head[i].regress_by_class(rois, bbox_label, + bbox_pred, img_metas[0]) cls_score = sum(ms_scores) / float(len(ms_scores)) - bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4) - rois = rois.reshape(batch_size, num_proposals_per_img, -1) - scale_factor = img_metas[0].get('scale_factor', None) + cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1)) + rois = rois.reshape(batch_size, -1, rois.size(-1)) + bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1)) + det_bboxes, det_labels = self.bbox_head[-1].get_bboxes( rois, cls_score, bbox_pred, max_shape, scale_factor, cfg=rcnn_test_cfg) @@ -85,8 +83,8 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas, else: batch_index = torch.arange(det_bboxes.size(0), device=det_bboxes.device). \ - float().view(-1, 1, 1).expand( - det_bboxes.size(0), det_bboxes.size(1), 1) + float().view(-1, 1, 1).expand( + det_bboxes.size(0), det_bboxes.size(1), 1) rois = det_bboxes[..., :4] mask_rois = torch.cat([batch_index, rois], dim=-1) mask_rois = mask_rois.view(-1, 5) diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 2dc8408cc..bf3c7597b 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -320,3 +320,12 @@ models: pipelines: - *pipeline_seg_ort_dynamic_fp32 - *pipeline_seg_trt_dynamic_fp32 + + - name: Convnext + metafile: configs/convnext/metafile.yml + model_configs: + - configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py + pipelines: + - *pipeline_seg_ort_dynamic_fp32 + - *pipeline_seg_trt_dynamic_fp32 + - *pipeline_seg_openvino_dynamic_fp32