[Fix] Fix CascadeRoIHead export when reg_class_agnostic=True in box_head ()

* fix convnext

* fix batch inference

* update docs

* add regression test config

* fix pose_tracker.cpp lint
pull/1944/head
Chen Xin 2023-03-28 20:59:26 +08:00 committed by GitHub
parent d181311dee
commit c7003bb76a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 22 deletions
csrc/mmdeploy/apis/python
docs/en/03-benchmark
mmdeploy/codebase/mmdet/models/roi_heads
tests/regression

View File

@ -30,7 +30,8 @@ std::vector<py::tuple> Apply(mmdeploy::PoseTracker* self,
std::vector<py::tuple> batch_ret;
batch_ret.reserve(frames.size());
for (const auto& rs : results) {
py::array_t<float> keypoints({static_cast<int>(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3});
py::array_t<float> keypoints(
{static_cast<int>(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3});
py::array_t<float> bboxes({static_cast<int>(rs.size()), 4});
py::array_t<uint32_t> track_ids(static_cast<int>(rs.size()));
auto kpts_ptr = keypoints.mutable_data();

View File

@ -17,6 +17,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| ConvNeXt | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/convnext) |
| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
| VFNet | MMDetection | N | N | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
| RepPoints | MMDetection | N | N | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |

View File

@ -42,15 +42,18 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
'while in exporting to ONNX'
# Remove the scores
rois = proposals[..., :-1]
batch_size = rois.shape[0]
num_proposals_per_img = rois.shape[1]
batch_size = rois.shape[0]
# Eliminate the batch dimension
rois = rois.view(-1, 4)
inds = torch.arange(
batch_size, device=rois.device).float().repeat(num_proposals_per_img,
1)
inds = inds.t().reshape(-1, 1)
rois = torch.cat([inds, rois], dim=1)
# Add dummy batch index
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
max_shape = img_metas[0]['img_shape']
max_shape = None
scale_factor = None
ms_scores = []
rcnn_test_cfg = self.test_cfg
@ -59,24 +62,19 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
# Recover the batch dimension
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
cls_score.size(-1))
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
ms_scores.append(cls_score)
if i < self.num_stages - 1:
assert self.bbox_head[i].reg_class_agnostic
new_rois = self.bbox_head[i].bbox_coder.decode(
rois[..., 1:], bbox_pred, max_shape=max_shape)
rois = new_rois.reshape(-1, new_rois.shape[-1])
# Add dummy batch index
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
assert not self.bbox_head[i].custom_activation
bbox_label = cls_score[:, :-1].argmax(dim=1)
rois = self.bbox_head[i].regress_by_class(rois, bbox_label,
bbox_pred, img_metas[0])
cls_score = sum(ms_scores) / float(len(ms_scores))
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
scale_factor = img_metas[0].get('scale_factor', None)
cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1))
rois = rois.reshape(batch_size, -1, rois.size(-1))
bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1))
det_bboxes, det_labels = self.bbox_head[-1].get_bboxes(
rois, cls_score, bbox_pred, max_shape, scale_factor, cfg=rcnn_test_cfg)
@ -85,8 +83,8 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
else:
batch_index = torch.arange(det_bboxes.size(0),
device=det_bboxes.device). \
float().view(-1, 1, 1).expand(
det_bboxes.size(0), det_bboxes.size(1), 1)
float().view(-1, 1, 1).expand(
det_bboxes.size(0), det_bboxes.size(1), 1)
rois = det_bboxes[..., :4]
mask_rois = torch.cat([batch_index, rois], dim=-1)
mask_rois = mask_rois.view(-1, 5)

View File

@ -320,3 +320,12 @@ models:
pipelines:
- *pipeline_seg_ort_dynamic_fp32
- *pipeline_seg_trt_dynamic_fp32
- name: Convnext
metafile: configs/convnext/metafile.yml
model_configs:
- configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py
pipelines:
- *pipeline_seg_ort_dynamic_fp32
- *pipeline_seg_trt_dynamic_fp32
- *pipeline_seg_openvino_dynamic_fp32