diff --git a/utils/loss.py b/utils/loss.py
index 6eb70a2..5fc73ea 100644
--- a/utils/loss.py
+++ b/utils/loss.py
@@ -642,7 +642,7 @@ class ComputeLossOTA:
         #indices, anch = self.find_4_positive(p, targets)
         #indices, anch = self.find_5_positive(p, targets)
         #indices, anch = self.find_9_positive(p, targets)
-
+        device = torch.device(targets.device)
         matching_bs = [[] for pp in p]
         matching_as = [[] for pp in p]
         matching_gjs = [[] for pp in p]
@@ -682,7 +682,7 @@ class ComputeLossOTA:
                 all_gj.append(gj)
                 all_gi.append(gi)
                 all_anch.append(anch[i][idx])
-                from_which_layer.append(torch.ones(size=(len(b),)) * i)
+                from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device)
                 
                 fg_pred = pi[b, a, gj, gi]                
                 p_obj.append(fg_pred[:, 4:5])
@@ -753,7 +753,7 @@ class ComputeLossOTA:
                 _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
                 matching_matrix[:, anchor_matching_gt > 1] *= 0.0
                 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
-            fg_mask_inboxes = matching_matrix.sum(0) > 0.0
+            fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
             matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
         
             from_which_layer = from_which_layer[fg_mask_inboxes]