Add files via upload

main
RE-OWOD 2022-01-11 11:03:57 +08:00 committed by GitHub
parent 6257e7bcce
commit 3b865248f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 67 additions and 13 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional, Tuple, Union
import sys
import torch
import torch.nn.functional as F
from fvcore.nn import giou_loss, smooth_l1_loss
@ -267,7 +268,7 @@ class RPN(nn.Module):
@torch.jit.unused
@torch.no_grad()
def label_and_sample_anchors(
self, anchors: List[Boxes], gt_instances: List[Instances]
self, anchors: List[Boxes], gt_instances: List[Instances], unk_gt_boxes: Optional[List[Instances]] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Args:
@ -292,11 +293,15 @@ class RPN(nn.Module):
gt_labels = []
matched_gt_boxes = []
for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes):
gt_labels_unk = []
matched_gt_boxes_unk = []
for image_size_i, gt_boxes_i, unk_gt_boxes_i in zip(image_sizes, gt_boxes, unk_gt_boxes):
"""
image_size_i: (h, w) for the i-th image
gt_boxes_i: ground-truth boxes for i-th image
"""
# print("unk_gt_boxes_i:",unk_gt_boxes_i)
# print("gt_boxes_i:",gt_boxes_i)
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
@ -322,7 +327,39 @@ class RPN(nn.Module):
gt_labels.append(gt_labels_i) # N,AHW
matched_gt_boxes.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes
if len(unk_gt_boxes_i):
gt_boxes_i = Boxes.cat([gt_boxes_i, unk_gt_boxes_i])
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
# Matching is memory-expensive and may result in CPU tensors. But the result is small
gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)
del match_quality_matrix
if self.anchor_boundary_thresh >= 0:
# Discard anchors that go out of the boundaries of the image
# NOTE: This is legacy functionality that is turned off by default in Detectron2
anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh)
gt_labels_i[~anchors_inside_image] = -1
# A vector of labels (-1, 0, 1) for each anchor
gt_labels_i = self._subsample_labels(gt_labels_i)
if len(gt_boxes_i) == 0:
# These values won't be used anyway since the anchor is labeled as background
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
else:
# TODO wasted indexing computation for ignored boxes
matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor
gt_labels_unk.append(gt_labels_i) # N,AHW
matched_gt_boxes_unk.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes, gt_labels_unk, matched_gt_boxes_unk
@torch.jit.unused
def losses(
@ -330,8 +367,10 @@ class RPN(nn.Module):
anchors: List[Boxes],
pred_objectness_logits: List[torch.Tensor],
gt_labels: List[torch.Tensor],
gt_labels_unk: List[torch.Tensor],
pred_anchor_deltas: List[torch.Tensor],
gt_boxes: List[torch.Tensor],
gt_boxes_unk: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Return the losses from a set of RPN predictions and their associated ground-truth.
@ -355,6 +394,7 @@ class RPN(nn.Module):
"""
num_images = len(gt_labels)
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai))
gt_labels_unk = torch.stack(gt_labels_unk) # (N, sum(Hi*Wi*Ai))
# Log the number of positive/negative anchors per-image that's used in training
pos_mask = gt_labels == 1
@ -385,12 +425,20 @@ class RPN(nn.Module):
else:
raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'")
if storage.iter < 5000:
valid_mask = gt_labels >= 0
objectness_loss = F.binary_cross_entropy_with_logits(
cat(pred_objectness_logits, dim=1)[valid_mask],
gt_labels[valid_mask].to(torch.float32),
reduction="sum",
)
else:
valid_mask = gt_labels_unk >= 0
objectness_loss = F.binary_cross_entropy_with_logits(
cat(pred_objectness_logits, dim=1)[valid_mask],
gt_labels_unk[valid_mask].to(torch.float32),
reduction="sum",
)
normalizer = self.batch_size_per_image * num_images
losses = {
"loss_rpn_cls": objectness_loss / normalizer,
@ -404,6 +452,8 @@ class RPN(nn.Module):
images: ImageList,
features: Dict[str, torch.Tensor],
gt_instances: Optional[List[Instances]] = None,
loss_flag: Optional[bool] = False,
unk_gt_boxes : Optional[List[Instances]] = None,
):
"""
Args:
@ -439,12 +489,16 @@ class RPN(nn.Module):
if self.training:
assert gt_instances is not None, "RPN requires gt_instances in training!"
gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
if loss_flag:
gt_labels, gt_boxes, gt_labels_unk, gt_boxes_unk = self.label_and_sample_anchors(anchors, gt_instances, unk_gt_boxes)
losses = self.losses(
anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
anchors, pred_objectness_logits, gt_labels, gt_labels_unk, pred_anchor_deltas, gt_boxes, gt_boxes_unk
)
else:
losses = {}
else:
losses = {}
proposals = self.predict_proposals(
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
)