fix two stage batch dynamic (#1046)
parent
615668ec63
commit
441d0e2703
|
@ -44,7 +44,6 @@ def bbox_test_mixin__simple_test_bboxes(ctx,
|
|||
rois.size(0), rois.size(1), 1)
|
||||
rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
|
||||
batch_size = rois.shape[0]
|
||||
num_proposals_per_img = rois.shape[1]
|
||||
|
||||
# Eliminate the batch dimension
|
||||
rois = rois.view(-1, 5)
|
||||
|
@ -53,12 +52,10 @@ def bbox_test_mixin__simple_test_bboxes(ctx,
|
|||
bbox_pred = bbox_results['bbox_pred']
|
||||
|
||||
# Recover the batch dimension
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
|
||||
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
|
||||
cls_score.size(-1))
|
||||
rois = rois.reshape(batch_size, -1, rois.size(-1))
|
||||
cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1))
|
||||
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
|
||||
bbox_pred.size(-1))
|
||||
bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1))
|
||||
det_bboxes, det_labels = self.bbox_head.get_bboxes(
|
||||
rois,
|
||||
cls_score,
|
||||
|
|
Loading…
Reference in New Issue