Add files via upload

main
RE-OWOD 2022-01-11 11:03:57 +08:00 committed by GitHub
parent 6257e7bcce
commit 3b865248f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 67 additions and 13 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import sys
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fvcore.nn import giou_loss, smooth_l1_loss from fvcore.nn import giou_loss, smooth_l1_loss
@ -267,7 +268,7 @@ class RPN(nn.Module):
@torch.jit.unused @torch.jit.unused
@torch.no_grad() @torch.no_grad()
def label_and_sample_anchors( def label_and_sample_anchors(
self, anchors: List[Boxes], gt_instances: List[Instances] self, anchors: List[Boxes], gt_instances: List[Instances], unk_gt_boxes: Optional[List[Instances]] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
""" """
Args: Args:
@ -292,11 +293,15 @@ class RPN(nn.Module):
gt_labels = [] gt_labels = []
matched_gt_boxes = [] matched_gt_boxes = []
for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes): gt_labels_unk = []
matched_gt_boxes_unk = []
for image_size_i, gt_boxes_i, unk_gt_boxes_i in zip(image_sizes, gt_boxes, unk_gt_boxes):
""" """
image_size_i: (h, w) for the i-th image image_size_i: (h, w) for the i-th image
gt_boxes_i: ground-truth boxes for i-th image gt_boxes_i: ground-truth boxes for i-th image
""" """
# print("unk_gt_boxes_i:",unk_gt_boxes_i)
# print("gt_boxes_i:",gt_boxes_i)
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors) match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
@ -322,7 +327,39 @@ class RPN(nn.Module):
gt_labels.append(gt_labels_i) # N,AHW gt_labels.append(gt_labels_i) # N,AHW
matched_gt_boxes.append(matched_gt_boxes_i) matched_gt_boxes.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes
if len(unk_gt_boxes_i):
gt_boxes_i = Boxes.cat([gt_boxes_i, unk_gt_boxes_i])
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
# Matching is memory-expensive and may result in CPU tensors. But the result is small
gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)
del match_quality_matrix
if self.anchor_boundary_thresh >= 0:
# Discard anchors that go out of the boundaries of the image
# NOTE: This is legacy functionality that is turned off by default in Detectron2
anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh)
gt_labels_i[~anchors_inside_image] = -1
# A vector of labels (-1, 0, 1) for each anchor
gt_labels_i = self._subsample_labels(gt_labels_i)
if len(gt_boxes_i) == 0:
# These values won't be used anyway since the anchor is labeled as background
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
else:
# TODO wasted indexing computation for ignored boxes
matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor
gt_labels_unk.append(gt_labels_i) # N,AHW
matched_gt_boxes_unk.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes, gt_labels_unk, matched_gt_boxes_unk
@torch.jit.unused @torch.jit.unused
def losses( def losses(
@ -330,8 +367,10 @@ class RPN(nn.Module):
anchors: List[Boxes], anchors: List[Boxes],
pred_objectness_logits: List[torch.Tensor], pred_objectness_logits: List[torch.Tensor],
gt_labels: List[torch.Tensor], gt_labels: List[torch.Tensor],
gt_labels_unk: List[torch.Tensor],
pred_anchor_deltas: List[torch.Tensor], pred_anchor_deltas: List[torch.Tensor],
gt_boxes: List[torch.Tensor], gt_boxes: List[torch.Tensor],
gt_boxes_unk: List[torch.Tensor],
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
""" """
Return the losses from a set of RPN predictions and their associated ground-truth. Return the losses from a set of RPN predictions and their associated ground-truth.
@ -355,6 +394,7 @@ class RPN(nn.Module):
""" """
num_images = len(gt_labels) num_images = len(gt_labels)
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai)) gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai))
gt_labels_unk = torch.stack(gt_labels_unk) # (N, sum(Hi*Wi*Ai))
# Log the number of positive/negative anchors per-image that's used in training # Log the number of positive/negative anchors per-image that's used in training
pos_mask = gt_labels == 1 pos_mask = gt_labels == 1
@ -385,12 +425,20 @@ class RPN(nn.Module):
else: else:
raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'") raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'")
if storage.iter < 5000:
valid_mask = gt_labels >= 0 valid_mask = gt_labels >= 0
objectness_loss = F.binary_cross_entropy_with_logits( objectness_loss = F.binary_cross_entropy_with_logits(
cat(pred_objectness_logits, dim=1)[valid_mask], cat(pred_objectness_logits, dim=1)[valid_mask],
gt_labels[valid_mask].to(torch.float32), gt_labels[valid_mask].to(torch.float32),
reduction="sum", reduction="sum",
) )
else:
valid_mask = gt_labels_unk >= 0
objectness_loss = F.binary_cross_entropy_with_logits(
cat(pred_objectness_logits, dim=1)[valid_mask],
gt_labels_unk[valid_mask].to(torch.float32),
reduction="sum",
)
normalizer = self.batch_size_per_image * num_images normalizer = self.batch_size_per_image * num_images
losses = { losses = {
"loss_rpn_cls": objectness_loss / normalizer, "loss_rpn_cls": objectness_loss / normalizer,
@ -404,6 +452,8 @@ class RPN(nn.Module):
images: ImageList, images: ImageList,
features: Dict[str, torch.Tensor], features: Dict[str, torch.Tensor],
gt_instances: Optional[List[Instances]] = None, gt_instances: Optional[List[Instances]] = None,
loss_flag: Optional[bool] = False,
unk_gt_boxes : Optional[List[Instances]] = None,
): ):
""" """
Args: Args:
@ -439,12 +489,16 @@ class RPN(nn.Module):
if self.training: if self.training:
assert gt_instances is not None, "RPN requires gt_instances in training!" assert gt_instances is not None, "RPN requires gt_instances in training!"
gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
if loss_flag:
gt_labels, gt_boxes, gt_labels_unk, gt_boxes_unk = self.label_and_sample_anchors(anchors, gt_instances, unk_gt_boxes)
losses = self.losses( losses = self.losses(
anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes anchors, pred_objectness_logits, gt_labels, gt_labels_unk, pred_anchor_deltas, gt_boxes, gt_boxes_unk
) )
else: else:
losses = {} losses = {}
else:
losses = {}
proposals = self.predict_proposals( proposals = self.predict_proposals(
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
) )