fix indices device bug (#1311)

pull/1317/head
superfast852 2022-12-27 23:40:57 -04:00 committed by GitHub
parent 557e3837af
commit 45f2c1fe55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -642,7 +642,7 @@ class ComputeLossOTA:
#indices, anch = self.find_4_positive(p, targets)
#indices, anch = self.find_5_positive(p, targets)
#indices, anch = self.find_9_positive(p, targets)
device = torch.device(targets.device)
matching_bs = [[] for pp in p]
matching_as = [[] for pp in p]
matching_gjs = [[] for pp in p]
@ -682,7 +682,7 @@ class ComputeLossOTA:
all_gj.append(gj)
all_gi.append(gi)
all_anch.append(anch[i][idx])
from_which_layer.append(torch.ones(size=(len(b),)) * i)
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device)
fg_pred = pi[b, a, gj, gi]
p_obj.append(fg_pred[:, 4:5])
@ -753,7 +753,7 @@ class ComputeLossOTA:
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
from_which_layer = from_which_layer[fg_mask_inboxes]