diff --git a/mmyolo/models/dense_heads/yolov6_head.py b/mmyolo/models/dense_heads/yolov6_head.py index b2581ef5..56710d39 100644 --- a/mmyolo/models/dense_heads/yolov6_head.py +++ b/mmyolo/models/dense_heads/yolov6_head.py @@ -160,7 +160,6 @@ class YOLOv6HeadModule(BaseModule): return cls_score, bbox_pred -# Training mode is currently not supported @MODELS.register_module() class YOLOv6Head(YOLOv5Head): """YOLOv6Head head used in `YOLOv6 `_. @@ -289,15 +288,16 @@ class YOLOv6Head(YOLOv5Head): if current_featmap_sizes != self.featmap_sizes: self.featmap_sizes = current_featmap_sizes - self.mlvl_priors = self.prior_generator.grid_priors( + mlvl_priors = self.prior_generator.grid_priors( self.featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device, with_stride=True) - self.num_level_priors = [len(n) for n in self.mlvl_priors] - self.flatten_priors = torch.cat(self.mlvl_priors, dim=0) + self.num_level_priors = [len(n) for n in mlvl_priors] + self.flatten_priors = torch.cat(mlvl_priors, dim=0) self.stride_tensor = self.flatten_priors[..., [2]] + self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors] # gt info gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)