[BUG] Fix YOLOv6 when get grid priors in `predict_by_feat` same shape will crash (#378)

* [BUG] Fix with_stride

* [BUG] Fix with_stride

* [BUG] Fix with_stride

* [BUG] Fix with_stride

* [BUG] Fix with_stride

* A better to fix

* A better way

* A better way

* A better way

* A better way

* A better way

* A better way
pull/391/head
HinGwenWoong 2022-12-19 14:40:04 +08:00 committed by GitHub
parent d03dc0bfb5
commit e0128c2dbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -160,7 +160,6 @@ class YOLOv6HeadModule(BaseModule):
return cls_score, bbox_pred
# Training mode is currently not supported
@MODELS.register_module()
class YOLOv6Head(YOLOv5Head):
"""YOLOv6Head head used in `YOLOv6 <https://arxiv.org/pdf/2209.02976>`_.
@ -289,15 +288,16 @@ class YOLOv6Head(YOLOv5Head):
if current_featmap_sizes != self.featmap_sizes:
self.featmap_sizes = current_featmap_sizes
self.mlvl_priors = self.prior_generator.grid_priors(
mlvl_priors = self.prior_generator.grid_priors(
self.featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in self.mlvl_priors]
self.flatten_priors = torch.cat(self.mlvl_priors, dim=0)
self.num_level_priors = [len(n) for n in mlvl_priors]
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
self.stride_tensor = self.flatten_priors[..., [2]]
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
# gt info
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)