mirror of https://github.com/RE-OWOD/RE-OWOD
219 lines
8.2 KiB
Python
219 lines
8.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from detectron2.structures import ImageList
|
|
|
|
from ..backbone import build_backbone
|
|
from ..postprocessing import detector_postprocess, sem_seg_postprocess
|
|
from ..proposal_generator import build_proposal_generator
|
|
from ..roi_heads import build_roi_heads
|
|
from .build import META_ARCH_REGISTRY
|
|
from .semantic_seg import build_sem_seg_head
|
|
|
|
__all__ = ["PanopticFPN"]
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class PanopticFPN(nn.Module):
|
|
"""
|
|
Implement the paper :paper:`PanopticFPN`.
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
self.instance_loss_weight = cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT
|
|
|
|
# options when combining instance & semantic outputs
|
|
self.combine_on = cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED
|
|
self.combine_overlap_threshold = cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH
|
|
self.combine_stuff_area_limit = cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT
|
|
self.combine_instances_confidence_threshold = (
|
|
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH
|
|
)
|
|
|
|
self.backbone = build_backbone(cfg)
|
|
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
|
|
self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
|
|
self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape())
|
|
|
|
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
|
|
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
|
|
|
|
@property
|
|
def device(self):
|
|
return self.pixel_mean.device
|
|
|
|
def forward(self, batched_inputs):
|
|
"""
|
|
Args:
|
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
|
Each item in the list contains the inputs for one image.
|
|
|
|
For now, each item in the list is a dict that contains:
|
|
|
|
* "image": Tensor, image in (C, H, W) format.
|
|
* "instances": Instances
|
|
* "sem_seg": semantic segmentation ground truth.
|
|
* Other information that's included in the original dicts, such as:
|
|
"height", "width" (int): the output resolution of the model, used in inference.
|
|
See :meth:`postprocess` for details.
|
|
|
|
Returns:
|
|
list[dict]:
|
|
each dict is the results for one image. The dict contains the following keys:
|
|
|
|
* "instances": see :meth:`GeneralizedRCNN.forward` for its format.
|
|
* "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
|
|
* "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
|
|
See the return value of
|
|
:func:`combine_semantic_and_instance_outputs` for its format.
|
|
"""
|
|
images = [x["image"].to(self.device) for x in batched_inputs]
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
|
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
|
|
features = self.backbone(images.tensor)
|
|
|
|
if "proposals" in batched_inputs[0]:
|
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
|
proposal_losses = {}
|
|
|
|
if "sem_seg" in batched_inputs[0]:
|
|
gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs]
|
|
gt_sem_seg = ImageList.from_tensors(
|
|
gt_sem_seg, self.backbone.size_divisibility, self.sem_seg_head.ignore_value
|
|
).tensor
|
|
else:
|
|
gt_sem_seg = None
|
|
sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg)
|
|
|
|
if "instances" in batched_inputs[0]:
|
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
|
else:
|
|
gt_instances = None
|
|
if self.proposal_generator:
|
|
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
|
|
detector_results, detector_losses = self.roi_heads(
|
|
images, features, proposals, gt_instances
|
|
)
|
|
|
|
if self.training:
|
|
losses = {}
|
|
losses.update(sem_seg_losses)
|
|
losses.update({k: v * self.instance_loss_weight for k, v in detector_losses.items()})
|
|
losses.update(proposal_losses)
|
|
return losses
|
|
|
|
processed_results = []
|
|
for sem_seg_result, detector_result, input_per_image, image_size in zip(
|
|
sem_seg_results, detector_results, batched_inputs, images.image_sizes
|
|
):
|
|
height = input_per_image.get("height", image_size[0])
|
|
width = input_per_image.get("width", image_size[1])
|
|
sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
|
|
detector_r = detector_postprocess(detector_result, height, width)
|
|
|
|
processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r})
|
|
|
|
if self.combine_on:
|
|
panoptic_r = combine_semantic_and_instance_outputs(
|
|
detector_r,
|
|
sem_seg_r.argmax(dim=0),
|
|
self.combine_overlap_threshold,
|
|
self.combine_stuff_area_limit,
|
|
self.combine_instances_confidence_threshold,
|
|
)
|
|
processed_results[-1]["panoptic_seg"] = panoptic_r
|
|
return processed_results
|
|
|
|
|
|
def combine_semantic_and_instance_outputs(
|
|
instance_results,
|
|
semantic_results,
|
|
overlap_threshold,
|
|
stuff_area_limit,
|
|
instances_confidence_threshold,
|
|
):
|
|
"""
|
|
Implement a simple combining logic following
|
|
"combine_semantic_and_instance_predictions.py" in panopticapi
|
|
to produce panoptic segmentation outputs.
|
|
|
|
Args:
|
|
instance_results: output of :func:`detector_postprocess`.
|
|
semantic_results: an (H, W) tensor, each is the contiguous semantic
|
|
category id
|
|
|
|
Returns:
|
|
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
|
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
|
Each dict contains keys "id", "category_id", "isthing".
|
|
"""
|
|
panoptic_seg = torch.zeros_like(semantic_results, dtype=torch.int32)
|
|
|
|
# sort instance outputs by scores
|
|
sorted_inds = torch.argsort(-instance_results.scores)
|
|
|
|
current_segment_id = 0
|
|
segments_info = []
|
|
|
|
instance_masks = instance_results.pred_masks.to(dtype=torch.bool, device=panoptic_seg.device)
|
|
|
|
# Add instances one-by-one, check for overlaps with existing ones
|
|
for inst_id in sorted_inds:
|
|
score = instance_results.scores[inst_id].item()
|
|
if score < instances_confidence_threshold:
|
|
break
|
|
mask = instance_masks[inst_id] # H,W
|
|
mask_area = mask.sum().item()
|
|
|
|
if mask_area == 0:
|
|
continue
|
|
|
|
intersect = (mask > 0) & (panoptic_seg > 0)
|
|
intersect_area = intersect.sum().item()
|
|
|
|
if intersect_area * 1.0 / mask_area > overlap_threshold:
|
|
continue
|
|
|
|
if intersect_area > 0:
|
|
mask = mask & (panoptic_seg == 0)
|
|
|
|
current_segment_id += 1
|
|
panoptic_seg[mask] = current_segment_id
|
|
segments_info.append(
|
|
{
|
|
"id": current_segment_id,
|
|
"isthing": True,
|
|
"score": score,
|
|
"category_id": instance_results.pred_classes[inst_id].item(),
|
|
"instance_id": inst_id.item(),
|
|
}
|
|
)
|
|
|
|
# Add semantic results to remaining empty areas
|
|
semantic_labels = torch.unique(semantic_results).cpu().tolist()
|
|
for semantic_label in semantic_labels:
|
|
if semantic_label == 0: # 0 is a special "thing" class
|
|
continue
|
|
mask = (semantic_results == semantic_label) & (panoptic_seg == 0)
|
|
mask_area = mask.sum().item()
|
|
if mask_area < stuff_area_limit:
|
|
continue
|
|
|
|
current_segment_id += 1
|
|
panoptic_seg[mask] = current_segment_id
|
|
segments_info.append(
|
|
{
|
|
"id": current_segment_id,
|
|
"isthing": False,
|
|
"category_id": semantic_label,
|
|
"area": mask_area,
|
|
}
|
|
)
|
|
|
|
return panoptic_seg, segments_info
|