mirror of https://github.com/WongKinYiu/yolov7.git
main code
@liguagua752109150 https://github.com/WongKinYiu/yolov7/issues/33#issuecomment-1178669212pull/90/head
parent
eef4f2c928
commit
c587e8f467
|
@ -775,12 +775,20 @@ class ComputeLossOTA:
|
|||
matching_anchs[i].append(all_anch[layer_idx])
|
||||
|
||||
for i in range(nl):
|
||||
if matching_gis[i] != []:
|
||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||
else:
|
||||
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
|
||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||
|
||||
|
|
Loading…
Reference in New Issue