mirror of https://github.com/RE-OWOD/RE-OWOD
505 lines
21 KiB
Python
505 lines
21 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import logging
|
|
import math
|
|
import numpy as np
|
|
from typing import List
|
|
import torch
|
|
from fvcore.nn import giou_loss, sigmoid_focal_loss_jit, smooth_l1_loss
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.data.detection_utils import convert_image_to_rgb
|
|
from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm
|
|
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
|
|
from detectron2.utils.events import get_event_storage
|
|
|
|
from ..anchor_generator import build_anchor_generator
|
|
from ..backbone import build_backbone
|
|
from ..box_regression import Box2BoxTransform
|
|
from ..matcher import Matcher
|
|
from ..postprocessing import detector_postprocess
|
|
from .build import META_ARCH_REGISTRY
|
|
|
|
__all__ = ["RetinaNet"]
|
|
|
|
|
|
def permute_to_N_HWA_K(tensor, K):
|
|
"""
|
|
Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K)
|
|
"""
|
|
assert tensor.dim() == 4, tensor.shape
|
|
N, _, H, W = tensor.shape
|
|
tensor = tensor.view(N, -1, K, H, W)
|
|
tensor = tensor.permute(0, 3, 4, 1, 2)
|
|
tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K)
|
|
return tensor
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class RetinaNet(nn.Module):
|
|
"""
|
|
Implement RetinaNet in :paper:`RetinaNet`.
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
# fmt: off
|
|
self.num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
|
|
self.in_features = cfg.MODEL.RETINANET.IN_FEATURES
|
|
# Loss parameters:
|
|
self.focal_loss_alpha = cfg.MODEL.RETINANET.FOCAL_LOSS_ALPHA
|
|
self.focal_loss_gamma = cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA
|
|
self.smooth_l1_loss_beta = cfg.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA
|
|
self.box_reg_loss_type = cfg.MODEL.RETINANET.BBOX_REG_LOSS_TYPE
|
|
# Inference parameters:
|
|
self.score_threshold = cfg.MODEL.RETINANET.SCORE_THRESH_TEST
|
|
self.topk_candidates = cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST
|
|
self.nms_threshold = cfg.MODEL.RETINANET.NMS_THRESH_TEST
|
|
self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE
|
|
# Vis parameters
|
|
self.vis_period = cfg.VIS_PERIOD
|
|
self.input_format = cfg.INPUT.FORMAT
|
|
# fmt: on
|
|
|
|
self.backbone = build_backbone(cfg)
|
|
|
|
backbone_shape = self.backbone.output_shape()
|
|
feature_shapes = [backbone_shape[f] for f in self.in_features]
|
|
self.head = RetinaNetHead(cfg, feature_shapes)
|
|
self.anchor_generator = build_anchor_generator(cfg, feature_shapes)
|
|
|
|
# Matching and loss
|
|
self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.RETINANET.BBOX_REG_WEIGHTS)
|
|
self.anchor_matcher = Matcher(
|
|
cfg.MODEL.RETINANET.IOU_THRESHOLDS,
|
|
cfg.MODEL.RETINANET.IOU_LABELS,
|
|
allow_low_quality_matches=True,
|
|
)
|
|
|
|
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
|
|
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
|
|
|
|
"""
|
|
In Detectron1, loss is normalized by number of foreground samples in the batch.
|
|
When batch size is 1 per GPU, #foreground has a large variance and
|
|
using it lead to lower performance. Here we maintain an EMA of #foreground to
|
|
stabilize the normalizer.
|
|
"""
|
|
self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
|
|
self.loss_normalizer_momentum = 0.9
|
|
|
|
@property
|
|
def device(self):
|
|
return self.pixel_mean.device
|
|
|
|
def visualize_training(self, batched_inputs, results):
|
|
"""
|
|
A function used to visualize ground truth images and final network predictions.
|
|
It shows ground truth bounding boxes on the original image and up to 20
|
|
predicted object bounding boxes on the original image.
|
|
|
|
Args:
|
|
batched_inputs (list): a list that contains input to the model.
|
|
results (List[Instances]): a list of #images elements.
|
|
"""
|
|
from detectron2.utils.visualizer import Visualizer
|
|
|
|
assert len(batched_inputs) == len(
|
|
results
|
|
), "Cannot visualize inputs and results of different sizes"
|
|
storage = get_event_storage()
|
|
max_boxes = 20
|
|
|
|
image_index = 0 # only visualize a single image
|
|
img = batched_inputs[image_index]["image"]
|
|
img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)
|
|
v_gt = Visualizer(img, None)
|
|
v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes)
|
|
anno_img = v_gt.get_image()
|
|
processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1])
|
|
predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy()
|
|
|
|
v_pred = Visualizer(img, None)
|
|
v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes])
|
|
prop_img = v_pred.get_image()
|
|
vis_img = np.vstack((anno_img, prop_img))
|
|
vis_img = vis_img.transpose(2, 0, 1)
|
|
vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results"
|
|
storage.put_image(vis_name, vis_img)
|
|
|
|
def forward(self, batched_inputs):
|
|
"""
|
|
Args:
|
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
|
|
Each item in the list contains the inputs for one image.
|
|
For now, each item in the list is a dict that contains:
|
|
|
|
* image: Tensor, image in (C, H, W) format.
|
|
* instances: Instances
|
|
|
|
Other information that's included in the original dicts, such as:
|
|
|
|
* "height", "width" (int): the output resolution of the model, used in inference.
|
|
See :meth:`postprocess` for details.
|
|
Returns:
|
|
dict[str: Tensor]:
|
|
mapping from a named loss to a tensor storing the loss. Used during training only.
|
|
"""
|
|
images = self.preprocess_image(batched_inputs)
|
|
features = self.backbone(images.tensor)
|
|
features = [features[f] for f in self.in_features]
|
|
|
|
anchors = self.anchor_generator(features)
|
|
pred_logits, pred_anchor_deltas = self.head(features)
|
|
# Transpose the Hi*Wi*A dimension to the middle:
|
|
pred_logits = [permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits]
|
|
pred_anchor_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas]
|
|
|
|
if self.training:
|
|
assert "instances" in batched_inputs[0], "Instance annotations are missing in training!"
|
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
|
|
|
gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)
|
|
losses = self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)
|
|
|
|
if self.vis_period > 0:
|
|
storage = get_event_storage()
|
|
if storage.iter % self.vis_period == 0:
|
|
results = self.inference(
|
|
anchors, pred_logits, pred_anchor_deltas, images.image_sizes
|
|
)
|
|
self.visualize_training(batched_inputs, results)
|
|
|
|
return losses
|
|
else:
|
|
results = self.inference(anchors, pred_logits, pred_anchor_deltas, images.image_sizes)
|
|
processed_results = []
|
|
for results_per_image, input_per_image, image_size in zip(
|
|
results, batched_inputs, images.image_sizes
|
|
):
|
|
height = input_per_image.get("height", image_size[0])
|
|
width = input_per_image.get("width", image_size[1])
|
|
r = detector_postprocess(results_per_image, height, width)
|
|
processed_results.append({"instances": r})
|
|
return processed_results
|
|
|
|
def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes):
|
|
"""
|
|
Args:
|
|
anchors (list[Boxes]): a list of #feature level Boxes
|
|
gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`.
|
|
Their shapes are (N, R) and (N, R, 4), respectively, where R is
|
|
the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)
|
|
pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the
|
|
list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4).
|
|
Where K is the number of classes used in `pred_logits`.
|
|
|
|
Returns:
|
|
dict[str, Tensor]:
|
|
mapping from a named loss to a scalar tensor
|
|
storing the loss. Used during training only. The dict keys are:
|
|
"loss_cls" and "loss_box_reg"
|
|
"""
|
|
num_images = len(gt_labels)
|
|
gt_labels = torch.stack(gt_labels) # (N, R)
|
|
anchors = type(anchors[0]).cat(anchors).tensor # (R, 4)
|
|
gt_anchor_deltas = [self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
|
|
gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4)
|
|
|
|
valid_mask = gt_labels >= 0
|
|
pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
|
|
num_pos_anchors = pos_mask.sum().item()
|
|
get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
|
|
self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
|
|
1 - self.loss_normalizer_momentum
|
|
) * max(num_pos_anchors, 1)
|
|
|
|
# classification and regression loss
|
|
gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[
|
|
:, :-1
|
|
] # no loss for the last (background) class
|
|
loss_cls = sigmoid_focal_loss_jit(
|
|
cat(pred_logits, dim=1)[valid_mask],
|
|
gt_labels_target.to(pred_logits[0].dtype),
|
|
alpha=self.focal_loss_alpha,
|
|
gamma=self.focal_loss_gamma,
|
|
reduction="sum",
|
|
)
|
|
|
|
if self.box_reg_loss_type == "smooth_l1":
|
|
loss_box_reg = smooth_l1_loss(
|
|
cat(pred_anchor_deltas, dim=1)[pos_mask],
|
|
gt_anchor_deltas[pos_mask],
|
|
beta=self.smooth_l1_loss_beta,
|
|
reduction="sum",
|
|
)
|
|
elif self.box_reg_loss_type == "giou":
|
|
pred_boxes = [
|
|
self.box2box_transform.apply_deltas(k, anchors)
|
|
for k in cat(pred_anchor_deltas, dim=1)
|
|
]
|
|
loss_box_reg = giou_loss(
|
|
torch.stack(pred_boxes)[pos_mask], torch.stack(gt_boxes)[pos_mask], reduction="sum"
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
|
|
|
|
return {
|
|
"loss_cls": loss_cls / self.loss_normalizer,
|
|
"loss_box_reg": loss_box_reg / self.loss_normalizer,
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def label_anchors(self, anchors, gt_instances):
|
|
"""
|
|
Args:
|
|
anchors (list[Boxes]): A list of #feature level Boxes.
|
|
The Boxes contains anchors of this image on the specific feature level.
|
|
gt_instances (list[Instances]): a list of N `Instances`s. The i-th
|
|
`Instances` contains the ground-truth per-instance annotations
|
|
for the i-th input image.
|
|
|
|
Returns:
|
|
list[Tensor]:
|
|
List of #img tensors. i-th element is a vector of labels whose length is
|
|
the total number of anchors across all feature maps (sum(Hi * Wi * A)).
|
|
Label values are in {-1, 0, ..., K}, with -1 means ignore, and K means background.
|
|
list[Tensor]:
|
|
i-th element is a Rx4 tensor, where R is the total number of anchors across
|
|
feature maps. The values are the matched gt boxes for each anchor.
|
|
Values are undefined for those anchors not labeled as foreground.
|
|
"""
|
|
anchors = Boxes.cat(anchors) # Rx4
|
|
|
|
gt_labels = []
|
|
matched_gt_boxes = []
|
|
for gt_per_image in gt_instances:
|
|
match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors)
|
|
matched_idxs, anchor_labels = self.anchor_matcher(match_quality_matrix)
|
|
del match_quality_matrix
|
|
|
|
if len(gt_per_image) > 0:
|
|
matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs]
|
|
|
|
gt_labels_i = gt_per_image.gt_classes[matched_idxs]
|
|
# Anchors with label 0 are treated as background.
|
|
gt_labels_i[anchor_labels == 0] = self.num_classes
|
|
# Anchors with label -1 are ignored.
|
|
gt_labels_i[anchor_labels == -1] = -1
|
|
else:
|
|
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
|
|
gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes
|
|
|
|
gt_labels.append(gt_labels_i)
|
|
matched_gt_boxes.append(matched_gt_boxes_i)
|
|
|
|
return gt_labels, matched_gt_boxes
|
|
|
|
def inference(self, anchors, pred_logits, pred_anchor_deltas, image_sizes):
|
|
"""
|
|
Arguments:
|
|
anchors (list[Boxes]): A list of #feature level Boxes.
|
|
The Boxes contain anchors of this image on the specific feature level.
|
|
pred_logits, pred_anchor_deltas: list[Tensor], one per level. Each
|
|
has shape (N, Hi * Wi * Ai, K or 4)
|
|
image_sizes (List[torch.Size]): the input image sizes
|
|
|
|
Returns:
|
|
results (List[Instances]): a list of #images elements.
|
|
"""
|
|
results = []
|
|
for img_idx, image_size in enumerate(image_sizes):
|
|
pred_logits_per_image = [x[img_idx] for x in pred_logits]
|
|
deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]
|
|
results_per_image = self.inference_single_image(
|
|
anchors, pred_logits_per_image, deltas_per_image, tuple(image_size)
|
|
)
|
|
results.append(results_per_image)
|
|
return results
|
|
|
|
def inference_single_image(self, anchors, box_cls, box_delta, image_size):
|
|
"""
|
|
Single-image inference. Return bounding-box detection results by thresholding
|
|
on scores and applying non-maximum suppression (NMS).
|
|
|
|
Arguments:
|
|
anchors (list[Boxes]): list of #feature levels. Each entry contains
|
|
a Boxes object, which contains all the anchors in that feature level.
|
|
box_cls (list[Tensor]): list of #feature levels. Each entry contains
|
|
tensor of size (H x W x A, K)
|
|
box_delta (list[Tensor]): Same shape as 'box_cls' except that K becomes 4.
|
|
image_size (tuple(H, W)): a tuple of the image height and width.
|
|
|
|
Returns:
|
|
Same as `inference`, but for only one image.
|
|
"""
|
|
boxes_all = []
|
|
scores_all = []
|
|
class_idxs_all = []
|
|
|
|
# Iterate over every feature level
|
|
for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors):
|
|
# (HxWxAxK,)
|
|
predicted_prob = box_cls_i.flatten().sigmoid_()
|
|
|
|
# Apply two filtering below to make NMS faster.
|
|
# 1. Keep boxes with confidence score higher than threshold
|
|
keep_idxs = predicted_prob > self.score_threshold
|
|
predicted_prob = predicted_prob[keep_idxs]
|
|
topk_idxs = torch.nonzero(keep_idxs, as_tuple=True)[0]
|
|
|
|
# 2. Keep top k top scoring boxes only
|
|
num_topk = min(self.topk_candidates, topk_idxs.size(0))
|
|
# torch.sort is actually faster than .topk (at least on GPUs)
|
|
predicted_prob, idxs = predicted_prob.sort(descending=True)
|
|
predicted_prob = predicted_prob[:num_topk]
|
|
topk_idxs = topk_idxs[idxs[:num_topk]]
|
|
|
|
anchor_idxs = topk_idxs // self.num_classes
|
|
classes_idxs = topk_idxs % self.num_classes
|
|
|
|
box_reg_i = box_reg_i[anchor_idxs]
|
|
anchors_i = anchors_i[anchor_idxs]
|
|
# predict boxes
|
|
predicted_boxes = self.box2box_transform.apply_deltas(box_reg_i, anchors_i.tensor)
|
|
|
|
boxes_all.append(predicted_boxes)
|
|
scores_all.append(predicted_prob)
|
|
class_idxs_all.append(classes_idxs)
|
|
|
|
boxes_all, scores_all, class_idxs_all = [
|
|
cat(x) for x in [boxes_all, scores_all, class_idxs_all]
|
|
]
|
|
keep = batched_nms(boxes_all, scores_all, class_idxs_all, self.nms_threshold)
|
|
keep = keep[: self.max_detections_per_image]
|
|
|
|
result = Instances(image_size)
|
|
result.pred_boxes = Boxes(boxes_all[keep])
|
|
result.scores = scores_all[keep]
|
|
result.pred_classes = class_idxs_all[keep]
|
|
return result
|
|
|
|
def preprocess_image(self, batched_inputs):
|
|
"""
|
|
Normalize, pad and batch the input images.
|
|
"""
|
|
images = [x["image"].to(self.device) for x in batched_inputs]
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
|
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
|
|
return images
|
|
|
|
|
|
class RetinaNetHead(nn.Module):
|
|
"""
|
|
The head used in RetinaNet for object classification and box regression.
|
|
It has two subnets for the two tasks, with a common structure but separate parameters.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_shape: List[ShapeSpec],
|
|
num_classes,
|
|
num_anchors,
|
|
conv_dims: List[int],
|
|
norm="",
|
|
prior_prob=0.01,
|
|
):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
input_shape (List[ShapeSpec]): input shape
|
|
num_classes (int): number of classes. Used to label background proposals.
|
|
num_anchors (int): number of generated anchors
|
|
conv_dims (List[int]): dimensions for each convolution layer
|
|
norm (str or callable):
|
|
Normalization for conv layers except for the two output layers.
|
|
See :func:`detectron2.layers.get_norm` for supported types.
|
|
prior_prob (float): Prior weight for computing bias
|
|
"""
|
|
super().__init__()
|
|
|
|
if norm == "BN" or norm == "SyncBN":
|
|
logger = logging.getLogger(__name__)
|
|
logger.warn("Shared norm does not work well for BN, SyncBN, expect poor results")
|
|
|
|
cls_subnet = []
|
|
bbox_subnet = []
|
|
for in_channels, out_channels in zip([input_shape[0].channels] + conv_dims, conv_dims):
|
|
cls_subnet.append(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
if norm:
|
|
cls_subnet.append(get_norm(norm, out_channels))
|
|
cls_subnet.append(nn.ReLU())
|
|
bbox_subnet.append(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
if norm:
|
|
bbox_subnet.append(get_norm(norm, out_channels))
|
|
bbox_subnet.append(nn.ReLU())
|
|
|
|
self.cls_subnet = nn.Sequential(*cls_subnet)
|
|
self.bbox_subnet = nn.Sequential(*bbox_subnet)
|
|
self.cls_score = nn.Conv2d(
|
|
conv_dims[-1], num_anchors * num_classes, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.bbox_pred = nn.Conv2d(
|
|
conv_dims[-1], num_anchors * 4, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
# Initialization
|
|
for modules in [self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred]:
|
|
for layer in modules.modules():
|
|
if isinstance(layer, nn.Conv2d):
|
|
torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
|
|
torch.nn.init.constant_(layer.bias, 0)
|
|
|
|
# Use prior in model initialization to improve stability
|
|
bias_value = -(math.log((1 - prior_prob) / prior_prob))
|
|
torch.nn.init.constant_(self.cls_score.bias, bias_value)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, input_shape: List[ShapeSpec]):
|
|
num_anchors = build_anchor_generator(cfg, input_shape).num_cell_anchors
|
|
# fmt: on
|
|
assert (
|
|
len(set(num_anchors)) == 1
|
|
), "Using different number of anchors between levels is not currently supported!"
|
|
num_anchors = num_anchors[0]
|
|
|
|
return {
|
|
"input_shape": input_shape,
|
|
"num_classes": cfg.MODEL.RETINANET.NUM_CLASSES,
|
|
"conv_dims": [input_shape[0].channels] * cfg.MODEL.RETINANET.NUM_CONVS,
|
|
"prior_prob": cfg.MODEL.RETINANET.PRIOR_PROB,
|
|
"norm": cfg.MODEL.RETINANET.NORM,
|
|
"num_anchors": num_anchors,
|
|
}
|
|
|
|
def forward(self, features):
|
|
"""
|
|
Arguments:
|
|
features (list[Tensor]): FPN feature map tensors in high to low resolution.
|
|
Each tensor in the list correspond to different feature levels.
|
|
|
|
Returns:
|
|
logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi).
|
|
The tensor predicts the classification probability
|
|
at each spatial position for each of the A anchors and K object
|
|
classes.
|
|
bbox_reg (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi).
|
|
The tensor predicts 4-vector (dx,dy,dw,dh) box
|
|
regression values for every anchor. These values are the
|
|
relative offset between the anchor and the ground truth box.
|
|
"""
|
|
logits = []
|
|
bbox_reg = []
|
|
for feature in features:
|
|
logits.append(self.cls_score(self.cls_subnet(feature)))
|
|
bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
|
|
return logits, bbox_reg
|