From c587e8f467e90f009d71540779be3772231c4ed2 Mon Sep 17 00:00:00 2001 From: "Kin-Yiu, Wong" <102582011@cc.ncu.edu.tw> Date: Sun, 10 Jul 2022 10:00:19 +0800 Subject: [PATCH] main code @liguagua752109150 https://github.com/WongKinYiu/yolov7/issues/33#issuecomment-1178669212 --- utils/loss.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 3138632..12ec230 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -775,12 +775,20 @@ class ComputeLossOTA: matching_anchs[i].append(all_anch[layer_idx]) for i in range(nl): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) - matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + if matching_gis[i] != []: + matching_bs[i] = torch.cat(matching_bs[i], dim=0) + matching_as[i] = torch.cat(matching_as[i], dim=0) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + else: + matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs