diff --git a/utils/loss.py b/utils/loss.py index 4e08782..3138632 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -712,7 +712,7 @@ class ComputeLossOTA: pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) - top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1) + top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1) dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) gt_cls_per_image = (