Fix batch index for cascade roi head (#2078)
parent
5fd0e8957f
commit
389a146212
mmdeploy/codebase/mmdet/models/roi_heads
|
@ -69,7 +69,7 @@ def cascade_roi_head__predict_bbox(self,
|
|||
new_rois = get_box_tensor(new_rois)
|
||||
rois = new_rois.reshape(-1, new_rois.shape[-1])
|
||||
# Add dummy batch index
|
||||
rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
|
||||
rois = torch.cat([batch_index.flatten(0, 1), rois], dim=-1)
|
||||
cls_scores = sum(ms_scores) / float(len(ms_scores))
|
||||
bbox_preds = bbox_pred.reshape(batch_size, num_proposals_per_img, -1)
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
|
||||
|
|
Loading…
Reference in New Issue