diff --git a/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py
index 618f3c4..d2600fd 100644
--- a/detectron2/modeling/meta_arch/rcnn.py
+++ b/detectron2/modeling/meta_arch/rcnn.py
@@ -149,6 +149,8 @@ class GeneralizedRCNN(nn.Module):
             return self.inference(batched_inputs)
 
         images = self.preprocess_image(batched_inputs)
+        image_id = [x["image_id"] for x in batched_inputs]
+        ori_image = [x['image'] for x in batched_inputs]
         if "instances" in batched_inputs[0]:
             gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
         else:
@@ -157,17 +159,20 @@ class GeneralizedRCNN(nn.Module):
         features = self.backbone(images.tensor)
 
         if self.proposal_generator:
-            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
+            proposals, _ = self.proposal_generator(images, features, gt_instances, False)
         else:
             assert "proposals" in batched_inputs[0]
             proposals = [x["proposals"].to(self.device) for x in batched_inputs]
             proposal_losses = {}
 
-        _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
-        if self.vis_period > 0:
-            storage = get_event_storage()
-            if storage.iter % self.vis_period == 0:
-                self.visualize_training(batched_inputs, proposals)
+        _, detector_losses, unk_sel_gt = self.roi_heads(images, features, proposals, gt_instances, image_id, ori_image)
+
+        _, proposal_losses = self.proposal_generator(images, features, gt_instances, True, unk_sel_gt)
+
+        # if self.vis_period > 0:
+        #     storage = get_event_storage()
+        #     if storage.iter % self.vis_period == 0:
+        #         self.visualize_training(batched_inputs, proposals)
 
         losses = {}
         losses.update(detector_losses)
@@ -203,7 +208,7 @@ class GeneralizedRCNN(nn.Module):
                 assert "proposals" in batched_inputs[0]
                 proposals = [x["proposals"].to(self.device) for x in batched_inputs]
 
-            results, _ = self.roi_heads(images, features, proposals, None)
+            results, _, _ = self.roi_heads(images, features, proposals, None)
         else:
             detected_instances = [x.to(self.device) for x in detected_instances]
             results = self.roi_heads.forward_with_given_boxes(features, detected_instances)