From 441d0e27037f76dc4686783ae0806814aa664889 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 14 Sep 2022 17:48:34 +0800 Subject: [PATCH] fix two stage batch dynamic (#1046) --- mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py b/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py index 18dcca633..e21b2acd3 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py @@ -44,7 +44,6 @@ def bbox_test_mixin__simple_test_bboxes(ctx, rois.size(0), rois.size(1), 1) rois = torch.cat([batch_index, rois[..., :4]], dim=-1) batch_size = rois.shape[0] - num_proposals_per_img = rois.shape[1] # Eliminate the batch dimension rois = rois.view(-1, 5) @@ -53,12 +52,10 @@ def bbox_test_mixin__simple_test_bboxes(ctx, bbox_pred = bbox_results['bbox_pred'] # Recover the batch dimension - rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) - cls_score = cls_score.reshape(batch_size, num_proposals_per_img, - cls_score.size(-1)) + rois = rois.reshape(batch_size, -1, rois.size(-1)) + cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1)) - bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, - bbox_pred.size(-1)) + bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1)) det_bboxes, det_labels = self.bbox_head.get_bboxes( rois, cls_score,