mirror of https://github.com/RE-OWOD/RE-OWOD
Add files via upload
parent
6257e7bcce
commit
3b865248f1
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fvcore.nn import giou_loss, smooth_l1_loss
|
from fvcore.nn import giou_loss, smooth_l1_loss
|
||||||
|
@ -267,7 +268,7 @@ class RPN(nn.Module):
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def label_and_sample_anchors(
|
def label_and_sample_anchors(
|
||||||
self, anchors: List[Boxes], gt_instances: List[Instances]
|
self, anchors: List[Boxes], gt_instances: List[Instances], unk_gt_boxes: Optional[List[Instances]] = None
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -292,11 +293,15 @@ class RPN(nn.Module):
|
||||||
|
|
||||||
gt_labels = []
|
gt_labels = []
|
||||||
matched_gt_boxes = []
|
matched_gt_boxes = []
|
||||||
for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes):
|
gt_labels_unk = []
|
||||||
|
matched_gt_boxes_unk = []
|
||||||
|
for image_size_i, gt_boxes_i, unk_gt_boxes_i in zip(image_sizes, gt_boxes, unk_gt_boxes):
|
||||||
"""
|
"""
|
||||||
image_size_i: (h, w) for the i-th image
|
image_size_i: (h, w) for the i-th image
|
||||||
gt_boxes_i: ground-truth boxes for i-th image
|
gt_boxes_i: ground-truth boxes for i-th image
|
||||||
"""
|
"""
|
||||||
|
# print("unk_gt_boxes_i:",unk_gt_boxes_i)
|
||||||
|
# print("gt_boxes_i:",gt_boxes_i)
|
||||||
|
|
||||||
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
|
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
|
||||||
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
|
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
|
||||||
|
@ -322,7 +327,39 @@ class RPN(nn.Module):
|
||||||
|
|
||||||
gt_labels.append(gt_labels_i) # N,AHW
|
gt_labels.append(gt_labels_i) # N,AHW
|
||||||
matched_gt_boxes.append(matched_gt_boxes_i)
|
matched_gt_boxes.append(matched_gt_boxes_i)
|
||||||
return gt_labels, matched_gt_boxes
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if len(unk_gt_boxes_i):
|
||||||
|
gt_boxes_i = Boxes.cat([gt_boxes_i, unk_gt_boxes_i])
|
||||||
|
|
||||||
|
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
|
||||||
|
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
|
||||||
|
# Matching is memory-expensive and may result in CPU tensors. But the result is small
|
||||||
|
gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)
|
||||||
|
del match_quality_matrix
|
||||||
|
|
||||||
|
if self.anchor_boundary_thresh >= 0:
|
||||||
|
# Discard anchors that go out of the boundaries of the image
|
||||||
|
# NOTE: This is legacy functionality that is turned off by default in Detectron2
|
||||||
|
anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh)
|
||||||
|
gt_labels_i[~anchors_inside_image] = -1
|
||||||
|
|
||||||
|
# A vector of labels (-1, 0, 1) for each anchor
|
||||||
|
gt_labels_i = self._subsample_labels(gt_labels_i)
|
||||||
|
|
||||||
|
if len(gt_boxes_i) == 0:
|
||||||
|
# These values won't be used anyway since the anchor is labeled as background
|
||||||
|
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
|
||||||
|
else:
|
||||||
|
# TODO wasted indexing computation for ignored boxes
|
||||||
|
matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor
|
||||||
|
|
||||||
|
gt_labels_unk.append(gt_labels_i) # N,AHW
|
||||||
|
matched_gt_boxes_unk.append(matched_gt_boxes_i)
|
||||||
|
|
||||||
|
return gt_labels, matched_gt_boxes, gt_labels_unk, matched_gt_boxes_unk
|
||||||
|
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
def losses(
|
def losses(
|
||||||
|
@ -330,8 +367,10 @@ class RPN(nn.Module):
|
||||||
anchors: List[Boxes],
|
anchors: List[Boxes],
|
||||||
pred_objectness_logits: List[torch.Tensor],
|
pred_objectness_logits: List[torch.Tensor],
|
||||||
gt_labels: List[torch.Tensor],
|
gt_labels: List[torch.Tensor],
|
||||||
|
gt_labels_unk: List[torch.Tensor],
|
||||||
pred_anchor_deltas: List[torch.Tensor],
|
pred_anchor_deltas: List[torch.Tensor],
|
||||||
gt_boxes: List[torch.Tensor],
|
gt_boxes: List[torch.Tensor],
|
||||||
|
gt_boxes_unk: List[torch.Tensor],
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Return the losses from a set of RPN predictions and their associated ground-truth.
|
Return the losses from a set of RPN predictions and their associated ground-truth.
|
||||||
|
@ -355,6 +394,7 @@ class RPN(nn.Module):
|
||||||
"""
|
"""
|
||||||
num_images = len(gt_labels)
|
num_images = len(gt_labels)
|
||||||
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai))
|
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai))
|
||||||
|
gt_labels_unk = torch.stack(gt_labels_unk) # (N, sum(Hi*Wi*Ai))
|
||||||
|
|
||||||
# Log the number of positive/negative anchors per-image that's used in training
|
# Log the number of positive/negative anchors per-image that's used in training
|
||||||
pos_mask = gt_labels == 1
|
pos_mask = gt_labels == 1
|
||||||
|
@ -385,12 +425,20 @@ class RPN(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'")
|
raise ValueError(f"Invalid rpn box reg loss type '{self.box_reg_loss_type}'")
|
||||||
|
|
||||||
|
if storage.iter < 5000:
|
||||||
valid_mask = gt_labels >= 0
|
valid_mask = gt_labels >= 0
|
||||||
objectness_loss = F.binary_cross_entropy_with_logits(
|
objectness_loss = F.binary_cross_entropy_with_logits(
|
||||||
cat(pred_objectness_logits, dim=1)[valid_mask],
|
cat(pred_objectness_logits, dim=1)[valid_mask],
|
||||||
gt_labels[valid_mask].to(torch.float32),
|
gt_labels[valid_mask].to(torch.float32),
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
valid_mask = gt_labels_unk >= 0
|
||||||
|
objectness_loss = F.binary_cross_entropy_with_logits(
|
||||||
|
cat(pred_objectness_logits, dim=1)[valid_mask],
|
||||||
|
gt_labels_unk[valid_mask].to(torch.float32),
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
normalizer = self.batch_size_per_image * num_images
|
normalizer = self.batch_size_per_image * num_images
|
||||||
losses = {
|
losses = {
|
||||||
"loss_rpn_cls": objectness_loss / normalizer,
|
"loss_rpn_cls": objectness_loss / normalizer,
|
||||||
|
@ -404,6 +452,8 @@ class RPN(nn.Module):
|
||||||
images: ImageList,
|
images: ImageList,
|
||||||
features: Dict[str, torch.Tensor],
|
features: Dict[str, torch.Tensor],
|
||||||
gt_instances: Optional[List[Instances]] = None,
|
gt_instances: Optional[List[Instances]] = None,
|
||||||
|
loss_flag: Optional[bool] = False,
|
||||||
|
unk_gt_boxes : Optional[List[Instances]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -439,12 +489,16 @@ class RPN(nn.Module):
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
assert gt_instances is not None, "RPN requires gt_instances in training!"
|
assert gt_instances is not None, "RPN requires gt_instances in training!"
|
||||||
gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
|
|
||||||
|
if loss_flag:
|
||||||
|
gt_labels, gt_boxes, gt_labels_unk, gt_boxes_unk = self.label_and_sample_anchors(anchors, gt_instances, unk_gt_boxes)
|
||||||
losses = self.losses(
|
losses = self.losses(
|
||||||
anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
|
anchors, pred_objectness_logits, gt_labels, gt_labels_unk, pred_anchor_deltas, gt_boxes, gt_boxes_unk
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
else:
|
||||||
|
losses = {}
|
||||||
proposals = self.predict_proposals(
|
proposals = self.predict_proposals(
|
||||||
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
|
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue