diff --git a/mmyolo/models/dense_heads/rtmdet_head.py b/mmyolo/models/dense_heads/rtmdet_head.py index b88147c3..a6651bd0 100644 --- a/mmyolo/models/dense_heads/rtmdet_head.py +++ b/mmyolo/models/dense_heads/rtmdet_head.py @@ -260,6 +260,9 @@ class RTMDetHead(YOLOv5Head): else: self.sampler = PseudoSampler(context=self) + self.featmap_sizes_train = None + self.flatten_priors_train = None + def forward(self, x: Tuple[Tensor]) -> Tuple[List]: """Forward features from the upstream network. @@ -312,12 +315,12 @@ class RTMDetHead(YOLOv5Head): device = cls_scores[0].device # If the shape does not equal, generate new one - if featmap_sizes != self.featmap_sizes: - self.featmap_sizes = featmap_sizes - mlvl_priors = self.prior_generator.grid_priors( + if featmap_sizes != self.featmap_sizes_train: + self.featmap_sizes_train = featmap_sizes + mlvl_priors_with_stride = self.prior_generator.grid_priors( featmap_sizes, device=device, with_stride=True) - self.flatten_priors = torch.cat(mlvl_priors, dim=0) - self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors] + self.flatten_priors_train = torch.cat( + mlvl_priors_with_stride, dim=0) flatten_cls_scores = torch.cat([ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, @@ -329,13 +332,14 @@ class RTMDetHead(YOLOv5Head): bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ], 1) - flatten_bboxes = flatten_bboxes * self.flatten_priors[..., -1, None] - flatten_bboxes = distance2bbox(self.flatten_priors[..., :2], + flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1, + None] + flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2], flatten_bboxes) assigned_result = self.assigner(flatten_bboxes.detach(), flatten_cls_scores.detach(), - self.flatten_priors, gt_labels, + self.flatten_priors_train, gt_labels, gt_bboxes, pad_bbox_flag) labels = assigned_result['assigned_labels'].reshape(-1) diff --git a/projects/assigner_visualization/assigner_visualization.py b/projects/assigner_visualization/assigner_visualization.py index df489c43..0086985f 100644 --- a/projects/assigner_visualization/assigner_visualization.py +++ b/projects/assigner_visualization/assigner_visualization.py @@ -111,6 +111,7 @@ def main(): # make output dir os.makedirs(args.output_dir, exist_ok=True) + print('Results will save to ', args.output_dir) # init visualization image number assert args.show_number > 0