mirror of https://github.com/open-mmlab/mmyolo.git
Fix potential bugs in rtmdet multi-scale training (#439)
* Distinguish variables when training and testing * beautypull/442/head
parent
884330108d
commit
07afc3ee78
|
@ -260,6 +260,9 @@ class RTMDetHead(YOLOv5Head):
|
|||
else:
|
||||
self.sampler = PseudoSampler(context=self)
|
||||
|
||||
self.featmap_sizes_train = None
|
||||
self.flatten_priors_train = None
|
||||
|
||||
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
||||
"""Forward features from the upstream network.
|
||||
|
||||
|
@ -312,12 +315,12 @@ class RTMDetHead(YOLOv5Head):
|
|||
device = cls_scores[0].device
|
||||
|
||||
# If the shape does not equal, generate new one
|
||||
if featmap_sizes != self.featmap_sizes:
|
||||
self.featmap_sizes = featmap_sizes
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
if featmap_sizes != self.featmap_sizes_train:
|
||||
self.featmap_sizes_train = featmap_sizes
|
||||
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=device, with_stride=True)
|
||||
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
|
||||
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
|
||||
self.flatten_priors_train = torch.cat(
|
||||
mlvl_priors_with_stride, dim=0)
|
||||
|
||||
flatten_cls_scores = torch.cat([
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||
|
@ -329,13 +332,14 @@ class RTMDetHead(YOLOv5Head):
|
|||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
], 1)
|
||||
flatten_bboxes = flatten_bboxes * self.flatten_priors[..., -1, None]
|
||||
flatten_bboxes = distance2bbox(self.flatten_priors[..., :2],
|
||||
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
|
||||
None]
|
||||
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
|
||||
flatten_bboxes)
|
||||
|
||||
assigned_result = self.assigner(flatten_bboxes.detach(),
|
||||
flatten_cls_scores.detach(),
|
||||
self.flatten_priors, gt_labels,
|
||||
self.flatten_priors_train, gt_labels,
|
||||
gt_bboxes, pad_bbox_flag)
|
||||
|
||||
labels = assigned_result['assigned_labels'].reshape(-1)
|
||||
|
|
|
@ -111,6 +111,7 @@ def main():
|
|||
|
||||
# make output dir
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
print('Results will save to ', args.output_dir)
|
||||
|
||||
# init visualization image number
|
||||
assert args.show_number > 0
|
||||
|
|
Loading…
Reference in New Issue