mirror of https://github.com/open-mmlab/mmyolo.git
[BUG] Fix YOLOv6 when get grid priors in `predict_by_feat` same shape will crash (#378)
* [BUG] Fix with_stride * [BUG] Fix with_stride * [BUG] Fix with_stride * [BUG] Fix with_stride * [BUG] Fix with_stride * A better to fix * A better way * A better way * A better way * A better way * A better way * A better waypull/391/head
parent
d03dc0bfb5
commit
e0128c2dbd
|
@ -160,7 +160,6 @@ class YOLOv6HeadModule(BaseModule):
|
|||
return cls_score, bbox_pred
|
||||
|
||||
|
||||
# Training mode is currently not supported
|
||||
@MODELS.register_module()
|
||||
class YOLOv6Head(YOLOv5Head):
|
||||
"""YOLOv6Head head used in `YOLOv6 <https://arxiv.org/pdf/2209.02976>`_.
|
||||
|
@ -289,15 +288,16 @@ class YOLOv6Head(YOLOv5Head):
|
|||
if current_featmap_sizes != self.featmap_sizes:
|
||||
self.featmap_sizes = current_featmap_sizes
|
||||
|
||||
self.mlvl_priors = self.prior_generator.grid_priors(
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
self.featmap_sizes,
|
||||
dtype=cls_scores[0].dtype,
|
||||
device=cls_scores[0].device,
|
||||
with_stride=True)
|
||||
|
||||
self.num_level_priors = [len(n) for n in self.mlvl_priors]
|
||||
self.flatten_priors = torch.cat(self.mlvl_priors, dim=0)
|
||||
self.num_level_priors = [len(n) for n in mlvl_priors]
|
||||
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
|
||||
self.stride_tensor = self.flatten_priors[..., [2]]
|
||||
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
|
||||
|
||||
# gt info
|
||||
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
||||
|
|
Loading…
Reference in New Issue