mirror of https://github.com/RE-OWOD/RE-OWOD
117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
from detectron2.layers import batched_nms
|
|
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
|
|
from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads
|
|
from detectron2.structures import Instances
|
|
|
|
|
|
def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image):
|
|
"""
|
|
Merge detection results from different branches of TridentNet.
|
|
Return detection results by applying non-maximum suppression (NMS) on bounding boxes
|
|
and keep the unsuppressed boxes and other instances (e.g mask) if any.
|
|
|
|
Args:
|
|
instances (list[Instances]): A list of N * num_branch instances that store detection
|
|
results. Contain N images and each image has num_branch instances.
|
|
num_branch (int): Number of branches used for merging detection results for each image.
|
|
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
|
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
|
|
all detections.
|
|
|
|
Returns:
|
|
results: (list[Instances]): A list of N instances, one for each image in the batch,
|
|
that stores the topk most confidence detections after merging results from multiple
|
|
branches.
|
|
"""
|
|
if num_branch == 1:
|
|
return instances
|
|
|
|
batch_size = len(instances) // num_branch
|
|
results = []
|
|
for i in range(batch_size):
|
|
instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)])
|
|
|
|
# Apply per-class NMS
|
|
keep = batched_nms(
|
|
instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh
|
|
)
|
|
keep = keep[:topk_per_image]
|
|
result = instance[keep]
|
|
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
|
|
@ROI_HEADS_REGISTRY.register()
|
|
class TridentRes5ROIHeads(Res5ROIHeads):
|
|
"""
|
|
The TridentNet ROIHeads in a typical "C4" R-CNN model.
|
|
See :class:`Res5ROIHeads`.
|
|
"""
|
|
|
|
def __init__(self, cfg, input_shape):
|
|
super().__init__(cfg, input_shape)
|
|
|
|
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
|
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
|
|
|
def forward(self, images, features, proposals, targets=None):
|
|
"""
|
|
See :class:`Res5ROIHeads.forward`.
|
|
"""
|
|
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
|
all_targets = targets * num_branch if targets is not None else None
|
|
pred_instances, losses = super().forward(images, features, proposals, all_targets)
|
|
del images, all_targets, targets
|
|
|
|
if self.training:
|
|
return pred_instances, losses
|
|
else:
|
|
pred_instances = merge_branch_instances(
|
|
pred_instances,
|
|
num_branch,
|
|
self.box_predictor.test_nms_thresh,
|
|
self.box_predictor.test_topk_per_image,
|
|
)
|
|
|
|
return pred_instances, {}
|
|
|
|
|
|
@ROI_HEADS_REGISTRY.register()
|
|
class TridentStandardROIHeads(StandardROIHeads):
|
|
"""
|
|
The `StandardROIHeads` for TridentNet.
|
|
See :class:`StandardROIHeads`.
|
|
"""
|
|
|
|
def __init__(self, cfg, input_shape):
|
|
super(TridentStandardROIHeads, self).__init__(cfg, input_shape)
|
|
|
|
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
|
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
|
|
|
def forward(self, images, features, proposals, targets=None):
|
|
"""
|
|
See :class:`Res5ROIHeads.forward`.
|
|
"""
|
|
# Use 1 branch if using trident_fast during inference.
|
|
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
|
# Duplicate targets for all branches in TridentNet.
|
|
all_targets = targets * num_branch if targets is not None else None
|
|
pred_instances, losses = super().forward(images, features, proposals, all_targets)
|
|
del images, all_targets, targets
|
|
|
|
if self.training:
|
|
return pred_instances, losses
|
|
else:
|
|
pred_instances = merge_branch_instances(
|
|
pred_instances,
|
|
num_branch,
|
|
self.box_predictor.test_nms_thresh,
|
|
self.box_predictor.test_topk_per_image,
|
|
)
|
|
|
|
return pred_instances, {}
|