Fix potential bugs in rtmdet multi-scale training (#439)

* Distinguish variables when training and testing

* beauty
pull/442/head
Nioolek 2023-01-06 19:03:14 +08:00 committed by GitHub
parent 884330108d
commit 07afc3ee78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 8 deletions

View File

@ -260,6 +260,9 @@ class RTMDetHead(YOLOv5Head):
else:
self.sampler = PseudoSampler(context=self)
self.featmap_sizes_train = None
self.flatten_priors_train = None
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
@ -312,12 +315,12 @@ class RTMDetHead(YOLOv5Head):
device = cls_scores[0].device
# If the shape does not equal, generate new one
if featmap_sizes != self.featmap_sizes:
self.featmap_sizes = featmap_sizes
mlvl_priors = self.prior_generator.grid_priors(
if featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
@ -329,13 +332,14 @@ class RTMDetHead(YOLOv5Head):
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
], 1)
flatten_bboxes = flatten_bboxes * self.flatten_priors[..., -1, None]
flatten_bboxes = distance2bbox(self.flatten_priors[..., :2],
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
None]
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
flatten_bboxes)
assigned_result = self.assigner(flatten_bboxes.detach(),
flatten_cls_scores.detach(),
self.flatten_priors, gt_labels,
self.flatten_priors_train, gt_labels,
gt_bboxes, pad_bbox_flag)
labels = assigned_result['assigned_labels'].reshape(-1)

View File

@ -111,6 +111,7 @@ def main():
# make output dir
os.makedirs(args.output_dir, exist_ok=True)
print('Results will save to ', args.output_dir)
# init visualization image number
assert args.show_number > 0