fix two stage batch dynamic (#1046)

pull/1050/head
AllentDan 2022-09-14 17:48:34 +08:00 committed by GitHub
parent 615668ec63
commit 441d0e2703
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 6 deletions

View File

@ -44,7 +44,6 @@ def bbox_test_mixin__simple_test_bboxes(ctx,
rois.size(0), rois.size(1), 1) rois.size(0), rois.size(1), 1)
rois = torch.cat([batch_index, rois[..., :4]], dim=-1) rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
batch_size = rois.shape[0] batch_size = rois.shape[0]
num_proposals_per_img = rois.shape[1]
# Eliminate the batch dimension # Eliminate the batch dimension
rois = rois.view(-1, 5) rois = rois.view(-1, 5)
@ -53,12 +52,10 @@ def bbox_test_mixin__simple_test_bboxes(ctx,
bbox_pred = bbox_results['bbox_pred'] bbox_pred = bbox_results['bbox_pred']
# Recover the batch dimension # Recover the batch dimension
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) rois = rois.reshape(batch_size, -1, rois.size(-1))
cls_score = cls_score.reshape(batch_size, num_proposals_per_img, cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1))
cls_score.size(-1))
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1))
bbox_pred.size(-1))
det_bboxes, det_labels = self.bbox_head.get_bboxes( det_bboxes, det_labels = self.bbox_head.get_bboxes(
rois, rois,
cls_score, cls_score,