mirror of https://github.com/WongKinYiu/yolov7.git
fix indices device bug (#1311)
parent
557e3837af
commit
45f2c1fe55
|
@ -642,7 +642,7 @@ class ComputeLossOTA:
|
|||
#indices, anch = self.find_4_positive(p, targets)
|
||||
#indices, anch = self.find_5_positive(p, targets)
|
||||
#indices, anch = self.find_9_positive(p, targets)
|
||||
|
||||
device = torch.device(targets.device)
|
||||
matching_bs = [[] for pp in p]
|
||||
matching_as = [[] for pp in p]
|
||||
matching_gjs = [[] for pp in p]
|
||||
|
@ -682,7 +682,7 @@ class ComputeLossOTA:
|
|||
all_gj.append(gj)
|
||||
all_gi.append(gi)
|
||||
all_anch.append(anch[i][idx])
|
||||
from_which_layer.append(torch.ones(size=(len(b),)) * i)
|
||||
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device)
|
||||
|
||||
fg_pred = pi[b, a, gj, gi]
|
||||
p_obj.append(fg_pred[:, 4:5])
|
||||
|
@ -753,7 +753,7 @@ class ComputeLossOTA:
|
|||
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
||||
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
|
||||
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
|
||||
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
|
||||
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
|
||||
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
||||
|
||||
from_which_layer = from_which_layer[fg_mask_inboxes]
|
||||
|
|
Loading…
Reference in New Issue