mirror of https://github.com/RE-OWOD/RE-OWOD
Add files via upload
parent
6257e7bcce
commit
3b865248f1
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fvcore.nn import giou_loss, smooth_l1_loss
|
||||
|
@ -267,7 +268,7 @@ class RPN(nn.Module):
|
|||
@torch.jit.unused
|
||||
@torch.no_grad()
|
||||
def label_and_sample_anchors(
|
||||
self, anchors: List[Boxes], gt_instances: List[Instances]
|
||||
self, anchors: List[Boxes], gt_instances: List[Instances], unk_gt_boxes: Optional[List[Instances]] = None
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
|
@ -292,11 +293,15 @@ class RPN(nn.Module):
|
|||
|
||||
gt_labels = []
|
||||
matched_gt_boxes = []
|
||||
for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes):
|
||||
gt_labels_unk = []
|
||||
matched_gt_boxes_unk = []
|
||||
for image_size_i, gt_boxes_i, unk_gt_boxes_i in zip(image_sizes, gt_boxes, unk_gt_boxes):
|
||||
"""
|
||||
image_size_i: (h, w) for the i-th image
|
||||
gt_boxes_i: ground-truth boxes for i-th image
|
||||
"""
|
||||
# print("unk_gt_boxes_i:",unk_gt_boxes_i)
|
||||
# print("gt_boxes_i:",gt_boxes_i)
|
||||
|
||||
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
|
||||
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
|
||||
|
@ -322,7 +327,39 @@ class RPN(nn.Module):
|
|||
|
||||
gt_labels.append(gt_labels_i) # N,AHW
|
||||
matched_gt_boxes.append(matched_gt_boxes_i)
|
||||
return gt_labels, matched_gt_boxes
|
||||
|
||||
|
||||
|
||||
|
||||
if len(unk_gt_boxes_i):
|
||||
gt_boxes_i = Boxes.cat([gt_boxes_i, unk_gt_boxes_i])
|
||||
|
||||
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
|
||||
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
|
||||
# Matching is memory-expensive and may result in CPU tensors. But the result is small
|
||||
gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)
|
||||
del match_quality_matrix
|
||||
|
||||
if self.anchor_boundary_thresh >= 0:
|
||||
# Discard anchors that go out of the boundaries of the image
|
||||
# NOTE: This is legacy functionality that is turned off by default in Detectron2
|
||||
anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh)
|
||||
gt_labels_i[~anchors_inside_image] = -1
|
||||
|
||||
# A vector of labels (-1, 0, 1) for each anchor
|
||||
gt_labels_i = self._subsample_labels(gt_labels_i)
|
||||
|
||||
if len(gt_boxes_i) == 0:
|
||||
# These values won't be used anyway since the anchor is labeled as background
|
||||
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
|
||||
else:
|
||||
# TODO wasted indexing computation for ignored boxes
|
||||
matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor
|
||||
|
||||
gt_labels_unk.append(gt_labels_i) # N,AHW
|
||||
matched_gt_boxes_unk.append(matched_gt_boxes_i)
|
||||
|
||||
return gt_labels, matched_gt_boxes, gt_labels_unk, matched_gt_boxes_unk
|
||||
|
||||
@torch.jit.unused
|
||||
def losses(
|
||||
|
@ -330,8 +367,10 @@ class RPN(nn.Module):
|
|||
anchors: List[Boxes],
|
||||
pred_objectness_logits: List[torch.Tensor],
|
||||
gt_labels: List[torch.Tensor],
|
||||
gt_labels_unk: List[torch.Tensor],
|
||||
pred_anchor_deltas: List[torch.Tensor],
|
||||
gt_boxes: List[torch.Tensor],
|
||||
gt_boxes_unk: List[torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Return the losses from a set of RPN predictions and their associated ground-truth.
|
||||
|
@ -355,6 +394,7 @@ class RPN(nn.Module):
|
|||
"""
|
||||
num_images = len(gt_labels)
|
||||
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai))
|
||||
gt_labels_unk = torch.stack(gt_labels_unk) # (N, sum(Hi*Wi*Ai))
|
||||
|
||||
# Log the number of positive/negative anchors per-image that's used in training
|
||||
pos_mask = gt_labels == 1
|
||||
|
@ -385,12 +425,20 @@ class RPN(nn.Module):
|
|||
else:
|
||||
raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'")
|
||||
|
||||
if storage.iter < 5000:
|
||||
valid_mask = gt_labels >= 0
|
||||
objectness_loss = F.binary_cross_entropy_with_logits(
|
||||
cat(pred_objectness_logits, dim=1)[valid_mask],
|
||||
gt_labels[valid_mask].to(torch.float32),
|
||||
reduction="sum",
|
||||
)
|
||||
else:
|
||||
valid_mask = gt_labels_unk >= 0
|
||||
objectness_loss = F.binary_cross_entropy_with_logits(
|
||||
cat(pred_objectness_logits, dim=1)[valid_mask],
|
||||
gt_labels_unk[valid_mask].to(torch.float32),
|
||||
reduction="sum",
|
||||
)
|
||||
normalizer = self.batch_size_per_image * num_images
|
||||
losses = {
|
||||
"loss_rpn_cls": objectness_loss / normalizer,
|
||||
|
@ -404,6 +452,8 @@ class RPN(nn.Module):
|
|||
images: ImageList,
|
||||
features: Dict[str, torch.Tensor],
|
||||
gt_instances: Optional[List[Instances]] = None,
|
||||
loss_flag: Optional[bool] = False,
|
||||
unk_gt_boxes : Optional[List[Instances]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -439,12 +489,16 @@ class RPN(nn.Module):
|
|||
|
||||
if self.training:
|
||||
assert gt_instances is not None, "RPN requires gt_instances in training!"
|
||||
gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
|
||||
|
||||
if loss_flag:
|
||||
gt_labels, gt_boxes, gt_labels_unk, gt_boxes_unk = self.label_and_sample_anchors(anchors, gt_instances, unk_gt_boxes)
|
||||
losses = self.losses(
|
||||
anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
|
||||
anchors, pred_objectness_logits, gt_labels, gt_labels_unk, pred_anchor_deltas, gt_boxes, gt_boxes_unk
|
||||
)
|
||||
else:
|
||||
losses = {}
|
||||
else:
|
||||
losses = {}
|
||||
proposals = self.predict_proposals(
|
||||
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue