mirror of https://github.com/WongKinYiu/yolov7.git
main code
@albertfaromatics https://github.com/WongKinYiu/yolov7/issues/35#issuecomment-1178800685pull/90/head
parent
c587e8f467
commit
54627aa3ac
|
@ -501,7 +501,7 @@ class ComputeLoss:
|
|||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||
tcls, tbox, indices, anch = [], [], [], []
|
||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
||||
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||
|
||||
|
@ -775,7 +775,7 @@ class ComputeLossOTA:
|
|||
matching_anchs[i].append(all_anch[layer_idx])
|
||||
|
||||
for i in range(nl):
|
||||
if matching_gis[i] != []:
|
||||
if matching_targets[i] != []:
|
||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||
|
@ -796,7 +796,7 @@ class ComputeLossOTA:
|
|||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||
indices, anch = [], []
|
||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
||||
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||
|
||||
|
|
Loading…
Reference in New Issue