[Fix] Fix CascadeRoIHead export when reg_class_agnostic=True in box_head (#1900)
* fix convnext * fix batch inference * update docs * add regression test config * fix pose_tracker.cpp lintpull/1944/head
parent
d181311dee
commit
c7003bb76a
csrc/mmdeploy/apis/python
docs/en/03-benchmark
mmdeploy/codebase/mmdet/models/roi_heads
tests/regression
|
@ -30,7 +30,8 @@ std::vector<py::tuple> Apply(mmdeploy::PoseTracker* self,
|
|||
std::vector<py::tuple> batch_ret;
|
||||
batch_ret.reserve(frames.size());
|
||||
for (const auto& rs : results) {
|
||||
py::array_t<float> keypoints({static_cast<int>(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3});
|
||||
py::array_t<float> keypoints(
|
||||
{static_cast<int>(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3});
|
||||
py::array_t<float> bboxes({static_cast<int>(rs.size()), 4});
|
||||
py::array_t<uint32_t> track_ids(static_cast<int>(rs.size()));
|
||||
auto kpts_ptr = keypoints.mutable_data();
|
||||
|
|
|
@ -17,6 +17,7 @@ The table below lists the models that are guaranteed to be exportable to other b
|
|||
| GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| ConvNeXt | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/convnext) |
|
||||
| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
| VFNet | MMDetection | N | N | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
|
||||
| RepPoints | MMDetection | N | N | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
|
||||
|
|
|
@ -42,15 +42,18 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
|||
'while in exporting to ONNX'
|
||||
# Remove the scores
|
||||
rois = proposals[..., :-1]
|
||||
batch_size = rois.shape[0]
|
||||
num_proposals_per_img = rois.shape[1]
|
||||
batch_size = rois.shape[0]
|
||||
# Eliminate the batch dimension
|
||||
rois = rois.view(-1, 4)
|
||||
inds = torch.arange(
|
||||
batch_size, device=rois.device).float().repeat(num_proposals_per_img,
|
||||
1)
|
||||
inds = inds.t().reshape(-1, 1)
|
||||
rois = torch.cat([inds, rois], dim=1)
|
||||
|
||||
# Add dummy batch index
|
||||
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
|
||||
|
||||
max_shape = img_metas[0]['img_shape']
|
||||
max_shape = None
|
||||
scale_factor = None
|
||||
ms_scores = []
|
||||
rcnn_test_cfg = self.test_cfg
|
||||
|
||||
|
@ -59,24 +62,19 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
|||
|
||||
cls_score = bbox_results['cls_score']
|
||||
bbox_pred = bbox_results['bbox_pred']
|
||||
# Recover the batch dimension
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
|
||||
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
|
||||
cls_score.size(-1))
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
|
||||
|
||||
ms_scores.append(cls_score)
|
||||
if i < self.num_stages - 1:
|
||||
assert self.bbox_head[i].reg_class_agnostic
|
||||
new_rois = self.bbox_head[i].bbox_coder.decode(
|
||||
rois[..., 1:], bbox_pred, max_shape=max_shape)
|
||||
rois = new_rois.reshape(-1, new_rois.shape[-1])
|
||||
# Add dummy batch index
|
||||
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
|
||||
assert not self.bbox_head[i].custom_activation
|
||||
bbox_label = cls_score[:, :-1].argmax(dim=1)
|
||||
rois = self.bbox_head[i].regress_by_class(rois, bbox_label,
|
||||
bbox_pred, img_metas[0])
|
||||
|
||||
cls_score = sum(ms_scores) / float(len(ms_scores))
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
|
||||
scale_factor = img_metas[0].get('scale_factor', None)
|
||||
cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1))
|
||||
rois = rois.reshape(batch_size, -1, rois.size(-1))
|
||||
bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1))
|
||||
|
||||
det_bboxes, det_labels = self.bbox_head[-1].get_bboxes(
|
||||
rois, cls_score, bbox_pred, max_shape, scale_factor, cfg=rcnn_test_cfg)
|
||||
|
||||
|
@ -85,8 +83,8 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
|||
else:
|
||||
batch_index = torch.arange(det_bboxes.size(0),
|
||||
device=det_bboxes.device). \
|
||||
float().view(-1, 1, 1).expand(
|
||||
det_bboxes.size(0), det_bboxes.size(1), 1)
|
||||
float().view(-1, 1, 1).expand(
|
||||
det_bboxes.size(0), det_bboxes.size(1), 1)
|
||||
rois = det_bboxes[..., :4]
|
||||
mask_rois = torch.cat([batch_index, rois], dim=-1)
|
||||
mask_rois = mask_rois.view(-1, 5)
|
||||
|
|
|
@ -320,3 +320,12 @@ models:
|
|||
pipelines:
|
||||
- *pipeline_seg_ort_dynamic_fp32
|
||||
- *pipeline_seg_trt_dynamic_fp32
|
||||
|
||||
- name: Convnext
|
||||
metafile: configs/convnext/metafile.yml
|
||||
model_configs:
|
||||
- configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py
|
||||
pipelines:
|
||||
- *pipeline_seg_ort_dynamic_fp32
|
||||
- *pipeline_seg_trt_dynamic_fp32
|
||||
- *pipeline_seg_openvino_dynamic_fp32
|
||||
|
|
Loading…
Reference in New Issue