mmsegmentation/projects/sam_inference_demo/sam/modeling/sam.py

189 lines
7.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Borrowed from https://github.com/facebookresearch/segment-anything
from typing import Any, Dict, List, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from mmseg.registry import MODELS
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
@MODELS.register_module()
class SAM(nn.Module):
mask_threshold: float = 0.0
image_format: str = 'RGB'
def __init__(
self,
image_encoder_cfg: dict,
prompt_encoder_cfg: dict,
mask_decoder_cfg: dict,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
"""SAM predicts object masks from an image and input prompts. Borrowed
from https://github.com/facebookresearch/segment-anything.
Arguments:
image_encoder (ViTSAM): The backbone used to encode the
image into image embeddings that allow for efficient mask
prediction.
prompt_encoder (PromptEncoder): Encodes various types of input
prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the
input image.
pixel_std (list(float)): Std values for normalizing pixels in the
input image.
"""
super().__init__()
self.image_encoder = MODELS.build(image_encoder_cfg)
self.prompt_encoder: PromptEncoder = MODELS.build(prompt_encoder_cfg)
self.mask_decoder: MaskDecoder = MODELS.build(mask_decoder_cfg)
self.register_buffer('pixel_mean',
torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer('pixel_std',
torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""Predicts masks end-to-end from provided images and prompts. If
prompts are not known in advance, using SamPredictor is recommended
over calling the model directly.
Borrowed from https://github.com/facebookresearch/segment-anything
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determiend by multimask_output, and (H, W) is the
original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
input_images = torch.stack(
[self.preprocess(x['image']) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input,
image_embeddings):
if 'point_coords' in image_record:
points = (image_record['point_coords'],
image_record['point_labels'])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get('boxes', None),
masks=image_record.get('mask_inputs', None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record['image'].shape[-2:],
original_size=image_record['original_size'],
)
masks = masks > self.mask_threshold
outputs.append({
'masks': masks,
'iou_predictions': iou_predictions,
'low_res_logits': low_res_masks,
})
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""Remove padding and upscale masks to the original image size.
Borrowed from https://github.com/facebookresearch/segment-anything
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
self.image_encoder.img_size,
mode='bilinear',
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1]]
masks = F.interpolate(
masks, original_size, mode='bilinear', align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
img_size = max(self.image_encoder.img_size)
padh = img_size - h
padw = img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x