From 3b865248f1b2f67625f2fb8c92f2570ac3428e1c Mon Sep 17 00:00:00 2001 From: RE-OWOD <95522332+RE-OWOD@users.noreply.github.com> Date: Tue, 11 Jan 2022 11:03:57 +0800 Subject: [PATCH] Add files via upload --- detectron2/modeling/proposal_generator/rpn.py | 80 ++++++++++++++++--- 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/detectron2/modeling/proposal_generator/rpn.py b/detectron2/modeling/proposal_generator/rpn.py index 549fcbc..a5cc9a6 100644 --- a/detectron2/modeling/proposal_generator/rpn.py +++ b/detectron2/modeling/proposal_generator/rpn.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Dict, List, Optional, Tuple, Union +import sys import torch import torch.nn.functional as F from fvcore.nn import giou_loss, smooth_l1_loss @@ -267,7 +268,7 @@ class RPN(nn.Module): @torch.jit.unused @torch.no_grad() def label_and_sample_anchors( - self, anchors: List[Boxes], gt_instances: List[Instances] + self, anchors: List[Boxes], gt_instances: List[Instances], unk_gt_boxes: Optional[List[Instances]] = None ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Args: @@ -292,11 +293,15 @@ class RPN(nn.Module): gt_labels = [] matched_gt_boxes = [] - for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes): + gt_labels_unk = [] + matched_gt_boxes_unk = [] + for image_size_i, gt_boxes_i, unk_gt_boxes_i in zip(image_sizes, gt_boxes, unk_gt_boxes): """ image_size_i: (h, w) for the i-th image gt_boxes_i: ground-truth boxes for i-th image """ + # print("unk_gt_boxes_i:",unk_gt_boxes_i) + # print("gt_boxes_i:",gt_boxes_i) match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors) matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) @@ -322,7 +327,39 @@ class RPN(nn.Module): gt_labels.append(gt_labels_i) # N,AHW matched_gt_boxes.append(matched_gt_boxes_i) - return gt_labels, matched_gt_boxes + + + + + if len(unk_gt_boxes_i): + gt_boxes_i = Boxes.cat([gt_boxes_i, unk_gt_boxes_i]) + + match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors) + matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) + # Matching is memory-expensive and may result in CPU tensors. But the result is small + gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device) + del match_quality_matrix + + if self.anchor_boundary_thresh >= 0: + # Discard anchors that go out of the boundaries of the image + # NOTE: This is legacy functionality that is turned off by default in Detectron2 + anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh) + gt_labels_i[~anchors_inside_image] = -1 + + # A vector of labels (-1, 0, 1) for each anchor + gt_labels_i = self._subsample_labels(gt_labels_i) + + if len(gt_boxes_i) == 0: + # These values won't be used anyway since the anchor is labeled as background + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + else: + # TODO wasted indexing computation for ignored boxes + matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor + + gt_labels_unk.append(gt_labels_i) # N,AHW + matched_gt_boxes_unk.append(matched_gt_boxes_i) + + return gt_labels, matched_gt_boxes, gt_labels_unk, matched_gt_boxes_unk @torch.jit.unused def losses( @@ -330,8 +367,10 @@ class RPN(nn.Module): anchors: List[Boxes], pred_objectness_logits: List[torch.Tensor], gt_labels: List[torch.Tensor], + gt_labels_unk: List[torch.Tensor], pred_anchor_deltas: List[torch.Tensor], gt_boxes: List[torch.Tensor], + gt_boxes_unk: List[torch.Tensor], ) -> Dict[str, torch.Tensor]: """ Return the losses from a set of RPN predictions and their associated ground-truth. @@ -355,6 +394,7 @@ class RPN(nn.Module): """ num_images = len(gt_labels) gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai)) + gt_labels_unk = torch.stack(gt_labels_unk) # (N, sum(Hi*Wi*Ai)) # Log the number of positive/negative anchors per-image that's used in training pos_mask = gt_labels == 1 @@ -385,12 +425,20 @@ class RPN(nn.Module): else: raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'") - valid_mask = gt_labels >= 0 - objectness_loss = F.binary_cross_entropy_with_logits( - cat(pred_objectness_logits, dim=1)[valid_mask], - gt_labels[valid_mask].to(torch.float32), - reduction="sum", - ) + if storage.iter < 5000: + valid_mask = gt_labels >= 0 + objectness_loss = F.binary_cross_entropy_with_logits( + cat(pred_objectness_logits, dim=1)[valid_mask], + gt_labels[valid_mask].to(torch.float32), + reduction="sum", + ) + else: + valid_mask = gt_labels_unk >= 0 + objectness_loss = F.binary_cross_entropy_with_logits( + cat(pred_objectness_logits, dim=1)[valid_mask], + gt_labels_unk[valid_mask].to(torch.float32), + reduction="sum", + ) normalizer = self.batch_size_per_image * num_images losses = { "loss_rpn_cls": objectness_loss / normalizer, @@ -404,6 +452,8 @@ class RPN(nn.Module): images: ImageList, features: Dict[str, torch.Tensor], gt_instances: Optional[List[Instances]] = None, + loss_flag: Optional[bool] = False, + unk_gt_boxes : Optional[List[Instances]] = None, ): """ Args: @@ -439,10 +489,14 @@ class RPN(nn.Module): if self.training: assert gt_instances is not None, "RPN requires gt_instances in training!" - gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) - losses = self.losses( - anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes - ) + + if loss_flag: + gt_labels, gt_boxes, gt_labels_unk, gt_boxes_unk = self.label_and_sample_anchors(anchors, gt_instances, unk_gt_boxes) + losses = self.losses( + anchors, pred_objectness_logits, gt_labels, gt_labels_unk, pred_anchor_deltas, gt_boxes, gt_boxes_unk + ) + else: + losses = {} else: losses = {} proposals = self.predict_proposals(