Fix batch index for cascade roi head ()

pull/2064/head^2
q.yao 2023-05-15 17:14:06 +08:00 committed by GitHub
parent 5fd0e8957f
commit 389a146212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions
mmdeploy/codebase/mmdet/models/roi_heads

View File

@ -69,7 +69,7 @@ def cascade_roi_head__predict_bbox(self,
new_rois = get_box_tensor(new_rois)
rois = new_rois.reshape(-1, new_rois.shape[-1])
# Add dummy batch index
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
rois = torch.cat([batch_index.flatten(0, 1), rois], dim=-1)
cls_scores = sum(ms_scores) / float(len(ms_scores))
bbox_preds = bbox_pred.reshape(batch_size, num_proposals_per_img, -1)
rois = rois.reshape(batch_size, num_proposals_per_img, -1)