mirror of https://github.com/RE-OWOD/RE-OWOD
968 lines
42 KiB
Python
968 lines
42 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import inspect
|
|
import logging
|
|
import numpy as np
|
|
import heapq
|
|
import os
|
|
import shortuuid
|
|
import operator
|
|
import sys
|
|
import cv2
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
import torch
|
|
from torch import nn
|
|
# from ..drawBoxes import draw_boxes
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.layers import ShapeSpec, nonzero_tuple
|
|
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
|
|
from detectron2.utils.events import get_event_storage
|
|
from detectron2.utils.registry import Registry
|
|
|
|
from ..backbone.resnet import BottleneckBlock, ResNet
|
|
from ..matcher import Matcher
|
|
from ..poolers import ROIPooler
|
|
from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals
|
|
from ..sampling import subsample_labels
|
|
from .box_head import build_box_head
|
|
from .fast_rcnn import FastRCNNOutputLayers
|
|
from .keypoint_head import build_keypoint_head
|
|
from .mask_head import build_mask_head
|
|
|
|
ROI_HEADS_REGISTRY = Registry("ROI_HEADS")
|
|
ROI_HEADS_REGISTRY.__doc__ = """
|
|
Registry for ROI heads in a generalized R-CNN model.
|
|
ROIHeads take feature maps and region proposals, and
|
|
perform per-region computation.
|
|
|
|
The registered object will be called with `obj(cfg, input_shape)`.
|
|
The call is expected to return an :class:`ROIHeads`.
|
|
"""
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def build_roi_heads(cfg, input_shape):
|
|
"""
|
|
Build ROIHeads defined by `cfg.MODEL.ROI_HEADS.NAME`.
|
|
"""
|
|
name = cfg.MODEL.ROI_HEADS.NAME
|
|
return ROI_HEADS_REGISTRY.get(name)(cfg, input_shape)
|
|
|
|
|
|
def select_foreground_proposals(
|
|
proposals: List[Instances], bg_label: int
|
|
) -> Tuple[List[Instances], List[torch.Tensor]]:
|
|
"""
|
|
Given a list of N Instances (for N images), each containing a `gt_classes` field,
|
|
return a list of Instances that contain only instances with `gt_classes != -1 &&
|
|
gt_classes != bg_label`.
|
|
|
|
Args:
|
|
proposals (list[Instances]): A list of N Instances, where N is the number of
|
|
images in the batch.
|
|
bg_label: label index of background class.
|
|
|
|
Returns:
|
|
list[Instances]: N Instances, each contains only the selected foreground instances.
|
|
list[Tensor]: N boolean vector, correspond to the selection mask of
|
|
each Instances object. True for selected instances.
|
|
"""
|
|
assert isinstance(proposals, (list, tuple))
|
|
assert isinstance(proposals[0], Instances)
|
|
assert proposals[0].has("gt_classes")
|
|
fg_proposals = []
|
|
fg_selection_masks = []
|
|
for proposals_per_image in proposals:
|
|
gt_classes = proposals_per_image.gt_classes
|
|
fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label)
|
|
fg_idxs = fg_selection_mask.nonzero().squeeze(1)
|
|
fg_proposals.append(proposals_per_image[fg_idxs])
|
|
fg_selection_masks.append(fg_selection_mask)
|
|
return fg_proposals, fg_selection_masks
|
|
|
|
|
|
def select_proposals_with_visible_keypoints(proposals: List[Instances]) -> List[Instances]:
|
|
"""
|
|
Args:
|
|
proposals (list[Instances]): a list of N Instances, where N is the
|
|
number of images.
|
|
|
|
Returns:
|
|
proposals: only contains proposals with at least one visible keypoint.
|
|
|
|
Note that this is still slightly different from Detectron.
|
|
In Detectron, proposals for training keypoint head are re-sampled from
|
|
all the proposals with IOU>threshold & >=1 visible keypoint.
|
|
|
|
Here, the proposals are first sampled from all proposals with
|
|
IOU>threshold, then proposals with no visible keypoint are filtered out.
|
|
This strategy seems to make no difference on Detectron and is easier to implement.
|
|
"""
|
|
ret = []
|
|
all_num_fg = []
|
|
for proposals_per_image in proposals:
|
|
# If empty/unannotated image (hard negatives), skip filtering for train
|
|
if len(proposals_per_image) == 0:
|
|
ret.append(proposals_per_image)
|
|
continue
|
|
gt_keypoints = proposals_per_image.gt_keypoints.tensor
|
|
# #fg x K x 3
|
|
vis_mask = gt_keypoints[:, :, 2] >= 1
|
|
xs, ys = gt_keypoints[:, :, 0], gt_keypoints[:, :, 1]
|
|
proposal_boxes = proposals_per_image.proposal_boxes.tensor.unsqueeze(dim=1) # #fg x 1 x 4
|
|
kp_in_box = (
|
|
(xs >= proposal_boxes[:, :, 0])
|
|
& (xs <= proposal_boxes[:, :, 2])
|
|
& (ys >= proposal_boxes[:, :, 1])
|
|
& (ys <= proposal_boxes[:, :, 3])
|
|
)
|
|
selection = (kp_in_box & vis_mask).any(dim=1)
|
|
selection_idxs = nonzero_tuple(selection)[0]
|
|
all_num_fg.append(selection_idxs.numel())
|
|
ret.append(proposals_per_image[selection_idxs])
|
|
|
|
storage = get_event_storage()
|
|
storage.put_scalar("keypoint_head/num_fg_samples", np.mean(all_num_fg))
|
|
return ret
|
|
|
|
|
|
class ROIHeads(torch.nn.Module):
|
|
"""
|
|
ROIHeads perform all per-region computation in an R-CNN.
|
|
|
|
It typically contains logic to
|
|
|
|
1. (in training only) match proposals with ground truth and sample them
|
|
2. crop the regions and extract per-region features using proposals
|
|
3. make per-region predictions with different heads
|
|
|
|
It can have many variants, implemented as subclasses of this class.
|
|
This base class contains the logic to match/sample proposals.
|
|
But it is not necessary to inherit this class if the sampling logic is not needed.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(
|
|
self,
|
|
*,
|
|
num_classes,
|
|
batch_size_per_image,
|
|
positive_fraction,
|
|
proposal_matcher,
|
|
proposal_matcher_unk,
|
|
enable_thresold_autolabelling,
|
|
unk_k,
|
|
proposal_append_gt=True,
|
|
):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
num_classes (int): number of classes. Used to label background proposals.
|
|
batch_size_per_image (int): number of proposals to sample for training
|
|
positive_fraction (float): fraction of positive (foreground) proposals
|
|
to sample for training.
|
|
proposal_matcher (Matcher): matcher that matches proposals and ground truth
|
|
proposal_append_gt (bool): whether to include ground truth as proposals as well
|
|
"""
|
|
super().__init__()
|
|
self.batch_size_per_image = batch_size_per_image
|
|
self.positive_fraction = positive_fraction
|
|
self.num_classes = num_classes
|
|
self.proposal_matcher = proposal_matcher
|
|
self.proposal_matcher_unk = proposal_matcher_unk
|
|
self.proposal_append_gt = proposal_append_gt
|
|
self.enable_thresold_autolabelling = enable_thresold_autolabelling
|
|
self.unk_k = unk_k
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg):
|
|
return {
|
|
"batch_size_per_image": cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE,
|
|
"positive_fraction": cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION,
|
|
"num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES,
|
|
"proposal_append_gt": cfg.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT,
|
|
# Matcher to assign box proposals to gt boxes
|
|
"proposal_matcher": Matcher(
|
|
cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS,
|
|
cfg.MODEL.ROI_HEADS.IOU_LABELS,
|
|
allow_low_quality_matches=False,
|
|
),
|
|
"proposal_matcher_unk": Matcher(
|
|
# cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS_UNK,
|
|
[0.8],
|
|
cfg.MODEL.ROI_HEADS.IOU_LABELS,
|
|
allow_low_quality_matches=False,
|
|
),
|
|
"enable_thresold_autolabelling": cfg.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK,
|
|
"unk_k": cfg.OWOD.NUM_UNK_PER_IMAGE,
|
|
}
|
|
|
|
def _sample_proposals(
|
|
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor,
|
|
objectness_logits: torch.Tensor = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Based on the matching between N proposals and M groundtruth,
|
|
sample the proposals and set their classification labels.
|
|
|
|
Args:
|
|
matched_idxs (Tensor): a vector of length N, each is the best-matched
|
|
gt index in [0, M) for each proposal.
|
|
matched_labels (Tensor): a vector of length N, the matcher's label
|
|
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
|
|
gt_classes (Tensor): a vector of length M.
|
|
|
|
Returns:
|
|
Tensor: a vector of indices of sampled proposals. Each is in [0, N).
|
|
Tensor: a vector of the same length, the classification label for
|
|
each sampled proposal. Each sample is labeled as either a category in
|
|
[0, num_classes) or the background (num_classes).
|
|
"""
|
|
has_gt = gt_classes.numel() > 0
|
|
# Get the corresponding GT for each proposal
|
|
if has_gt:
|
|
gt_classes = gt_classes[matched_idxs]
|
|
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
|
|
gt_classes[matched_labels == 0] = self.num_classes
|
|
# Label ignore proposals (-1 label)
|
|
gt_classes[matched_labels == -1] = -1
|
|
else:
|
|
gt_classes = torch.zeros_like(matched_idxs) + self.num_classes
|
|
|
|
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
|
|
gt_classes, self.batch_size_per_image, self.positive_fraction, self.num_classes
|
|
)
|
|
|
|
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
|
|
gt_classes_ss = gt_classes[sampled_idxs]
|
|
|
|
return sampled_idxs, gt_classes_ss
|
|
|
|
@torch.no_grad()
|
|
def label_and_sample_proposals(
|
|
self, proposals: List[Instances], targets: List[Instances], image_id = None, ori_image = None
|
|
) -> List[Instances]:
|
|
"""
|
|
Prepare some proposals to be used to train the ROI heads.
|
|
It performs box matching between `proposals` and `targets`, and assigns
|
|
training labels to the proposals.
|
|
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
|
|
boxes, with a fraction of positives that is no larger than
|
|
``self.positive_fraction``.
|
|
|
|
Args:
|
|
See :meth:`ROIHeads.forward`
|
|
|
|
Returns:
|
|
list[Instances]:
|
|
length `N` list of `Instances`s containing the proposals
|
|
sampled for training. Each `Instances` has the following fields:
|
|
|
|
- proposal_boxes: the proposal boxes
|
|
- gt_boxes: the ground-truth box that the proposal is assigned to
|
|
(this is only meaningful if the proposal has a label > 0; if label = 0
|
|
then the ground-truth box is random)
|
|
|
|
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
|
|
"""
|
|
gt_boxes = [x.gt_boxes for x in targets]
|
|
# Augment proposals with ground-truth boxes.
|
|
# In the case of learned proposals (e.g., RPN), when training starts
|
|
# the proposals will be low quality due to random initialization.
|
|
# It's possible that none of these initial
|
|
# proposals have high enough overlap with the gt objects to be used
|
|
# as positive examples for the second stage components (box head,
|
|
# cls head, mask head). Adding the gt boxes to the set of proposals
|
|
# ensures that the second stage components will have some positive
|
|
# examples from the start of training. For RPN, this augmentation improves
|
|
# convergence and empirically improves box AP on COCO by about 0.5
|
|
# points (under one tested configuration).
|
|
if self.proposal_append_gt:
|
|
proposals = add_ground_truth_to_proposals(gt_boxes, proposals)
|
|
|
|
proposals_with_gt = []
|
|
unk_sel_gt = []
|
|
|
|
num_fg_samples = []
|
|
num_bg_samples = []
|
|
for proposals_per_image, targets_per_image,image_id_i,ori_image_i\
|
|
in zip(proposals, targets, image_id, ori_image):
|
|
height_new, width_new = proposals_per_image.image_size
|
|
has_gt = len(targets_per_image) > 0
|
|
match_quality_matrix = pairwise_iou(
|
|
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
|
|
)
|
|
matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)
|
|
sampled_idxs, gt_classes = self._sample_proposals(
|
|
matched_idxs, matched_labels, targets_per_image.gt_classes, proposals_per_image.objectness_logits
|
|
)
|
|
del match_quality_matrix
|
|
gt_flag = False
|
|
unk_flag = False
|
|
storage = get_event_storage()
|
|
if self.enable_thresold_autolabelling and storage.iter > 50000:
|
|
matched_labels_ss = matched_labels[sampled_idxs]
|
|
pred_objectness_score_ss = proposals_per_image.objectness_logits[sampled_idxs]
|
|
|
|
pred_objectness_score_ss[matched_labels_ss != 0] = -1
|
|
sorted_indices = list(zip(
|
|
*heapq.nlargest(50, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0]
|
|
mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool)
|
|
|
|
new_flag = True
|
|
for index in sorted_indices:
|
|
if new_flag:
|
|
auotolabel_boxes = proposals_per_image.proposal_boxes[sampled_idxs[index].item()]
|
|
autolabel_score = proposals_per_image.objectness_logits[sampled_idxs[index].item()].view(1,-1)
|
|
new_flag = False
|
|
else:
|
|
box_i = proposals_per_image.proposal_boxes[sampled_idxs[index].item()]
|
|
score_i = proposals_per_image.objectness_logits[sampled_idxs[index].item()].view(1,-1)
|
|
auotolabel_boxes = Boxes.cat([auotolabel_boxes, box_i])
|
|
autolabel_score = torch.cat([autolabel_score,score_i],1)
|
|
|
|
obj_save_path = "../score_store/" + image_id_i + ".jpg"+".pickle"
|
|
obj_score_save = torch.load(obj_save_path)
|
|
height_ori, width_ori = obj_score_save['image_size']
|
|
obj_score_boxes = obj_score_save['obj_boxes']
|
|
|
|
if len(obj_score_boxes):
|
|
obj_boxes_sel = obj_score_boxes[:50,:4]
|
|
obj_boxes_sel[:,0] = obj_boxes_sel[:,0] * (width_new * 1.0 / width_ori)
|
|
obj_boxes_sel[:,1] = obj_boxes_sel[:,1] * (height_new * 1.0 / height_ori)
|
|
obj_boxes_sel[:,2] = obj_boxes_sel[:,2] * (width_new * 1.0 / width_ori)
|
|
obj_boxes_sel[:,3] = obj_boxes_sel[:,3] * (height_new * 1.0 / height_ori)
|
|
obj_boxes = Boxes(torch.Tensor(obj_boxes_sel).cuda())
|
|
|
|
area_new = width_new * height_new
|
|
area_mask = obj_boxes.area() / area_new
|
|
area_mask = area_mask < 0.8
|
|
|
|
unk_match_matrix = pairwise_iou(obj_boxes, auotolabel_boxes)
|
|
unk_match_matrix[unk_match_matrix < 0.9] = 0 # 0.7
|
|
score_matrix = torch.mm(area_mask.view(-1,1).float(),autolabel_score.view(1,-1))
|
|
score_matrix = torch.mul(score_matrix, unk_match_matrix)
|
|
score_matrix, _ = torch.max(score_matrix, 0)
|
|
_, unk_max_index = torch.max(score_matrix, 0)
|
|
unk_obj_index = torch.nonzero(score_matrix).cpu()
|
|
del unk_match_matrix
|
|
if len(unk_obj_index): # pseudo
|
|
gt_flag = True
|
|
unk_instances_gt = Instances(proposals_per_image.image_size)
|
|
unk_box = auotolabel_boxes[unk_max_index.item()]
|
|
unk_instances_gt.gt_boxes = unk_box
|
|
unk_instances_gt.gt_classes = torch.Tensor([80]).long().cuda()
|
|
targets_per_image = Instances.cat([targets_per_image, unk_instances_gt])
|
|
|
|
match_quality_matrix = pairwise_iou(
|
|
unk_instances_gt.gt_boxes, proposals_per_image.proposal_boxes[sampled_idxs]
|
|
)
|
|
_, matched_labels_unk = self.proposal_matcher_unk(match_quality_matrix)
|
|
del match_quality_matrix
|
|
matched_unk_mask = matched_labels_unk == 1
|
|
matched_unk_mask_idx = torch.nonzero(matched_unk_mask)
|
|
for index in matched_unk_mask_idx:
|
|
if sampled_idxs[index] < 100:
|
|
mask[index] = True
|
|
|
|
unk_match_matrix = pairwise_iou(obj_boxes, auotolabel_boxes)
|
|
unk_match_matrix[unk_match_matrix < 0.7] = 0
|
|
|
|
score_matrix = torch.mm(area_mask.view(-1,1).float(),autolabel_score.view(1,-1))
|
|
score_matrix = torch.mul(score_matrix, unk_match_matrix)
|
|
score_matrix,_ = torch.max(score_matrix, 0)
|
|
del unk_match_matrix
|
|
unk_obj_index = torch.nonzero(score_matrix).cpu()
|
|
score_matrix = score_matrix[score_matrix > 0]
|
|
|
|
if len(unk_obj_index):
|
|
for idx in unk_obj_index:
|
|
mask[sorted_indices[idx]] = True
|
|
gt_classes[mask] = 80
|
|
|
|
# Set target attributes of the sampled proposals:
|
|
proposals_per_image = proposals_per_image[sampled_idxs]
|
|
proposals_per_image.gt_classes = gt_classes
|
|
|
|
# We index all the attributes of targets that start with "gt_"
|
|
# and have not been added to proposals yet (="gt_classes").
|
|
if has_gt:
|
|
sampled_targets = matched_idxs[sampled_idxs]
|
|
# NOTE: here the indexing waste some compute, because heads
|
|
# like masks, keypoints, etc, will filter the proposals again,
|
|
# (by foreground/background, or number of keypoints in the image, etc)
|
|
# so we essentially index the data twice.
|
|
for (trg_name, trg_value) in targets_per_image.get_fields().items():
|
|
if trg_name.startswith("gt_") and not proposals_per_image.has(trg_name):
|
|
proposals_per_image.set(trg_name, trg_value[sampled_targets])
|
|
else:
|
|
gt_boxes = Boxes(
|
|
targets_per_image.gt_boxes.tensor.new_zeros((len(sampled_idxs), 4))
|
|
)
|
|
proposals_per_image.gt_boxes = gt_boxes
|
|
|
|
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
|
|
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
|
|
|
|
|
|
if gt_flag:
|
|
unk_sel_gt.append(unk_instances_gt.gt_boxes)
|
|
else:
|
|
unk_gt_boxes = []
|
|
unk_sel_gt.append(unk_gt_boxes)
|
|
proposals_with_gt.append(proposals_per_image)
|
|
|
|
# Log the number of fg/bg samples that are selected for training ROI heads
|
|
# storage = get_event_storage()
|
|
storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples))
|
|
storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples))
|
|
return proposals_with_gt, unk_sel_gt
|
|
|
|
def forward(
|
|
self,
|
|
images: ImageList,
|
|
features: Dict[str, torch.Tensor],
|
|
proposals: List[Instances],
|
|
targets: Optional[List[Instances]] = None,
|
|
) -> Tuple[List[Instances], Dict[str, torch.Tensor]]:
|
|
"""
|
|
Args:
|
|
images (ImageList):
|
|
features (dict[str,Tensor]): input data as a mapping from feature
|
|
map name to tensor. Axis 0 represents the number of images `N` in
|
|
the input data; axes 1-3 are channels, height, and width, which may
|
|
vary between feature maps (e.g., if a feature pyramid is used).
|
|
proposals (list[Instances]): length `N` list of `Instances`. The i-th
|
|
`Instances` contains object proposals for the i-th input image,
|
|
with fields "proposal_boxes" and "objectness_logits".
|
|
targets (list[Instances], optional): length `N` list of `Instances`. The i-th
|
|
`Instances` contains the ground-truth per-instance annotations
|
|
for the i-th input image. Specify `targets` during training only.
|
|
It may have the following fields:
|
|
|
|
- gt_boxes: the bounding box of each instance.
|
|
- gt_classes: the label for each instance with a category ranging in [0, #class].
|
|
- gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
|
|
- gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
|
|
|
|
Returns:
|
|
list[Instances]: length `N` list of `Instances` containing the
|
|
detected instances. Returned during inference only; may be [] during training.
|
|
|
|
dict[str->Tensor]:
|
|
mapping from a named loss to a tensor storing the loss. Used during training only.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
@ROI_HEADS_REGISTRY.register()
|
|
class Res5ROIHeads(ROIHeads):
|
|
"""
|
|
The ROIHeads in a typical "C4" R-CNN model, where
|
|
the box and mask head share the cropping and
|
|
the per-region feature computation by a Res5 block.
|
|
"""
|
|
|
|
def __init__(self, cfg, input_shape):
|
|
super().__init__(cfg)
|
|
|
|
# fmt: off
|
|
self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
|
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
|
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
|
pooler_scales = (1.0 / input_shape[self.in_features[0]].stride, )
|
|
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
|
self.mask_on = cfg.MODEL.MASK_ON
|
|
self.enable_clustering = cfg.OWOD.ENABLE_CLUSTERING
|
|
self.compute_energy_flag = cfg.OWOD.COMPUTE_ENERGY
|
|
self.energy_save_path = os.path.join(cfg.OUTPUT_DIR, cfg.OWOD.ENERGY_SAVE_PATH)
|
|
# fmt: on
|
|
assert not cfg.MODEL.KEYPOINT_ON
|
|
assert len(self.in_features) == 1
|
|
|
|
self.pooler = ROIPooler(
|
|
output_size=pooler_resolution,
|
|
scales=pooler_scales,
|
|
sampling_ratio=sampling_ratio,
|
|
pooler_type=pooler_type,
|
|
)
|
|
|
|
self.res5, out_channels = self._build_res5_block(cfg)
|
|
self.box_predictor = FastRCNNOutputLayers(
|
|
cfg, ShapeSpec(channels=out_channels, height=1, width=1)
|
|
)
|
|
|
|
if self.mask_on:
|
|
self.mask_head = build_mask_head(
|
|
cfg,
|
|
ShapeSpec(channels=out_channels, width=pooler_resolution, height=pooler_resolution),
|
|
)
|
|
|
|
def _build_res5_block(self, cfg):
|
|
# fmt: off
|
|
stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
|
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
|
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
|
bottleneck_channels = num_groups * width_per_group * stage_channel_factor
|
|
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
|
|
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
|
norm = cfg.MODEL.RESNETS.NORM
|
|
assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \
|
|
"Deformable conv is not yet supported in res5 head."
|
|
# fmt: on
|
|
|
|
blocks = ResNet.make_stage(
|
|
BottleneckBlock,
|
|
3,
|
|
stride_per_block=[2, 1, 1],
|
|
in_channels=out_channels // 2,
|
|
bottleneck_channels=bottleneck_channels,
|
|
out_channels=out_channels,
|
|
num_groups=num_groups,
|
|
norm=norm,
|
|
stride_in_1x1=stride_in_1x1,
|
|
)
|
|
return nn.Sequential(*blocks), out_channels
|
|
|
|
def _shared_roi_transform(self, features, boxes):
|
|
x = self.pooler(features, boxes)
|
|
return self.res5(x)
|
|
|
|
def log_features(self, features, proposals):
|
|
gt_classes = torch.cat([p.gt_classes for p in proposals])
|
|
data = (features, gt_classes)
|
|
location = '/home/fk1/workspace/OWOD/output/features/' + shortuuid.uuid() + '.pkl'
|
|
torch.save(data, location)
|
|
|
|
def compute_energy(self, predictions, proposals):
|
|
gt_classes = torch.cat([p.gt_classes for p in proposals])
|
|
logits = predictions[0]
|
|
data = (logits, gt_classes)
|
|
location = os.path.join(self.energy_save_path, shortuuid.uuid() + '.pkl')
|
|
torch.save(data, location)
|
|
|
|
def forward(self, images, features, proposals, targets=None, image_id=None, ori_image=None):
|
|
"""
|
|
See :meth:`ROIHeads.forward`.
|
|
"""
|
|
del images
|
|
|
|
if self.training:
|
|
assert targets
|
|
proposals, unk_sel_gt = self.label_and_sample_proposals(proposals, targets, image_id, ori_image)
|
|
del targets
|
|
|
|
proposal_boxes = [x.proposal_boxes for x in proposals]
|
|
box_features = self._shared_roi_transform(
|
|
[features[f] for f in self.in_features], proposal_boxes
|
|
)
|
|
input_features = box_features.mean(dim=[2, 3])
|
|
predictions = self.box_predictor(input_features)
|
|
|
|
if self.training:
|
|
# self.log_features(input_features, proposals)
|
|
if self.enable_clustering:
|
|
self.box_predictor.update_feature_store(input_features, proposals)
|
|
del features
|
|
if self.compute_energy_flag:
|
|
self.compute_energy(predictions, proposals)
|
|
losses = self.box_predictor.losses(predictions, proposals, input_features)
|
|
if self.mask_on:
|
|
proposals, fg_selection_masks = select_foreground_proposals(
|
|
proposals, self.num_classes
|
|
)
|
|
# Since the ROI feature transform is shared between boxes and masks,
|
|
# we don't need to recompute features. The mask loss is only defined
|
|
# on foreground proposals, so we need to select out the foreground
|
|
# features.
|
|
mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
|
|
del box_features
|
|
losses.update(self.mask_head(mask_features, proposals))
|
|
return [], losses, unk_sel_gt
|
|
else:
|
|
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
|
|
pred_instances = self.forward_with_given_boxes(features, pred_instances)
|
|
return pred_instances, {}, []
|
|
|
|
def forward_with_given_boxes(self, features, instances):
|
|
"""
|
|
Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
|
|
|
|
Args:
|
|
features: same as in `forward()`
|
|
instances (list[Instances]): instances to predict other outputs. Expect the keys
|
|
"pred_boxes" and "pred_classes" to exist.
|
|
|
|
Returns:
|
|
instances (Instances):
|
|
the same `Instances` object, with extra
|
|
fields such as `pred_masks` or `pred_keypoints`.
|
|
"""
|
|
assert not self.training
|
|
assert instances[0].has("pred_boxes") and instances[0].has("pred_classes")
|
|
|
|
if self.mask_on:
|
|
features = [features[f] for f in self.in_features]
|
|
x = self._shared_roi_transform(features, [x.pred_boxes for x in instances])
|
|
return self.mask_head(x, instances)
|
|
else:
|
|
return instances
|
|
|
|
|
|
@ROI_HEADS_REGISTRY.register()
|
|
class StandardROIHeads(ROIHeads):
|
|
"""
|
|
It's "standard" in a sense that there is no ROI transform sharing
|
|
or feature sharing between tasks.
|
|
Each head independently processes the input features by each head's
|
|
own pooler and head.
|
|
|
|
This class is used by most models, such as FPN and C5.
|
|
To implement more models, you can subclass it and implement a different
|
|
:meth:`forward()` or a head.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(
|
|
self,
|
|
*,
|
|
box_in_features: List[str],
|
|
box_pooler: ROIPooler,
|
|
box_head: nn.Module,
|
|
box_predictor: nn.Module,
|
|
mask_in_features: Optional[List[str]] = None,
|
|
mask_pooler: Optional[ROIPooler] = None,
|
|
mask_head: Optional[nn.Module] = None,
|
|
keypoint_in_features: Optional[List[str]] = None,
|
|
keypoint_pooler: Optional[ROIPooler] = None,
|
|
keypoint_head: Optional[nn.Module] = None,
|
|
train_on_pred_boxes: bool = False,
|
|
**kwargs
|
|
):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
box_in_features (list[str]): list of feature names to use for the box head.
|
|
box_pooler (ROIPooler): pooler to extra region features for box head
|
|
box_head (nn.Module): transform features to make box predictions
|
|
box_predictor (nn.Module): make box predictions from the feature.
|
|
Should have the same interface as :class:`FastRCNNOutputLayers`.
|
|
mask_in_features (list[str]): list of feature names to use for the mask
|
|
pooler or mask head. None if not using mask head.
|
|
mask_pooler (ROIPooler): pooler to extract region features from image features.
|
|
The mask head will then take region features to make predictions.
|
|
If None, the mask head will directly take the dict of image features
|
|
defined by `mask_in_features`
|
|
mask_head (nn.Module): transform features to make mask predictions
|
|
keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``.
|
|
train_on_pred_boxes (bool): whether to use proposal boxes or
|
|
predicted boxes from the box head to train other heads.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
# keep self.in_features for backward compatibility
|
|
self.in_features = self.box_in_features = box_in_features
|
|
self.box_pooler = box_pooler
|
|
self.box_head = box_head
|
|
self.box_predictor = box_predictor
|
|
|
|
self.mask_on = mask_in_features is not None
|
|
if self.mask_on:
|
|
self.mask_in_features = mask_in_features
|
|
self.mask_pooler = mask_pooler
|
|
self.mask_head = mask_head
|
|
self.keypoint_on = keypoint_in_features is not None
|
|
if self.keypoint_on:
|
|
self.keypoint_in_features = keypoint_in_features
|
|
self.keypoint_pooler = keypoint_pooler
|
|
self.keypoint_head = keypoint_head
|
|
|
|
self.train_on_pred_boxes = train_on_pred_boxes
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, input_shape):
|
|
ret = super().from_config(cfg)
|
|
ret["train_on_pred_boxes"] = cfg.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES
|
|
# Subclasses that have not been updated to use from_config style construction
|
|
# may have overridden _init_*_head methods. In this case, those overridden methods
|
|
# will not be classmethods and we need to avoid trying to call them here.
|
|
# We test for this with ismethod which only returns True for bound methods of cls.
|
|
# Such subclasses will need to handle calling their overridden _init_*_head methods.
|
|
if inspect.ismethod(cls._init_box_head):
|
|
ret.update(cls._init_box_head(cfg, input_shape))
|
|
if inspect.ismethod(cls._init_mask_head):
|
|
ret.update(cls._init_mask_head(cfg, input_shape))
|
|
if inspect.ismethod(cls._init_keypoint_head):
|
|
ret.update(cls._init_keypoint_head(cfg, input_shape))
|
|
return ret
|
|
|
|
@classmethod
|
|
def _init_box_head(cls, cfg, input_shape):
|
|
# fmt: off
|
|
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
|
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
|
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
|
|
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
|
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
|
# fmt: on
|
|
|
|
# If StandardROIHeads is applied on multiple feature maps (as in FPN),
|
|
# then we share the same predictors and therefore the channel counts must be the same
|
|
in_channels = [input_shape[f].channels for f in in_features]
|
|
# Check all channel counts are equal
|
|
assert len(set(in_channels)) == 1, in_channels
|
|
in_channels = in_channels[0]
|
|
|
|
box_pooler = ROIPooler(
|
|
output_size=pooler_resolution,
|
|
scales=pooler_scales,
|
|
sampling_ratio=sampling_ratio,
|
|
pooler_type=pooler_type,
|
|
)
|
|
# Here we split "box head" and "box predictor", which is mainly due to historical reasons.
|
|
# They are used together so the "box predictor" layers should be part of the "box head".
|
|
# New subclasses of ROIHeads do not need "box predictor"s.
|
|
box_head = build_box_head(
|
|
cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution)
|
|
)
|
|
box_predictor = FastRCNNOutputLayers(cfg, box_head.output_shape)
|
|
return {
|
|
"box_in_features": in_features,
|
|
"box_pooler": box_pooler,
|
|
"box_head": box_head,
|
|
"box_predictor": box_predictor,
|
|
}
|
|
|
|
@classmethod
|
|
def _init_mask_head(cls, cfg, input_shape):
|
|
if not cfg.MODEL.MASK_ON:
|
|
return {}
|
|
# fmt: off
|
|
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
|
pooler_resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
|
|
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
|
|
sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
|
|
pooler_type = cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE
|
|
# fmt: on
|
|
|
|
in_channels = [input_shape[f].channels for f in in_features][0]
|
|
|
|
ret = {"mask_in_features": in_features}
|
|
ret["mask_pooler"] = (
|
|
ROIPooler(
|
|
output_size=pooler_resolution,
|
|
scales=pooler_scales,
|
|
sampling_ratio=sampling_ratio,
|
|
pooler_type=pooler_type,
|
|
)
|
|
if pooler_type
|
|
else None
|
|
)
|
|
if pooler_type:
|
|
shape = ShapeSpec(
|
|
channels=in_channels, width=pooler_resolution, height=pooler_resolution
|
|
)
|
|
else:
|
|
shape = {f: input_shape[f] for f in in_features}
|
|
ret["mask_head"] = build_mask_head(cfg, shape)
|
|
return ret
|
|
|
|
@classmethod
|
|
def _init_keypoint_head(cls, cfg, input_shape):
|
|
if not cfg.MODEL.KEYPOINT_ON:
|
|
return {}
|
|
# fmt: off
|
|
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
|
pooler_resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION
|
|
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) # noqa
|
|
sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO
|
|
pooler_type = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE
|
|
# fmt: on
|
|
|
|
in_channels = [input_shape[f].channels for f in in_features][0]
|
|
|
|
ret = {"keypoint_in_features": in_features}
|
|
ret["keypoint_pooler"] = (
|
|
ROIPooler(
|
|
output_size=pooler_resolution,
|
|
scales=pooler_scales,
|
|
sampling_ratio=sampling_ratio,
|
|
pooler_type=pooler_type,
|
|
)
|
|
if pooler_type
|
|
else None
|
|
)
|
|
if pooler_type:
|
|
shape = ShapeSpec(
|
|
channels=in_channels, width=pooler_resolution, height=pooler_resolution
|
|
)
|
|
else:
|
|
shape = {f: input_shape[f] for f in in_features}
|
|
ret["keypoint_head"] = build_keypoint_head(cfg, shape)
|
|
return ret
|
|
|
|
def forward(
|
|
self,
|
|
images: ImageList,
|
|
features: Dict[str, torch.Tensor],
|
|
proposals: List[Instances],
|
|
targets: Optional[List[Instances]] = None,
|
|
) -> Tuple[List[Instances], Dict[str, torch.Tensor]]:
|
|
"""
|
|
See :class:`ROIHeads.forward`.
|
|
"""
|
|
del images
|
|
if self.training:
|
|
assert targets
|
|
proposals = self.label_and_sample_proposals(proposals, targets)
|
|
del targets
|
|
|
|
if self.training:
|
|
losses = self._forward_box(features, proposals)
|
|
# Usually the original proposals used by the box head are used by the mask, keypoint
|
|
# heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes
|
|
# predicted by the box head.
|
|
losses.update(self._forward_mask(features, proposals))
|
|
losses.update(self._forward_keypoint(features, proposals))
|
|
return proposals, losses
|
|
else:
|
|
pred_instances = self._forward_box(features, proposals)
|
|
# During inference cascaded prediction is used: the mask and keypoints heads are only
|
|
# applied to the top scoring box detections.
|
|
pred_instances = self.forward_with_given_boxes(features, pred_instances)
|
|
return pred_instances, {}
|
|
|
|
def forward_with_given_boxes(
|
|
self, features: Dict[str, torch.Tensor], instances: List[Instances]
|
|
) -> List[Instances]:
|
|
"""
|
|
Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
|
|
|
|
This is useful for downstream tasks where a box is known, but need to obtain
|
|
other attributes (outputs of other heads).
|
|
Test-time augmentation also uses this.
|
|
|
|
Args:
|
|
features: same as in `forward()`
|
|
instances (list[Instances]): instances to predict other outputs. Expect the keys
|
|
"pred_boxes" and "pred_classes" to exist.
|
|
|
|
Returns:
|
|
instances (list[Instances]):
|
|
the same `Instances` objects, with extra
|
|
fields such as `pred_masks` or `pred_keypoints`.
|
|
"""
|
|
assert not self.training
|
|
assert instances[0].has("pred_boxes") and instances[0].has("pred_classes")
|
|
|
|
instances = self._forward_mask(features, instances)
|
|
instances = self._forward_keypoint(features, instances)
|
|
return instances
|
|
|
|
def _forward_box(
|
|
self, features: Dict[str, torch.Tensor], proposals: List[Instances]
|
|
) -> Union[Dict[str, torch.Tensor], List[Instances]]:
|
|
"""
|
|
Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`,
|
|
the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument.
|
|
|
|
Args:
|
|
features (dict[str, Tensor]): mapping from feature map names to tensor.
|
|
Same as in :meth:`ROIHeads.forward`.
|
|
proposals (list[Instances]): the per-image object proposals with
|
|
their matching ground truth.
|
|
Each has fields "proposal_boxes", and "objectness_logits",
|
|
"gt_classes", "gt_boxes".
|
|
|
|
Returns:
|
|
In training, a dict of losses.
|
|
In inference, a list of `Instances`, the predicted instances.
|
|
"""
|
|
features = [features[f] for f in self.box_in_features]
|
|
box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])
|
|
box_features = self.box_head(box_features)
|
|
predictions = self.box_predictor(box_features)
|
|
del box_features
|
|
|
|
if self.training:
|
|
losses = self.box_predictor.losses(predictions, proposals)
|
|
# proposals is modified in-place below, so losses must be computed first.
|
|
if self.train_on_pred_boxes:
|
|
with torch.no_grad():
|
|
pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(
|
|
predictions, proposals
|
|
)
|
|
for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):
|
|
proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)
|
|
return losses
|
|
else:
|
|
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
|
|
return pred_instances
|
|
|
|
def _forward_mask(
|
|
self, features: Dict[str, torch.Tensor], instances: List[Instances]
|
|
) -> Union[Dict[str, torch.Tensor], List[Instances]]:
|
|
"""
|
|
Forward logic of the mask prediction branch.
|
|
|
|
Args:
|
|
features (dict[str, Tensor]): mapping from feature map names to tensor.
|
|
Same as in :meth:`ROIHeads.forward`.
|
|
instances (list[Instances]): the per-image instances to train/predict masks.
|
|
In training, they can be the proposals.
|
|
In inference, they can be the boxes predicted by R-CNN box head.
|
|
|
|
Returns:
|
|
In training, a dict of losses.
|
|
In inference, update `instances` with new fields "pred_masks" and return it.
|
|
"""
|
|
if not self.mask_on:
|
|
return {} if self.training else instances
|
|
|
|
if self.training:
|
|
# head is only trained on positive proposals.
|
|
instances, _ = select_foreground_proposals(instances, self.num_classes)
|
|
|
|
if self.mask_pooler is not None:
|
|
features = [features[f] for f in self.mask_in_features]
|
|
boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances]
|
|
features = self.mask_pooler(features, boxes)
|
|
else:
|
|
features = {f: features[f] for f in self.mask_in_features}
|
|
return self.mask_head(features, instances)
|
|
|
|
def _forward_keypoint(
|
|
self, features: Dict[str, torch.Tensor], instances: List[Instances]
|
|
) -> Union[Dict[str, torch.Tensor], List[Instances]]:
|
|
"""
|
|
Forward logic of the keypoint prediction branch.
|
|
|
|
Args:
|
|
features (dict[str, Tensor]): mapping from feature map names to tensor.
|
|
Same as in :meth:`ROIHeads.forward`.
|
|
instances (list[Instances]): the per-image instances to train/predict keypoints.
|
|
In training, they can be the proposals.
|
|
In inference, they can be the boxes predicted by R-CNN box head.
|
|
|
|
Returns:
|
|
In training, a dict of losses.
|
|
In inference, update `instances` with new fields "pred_keypoints" and return it.
|
|
"""
|
|
if not self.keypoint_on:
|
|
return {} if self.training else instances
|
|
|
|
if self.training:
|
|
# head is only trained on positive proposals with >=1 visible keypoints.
|
|
instances, _ = select_foreground_proposals(instances, self.num_classes)
|
|
instances = select_proposals_with_visible_keypoints(instances)
|
|
|
|
if self.keypoint_pooler is not None:
|
|
features = [features[f] for f in self.keypoint_in_features]
|
|
boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances]
|
|
features = self.keypoint_pooler(features, boxes)
|
|
else:
|
|
features = {f: features[f] for f in self.keypoint_in_features}
|
|
return self.keypoint_head(features, instances)
|