diff --git a/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py index 618f3c4..d2600fd 100644 --- a/detectron2/modeling/meta_arch/rcnn.py +++ b/detectron2/modeling/meta_arch/rcnn.py @@ -149,6 +149,8 @@ class GeneralizedRCNN(nn.Module): return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) + image_id = [x["image_id"] for x in batched_inputs] + ori_image = [x['image'] for x in batched_inputs] if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: @@ -157,17 +159,20 @@ class GeneralizedRCNN(nn.Module): features = self.backbone(images.tensor) if self.proposal_generator: - proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + proposals, _ = self.proposal_generator(images, features, gt_instances, False) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] proposal_losses = {} - _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) - if self.vis_period > 0: - storage = get_event_storage() - if storage.iter % self.vis_period == 0: - self.visualize_training(batched_inputs, proposals) + _, detector_losses, unk_sel_gt = self.roi_heads(images, features, proposals, gt_instances, image_id, ori_image) + + _, proposal_losses = self.proposal_generator(images, features, gt_instances, True, unk_sel_gt) + + # if self.vis_period > 0: + # storage = get_event_storage() + # if storage.iter % self.vis_period == 0: + # self.visualize_training(batched_inputs, proposals) losses = {} losses.update(detector_losses) @@ -203,7 +208,7 @@ class GeneralizedRCNN(nn.Module): assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] - results, _ = self.roi_heads(images, features, proposals, None) + results, _, _ = self.roi_heads(images, features, proposals, None) else: detected_instances = [x.to(self.device) for x in detected_instances] results = self.roi_heads.forward_with_given_boxes(features, detected_instances)